From 57bae648c5e5fd77af3751d06559256b0011101f Mon Sep 17 00:00:00 2001 From: Default email Date: Fri, 11 Oct 2024 13:33:20 -0700 Subject: [PATCH 1/2] Project import generated by Copybara. GitOrigin-RevId: cff9e93824c1cfae9dff2628a1fb001972d32819 --- .bazelrc | 430 ++- .github/workflows/asan.yaml | 88 + .github/workflows/ci-build.yaml | 52 +- .github/workflows/cloud-tpu-ci-nightly.yml | 4 +- .github/workflows/jax-array-api.yml | 6 +- .github/workflows/metal_plugin_ci.yml | 2 +- .github/workflows/upstream-nightly.yml | 18 +- .github/workflows/wheel_win_x64.yml | 8 +- .github/workflows/windows_ci.yml | 8 +- .pre-commit-config.yaml | 2 +- CHANGELOG.md | 46 +- README.md | 4 +- build/build.py | 11 +- build/requirements_lock_3_13.txt | 249 +- build/rocm/Dockerfile.ms | 28 +- .../Dockerfile.manylinux_2_28_x86_64.rocm | 10 + build/rocm/ci_build | 10 +- build/rocm/docker/Dockerfile.jax-ubu22 | 12 +- build/rocm/docker/Dockerfile.jax-ubu24 | 10 +- build/rocm/tools/build_wheels.py | 74 +- build/rocm/tools/get_rocm.py | 3 +- docs/Custom_Operation_for_GPUs.md | 2 +- docs/autodidax.ipynb | 195 +- docs/autodidax.md | 195 +- docs/autodidax.py | 195 +- docs/ffi.ipynb | 44 +- docs/ffi.md | 44 +- docs/jax.experimental.host_callback.rst | 20 - docs/jax.experimental.pallas.mosaic_gpu.rst | 42 + docs/jax.experimental.pallas.rst | 20 +- docs/jax.experimental.pallas.tpu.rst | 16 + docs/jax.experimental.pallas.triton.rst | 22 + docs/jax.experimental.rst | 13 +- docs/jax.lax.rst | 11 + docs/jax.numpy.rst | 1 + docs/jax.rst | 3 +- docs/notebooks/shard_map.ipynb | 2 + docs/notebooks/shard_map.md | 2 + docs/pallas/CHANGELOG.md | 9 + docs/pallas/grid_blockspec.md | 8 +- docs/pallas/quickstart.ipynb | 2 + docs/pallas/quickstart.md | 2 + docs/pallas/tpu/details.rst | 2 + docs/pallas/tpu/distributed.ipynb | 6 +- docs/pallas/tpu/distributed.md | 6 +- docs/pallas/tpu/pipelining.ipynb | 2 + docs/pallas/tpu/pipelining.md | 2 + examples/ffi/CMakeLists.txt | 4 + examples/ffi/src/jax_ffi_example/attrs.cc | 66 + examples/ffi/src/jax_ffi_example/attrs.py | 47 + examples/ffi/src/jax_ffi_example/rms_norm.py | 7 +- examples/ffi/tests/attrs_test.py | 61 + jax/BUILD | 41 +- jax/_src/api.py | 45 +- jax/_src/array.py | 3 +- jax/_src/callback.py | 159 +- jax/_src/checkify.py | 29 +- jax/_src/cloud_tpu_init.py | 4 +- jax/_src/clusters/cloud_tpu_cluster.py | 2 +- jax/_src/compiler.py | 3 + jax/_src/config.py | 8 + jax/_src/core.py | 4 + jax/_src/deprecations.py | 1 + jax/_src/dispatch.py | 129 +- jax/_src/environment_info.py | 14 +- jax/_src/export/_export.py | 42 +- jax/_src/extend/ffi.py | 161 +- .../cpu_hessenberg_lapack_gehrd.py | 526 +++ .../cuda_eigh_cusolver_syev.py | 341 +- .../cuda_qr_cusolver_geqrf.py | 270 +- .../export_back_compat_test_util.py | 3 +- jax/_src/internal_test_util/test_harnesses.py | 121 +- jax/_src/interpreters/mlir.py | 100 +- jax/_src/interpreters/partial_eval.py | 12 +- jax/_src/interpreters/pxla.py | 59 +- jax/_src/interpreters/xla.py | 143 +- jax/_src/lax/control_flow/conditionals.py | 51 +- jax/_src/lax/convolution.py | 3 +- jax/_src/lax/fft.py | 46 +- jax/_src/lax/lax.py | 904 +++-- jax/_src/lax/linalg.py | 293 +- jax/_src/lax/other.py | 12 +- jax/_src/lax/parallel.py | 3 +- jax/_src/lax/slicing.py | 576 ++- jax/_src/lax/utils.py | 10 - jax/_src/lib/__init__.py | 31 +- jax/_src/mesh.py | 5 + jax/_src/nn/functions.py | 68 + jax/_src/numpy/array_methods.py | 15 + jax/_src/numpy/fft.py | 113 +- jax/_src/numpy/lax_numpy.py | 678 +++- jax/_src/numpy/reductions.py | 59 +- jax/_src/numpy/ufuncs.py | 612 +++- jax/_src/numpy/util.py | 6 +- jax/_src/ops/scatter.py | 4 +- jax/_src/pallas/core.py | 63 +- jax/_src/pallas/mosaic/BUILD | 7 + jax/_src/pallas/mosaic/core.py | 16 + jax/_src/pallas/mosaic/error_handling.py | 3 +- jax/_src/pallas/mosaic/lowering.py | 310 +- .../pallas/mosaic/pallas_call_registration.py | 2 +- jax/_src/pallas/mosaic/pipeline.py | 66 +- jax/_src/pallas/mosaic/primitives.py | 110 +- jax/_src/pallas/mosaic/random.py | 6 + jax/_src/pallas/mosaic_gpu/BUILD | 4 +- jax/_src/pallas/mosaic_gpu/__init__.py | 19 - jax/_src/pallas/mosaic_gpu/core.py | 234 +- jax/_src/pallas/mosaic_gpu/lowering.py | 477 ++- jax/_src/pallas/mosaic_gpu/primitives.py | 378 +- jax/_src/pallas/pallas_call.py | 283 +- jax/_src/pallas/primitives.py | 12 +- jax/_src/pallas/triton/lowering.py | 8 +- jax/_src/pjit.py | 110 +- jax/_src/sharding_impls.py | 15 + jax/_src/state/discharge.py | 45 +- jax/_src/state/indexing.py | 7 + jax/_src/state/primitives.py | 22 +- jax/_src/state/types.py | 83 +- jax/_src/tpu_custom_call.py | 22 +- jax/experimental/host_callback.py | 2128 +----------- .../examples/serving/model_server_request.py | 2 +- .../tf_js/quickdraw/input_pipeline.py | 2 +- jax/experimental/jax2tf/impl_no_xla.py | 3 - jax/experimental/jax2tf/jax2tf.py | 25 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 2 + .../jax2tf/tests/primitives_test.py | 6 +- .../jax2tf/tests/shape_poly_test.py | 8 +- jax/experimental/mosaic/gpu/__init__.py | 1 + .../mosaic/gpu/fragmented_array.py | 26 +- jax/experimental/mosaic/gpu/utils.py | 92 +- jax/experimental/pallas/__init__.py | 9 +- jax/experimental/pallas/gpu.py | 14 +- jax/experimental/pallas/mosaic_gpu.py | 39 + jax/experimental/pallas/ops/gpu/attention.py | 2 +- .../pallas/ops/gpu/decode_attention.py | 2 +- jax/experimental/pallas/ops/gpu/layer_norm.py | 2 +- jax/experimental/pallas/ops/gpu/rms_norm.py | 2 +- jax/experimental/pallas/ops/gpu/softmax.py | 2 +- jax/experimental/pallas/tpu.py | 2 + jax/experimental/pallas/triton.py | 20 + jax/experimental/rnn.py | 6 +- jax/experimental/shard_map.py | 15 +- jax/experimental/sparse/bcoo.py | 11 +- jax/experimental/sparse/bcsr.py | 8 +- jax/experimental/sparse/util.py | 3 +- jax/lax/__init__.py | 6 +- jax/lib/xla_client.py | 75 +- jax/numpy/__init__.py | 1 + jax/numpy/__init__.pyi | 3 +- jax/sharding.py | 19 +- jax/tools/jax_to_ir.py | 2 +- jax/version.py | 4 +- jax_plugins/BUILD.bazel | 4 + jaxlib/BUILD | 1 + jaxlib/ducc_fft.py | 92 - jaxlib/gpu/BUILD | 1 - jaxlib/gpu/solver_kernels_ffi.cc | 260 +- jaxlib/gpu/vendor.h | 6 +- jaxlib/gpu_solver.py | 257 +- jaxlib/jax.bzl | 1 + jaxlib/lapack.py | 406 +-- jaxlib/mosaic/dialect/tpu/tpu.td | 16 +- jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 6 +- .../tpu/transforms/apply_vector_layout.cc | 147 +- .../tpu/transforms/canonicalize_mosaic.cc | 21 +- .../tpu/transforms/infer_memref_layout.cc | 113 +- .../tpu/transforms/infer_memref_layout.h | 5 +- .../tpu/transforms/infer_vector_layout.cc | 39 +- jaxlib/rocm_plugin_extension.cc | 26 +- jaxlib/tools/build_wheel.py | 1 - pyproject.toml | 25 +- setup.py | 6 +- tests/BUILD | 136 +- tests/api_test.py | 24 + tests/checkify_test.py | 19 + tests/compilation_cache_test.py | 13 +- tests/cudnn_fusion_test.py | 10 +- tests/custom_object_test.py | 348 -- tests/export_back_compat_test.py | 160 +- tests/extend_test.py | 77 +- tests/fft_test.py | 14 +- tests/host_callback_test.py | 3089 ----------------- tests/host_callback_to_tf_test.py | 279 -- tests/infeed_test.py | 3 - tests/jax_to_ir_test.py | 9 - tests/lax_metal_test.py | 16 +- tests/lax_numpy_indexing_test.py | 22 +- tests/lax_numpy_operators_test.py | 30 + tests/lax_numpy_test.py | 40 +- tests/lax_test.py | 645 ++-- tests/lax_vmap_test.py | 35 +- tests/layout_test.py | 33 +- tests/linalg_test.py | 2 +- tests/logging_test.py | 4 + tests/memories_test.py | 131 +- tests/mosaic/gpu_test.py | 17 + tests/mutable_array_test.py | 9 + tests/nn_test.py | 59 + tests/pallas/BUILD | 90 +- tests/pallas/mosaic_gpu_test.py | 398 ++- tests/pallas/ops_test.py | 336 +- tests/pallas/pallas_test.py | 80 +- tests/pallas/tpu_ops_test.py | 25 + tests/pallas/tpu_pallas_async_test.py | 119 + tests/pallas/tpu_pallas_mesh_test.py | 107 - tests/pallas/tpu_pallas_random_test.py | 34 + tests/pallas/tpu_pallas_state_test.py | 271 ++ tests/pjit_test.py | 135 +- tests/python_callback_test.py | 69 +- tests/shape_poly_test.py | 111 +- tests/shard_map_test.py | 3 + tests/state_test.py | 54 + tests/tree_util_test.py | 7 +- tests/xla_bridge_test.py | 14 - third_party/xla/workspace.bzl | 4 +- 215 files changed, 10876 insertions(+), 10317 deletions(-) create mode 100644 .github/workflows/asan.yaml delete mode 100644 docs/jax.experimental.host_callback.rst create mode 100644 docs/jax.experimental.pallas.mosaic_gpu.rst create mode 100644 docs/jax.experimental.pallas.tpu.rst create mode 100644 docs/jax.experimental.pallas.triton.rst create mode 100644 examples/ffi/src/jax_ffi_example/attrs.cc create mode 100644 examples/ffi/src/jax_ffi_example/attrs.py create mode 100644 examples/ffi/tests/attrs_test.py create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py create mode 100644 jax/experimental/pallas/mosaic_gpu.py create mode 100644 jax/experimental/pallas/triton.py delete mode 100644 jaxlib/ducc_fft.py delete mode 100644 tests/custom_object_test.py delete mode 100644 tests/host_callback_test.py delete mode 100644 tests/host_callback_to_tf_test.py delete mode 100644 tests/pallas/tpu_pallas_mesh_test.py create mode 100644 tests/pallas/tpu_pallas_state_test.py diff --git a/.bazelrc b/.bazelrc index 5b7bc653373b..aebd12596a33 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,43 +1,82 @@ -############################################################################ -# All default build options below. - -# Required by OpenXLA -# https://github.com/openxla/xla/issues/1323 -build --nocheck_visibility - -# Sets the default Apple platform to macOS. -build --apple_platform_type=macos -build --macos_minimum_os=10.14 - +# ############################################################################# +# All default build options below. These apply to all build commands. +# ############################################################################# # Make Bazel print out all options from rc files. build --announce_rc -build --define open_source_build=true - -build --spawn_strategy=standalone +# By default, execute all actions locally. +build --spawn_strategy=local +# Enable host OS specific configs. For instance, "build:linux" will be used +# automatically when building on Linux. build --enable_platform_specific_config build --experimental_cc_shared_library -# Disable enabled-by-default TensorFlow features that we don't care about. -build --define=no_aws_support=true -build --define=no_gcp_support=true -build --define=no_hdfs_support=true -build --define=no_kafka_support=true -build --define=no_ignite_support=true - +# Do not use C-Ares when building gRPC. build --define=grpc_no_ares=true build --define=tsl_link_protobuf=true +# Enable optimization. build -c opt -build --config=short_logs +# Suppress all warning messages. +build --output_filter=DONT_MATCH_ANYTHING build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. -########################################################################### +# ############################################################################# +# Platform Specific configs below. These are automatically picked up by Bazel +# depending on the platform that is running the build. +# ############################################################################# +build:linux --config=posix +build:linux --copt=-Wno-unknown-warning-option + +# Workaround for gcc 10+ warnings related to upb. +# See https://github.com/tensorflow/tensorflow/issues/39467 +build:linux --copt=-Wno-stringop-truncation +build:linux --copt=-Wno-array-parameter + +build:macos --config=posix +build:macos --apple_platform_type=macos + +# Windows has a relatively short command line limit, which JAX has begun to hit. +# See https://docs.bazel.build/versions/main/windows.html +build:windows --features=compiler_param_file +build:windows --features=archive_param_file + +# XLA uses M_* math constants that only get defined by MSVC headers if +# _USE_MATH_DEFINES is defined. +build:windows --copt=/D_USE_MATH_DEFINES +build:windows --host_copt=/D_USE_MATH_DEFINES +# Make sure to include as little of windows.h as possible +build:windows --copt=-DWIN32_LEAN_AND_MEAN +build:windows --host_copt=-DWIN32_LEAN_AND_MEAN +build:windows --copt=-DNOGDI +build:windows --host_copt=-DNOGDI +# https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/ +# otherwise, there will be some compiling error due to preprocessing. +build:windows --copt=/Zc:preprocessor +build:windows --cxxopt=/std:c++17 +build:windows --host_cxxopt=/std:c++17 +# Generate PDB files, to generate useful PDBs, in opt compilation_mode +# --copt /Z7 is needed. +build:windows --linkopt=/DEBUG +build:windows --host_linkopt=/DEBUG +build:windows --linkopt=/OPT:REF +build:windows --host_linkopt=/OPT:REF +build:windows --linkopt=/OPT:ICF +build:windows --host_linkopt=/OPT:ICF +build:windows --incompatible_strict_action_env=true + +# ############################################################################# +# Feature-specific configurations. These are used by the Local and CI configs +# below depending on the type of build. E.g. `local_linux_x86_64` inherits the +# Linux x86 configs such as `avx_linux` and `mkl_open_source_only`, +# `local_cuda_base` inherits `cuda` and `build_cuda_with_nvcc`, etc. +# ############################################################################# +build:nonccl --define=no_nccl_support=true build:posix --copt=-fvisibility=hidden build:posix --copt=-Wno-sign-compare @@ -47,13 +86,13 @@ build:posix --host_cxxopt=-std=c++17 build:avx_posix --copt=-mavx build:avx_posix --host_copt=-mavx -build:avx_windows --copt=/arch=AVX +build:native_arch_posix --copt=-march=native +build:native_arch_posix --host_copt=-march=native build:avx_linux --copt=-mavx build:avx_linux --host_copt=-mavx -build:native_arch_posix --copt=-march=native -build:native_arch_posix --host_copt=-march=native +build:avx_windows --copt=/arch:AVX build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 @@ -66,6 +105,7 @@ build:clang --copt=-Wno-gnu-offsetof-extensions # Disable clang extention that rejects unknown arguments. build:clang --copt=-Qunused-arguments +# Configs for CUDA build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 # "sm" means we emit only cubin, which is forward compatible within a GPU generation. @@ -74,9 +114,14 @@ build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_8 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true + # Default hermetic CUDA and CUDNN versions. build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + +# This flag is needed to include CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_cuda_libs=true + # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to # point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA @@ -91,167 +136,180 @@ build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" # The list of CUDA pip packages that JAX depends on are present in setup.py. build:cuda --linkopt=-Wl,--disable-new-dtags -# This flag is needed to include CUDA libraries for bazel tests. -test:cuda --@local_config_cuda//cuda:include_cuda_libs=true - -build:cuda_clang --config=clang -build:cuda_clang --@local_config_cuda//:cuda_compiler=clang -build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" - -# Build with NVCC for CUDA -build:cuda_nvcc --config=cuda -build:cuda_nvcc --config=clang -build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc -build:cuda_nvcc --action_env=TF_NVCC_CLANG="1" -build:cuda_nvcc --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" - -build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain -build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true -build:rocm --repo_env TF_NEED_ROCM=1 -build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030" +# Build CUDA and other C++ targets with Clang +build:build_cuda_with_clang --@local_config_cuda//:cuda_compiler=clang -build:nonccl --define=no_nccl_support=true +# Build CUDA with NVCC and other C++ targets with Clang +build:build_cuda_with_nvcc --action_env=TF_NVCC_CLANG="1" +build:build_cuda_with_nvcc --@local_config_cuda//:cuda_compiler=nvcc # Requires MSVC and LLVM to be installed build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl build:win_clang --compiler=clang-cl -# Windows has a relatively short command line limit, which JAX has begun to hit. -# See https://docs.bazel.build/versions/main/windows.html -build:windows --features=compiler_param_file -build:windows --features=archive_param_file - -# Tensorflow uses M_* math constants that only get defined by MSVC headers if -# _USE_MATH_DEFINES is defined. -build:windows --copt=/D_USE_MATH_DEFINES -build:windows --host_copt=/D_USE_MATH_DEFINES -# Make sure to include as little of windows.h as possible -build:windows --copt=-DWIN32_LEAN_AND_MEAN -build:windows --host_copt=-DWIN32_LEAN_AND_MEAN -build:windows --copt=-DNOGDI -build:windows --host_copt=-DNOGDI -# https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/ -# otherwise, there will be some compiling error due to preprocessing. -build:windows --copt=/Zc:preprocessor -build:windows --cxxopt=/std:c++17 -build:windows --host_cxxopt=/std:c++17 -# Generate PDB files, to generate useful PDBs, in opt compilation_mode -# --copt /Z7 is needed. -build:windows --linkopt=/DEBUG -build:windows --host_linkopt=/DEBUG -build:windows --linkopt=/OPT:REF -build:windows --host_linkopt=/OPT:REF -build:windows --linkopt=/OPT:ICF -build:windows --host_linkopt=/OPT:ICF -build:windows --incompatible_strict_action_env=true - -build:linux --config=posix -build:linux --copt=-Wno-unknown-warning-option -# Workaround for gcc 10+ warnings related to upb. -# See https://github.com/tensorflow/tensorflow/issues/39467 -build:linux --copt=-Wno-stringop-truncation -build:linux --copt=-Wno-array-parameter - -build:macos --config=posix - -# Public cache for macOS builds. The "oct2023" in the URL is just the -# date when the bucket was created and can be disregarded. It still contains the -# latest cache that is being used. +build:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain +build:rocm_base --define=using_rocm=true --define=using_rocm_hipcc=true +build:rocm_base --repo_env TF_NEED_ROCM=1 +build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100" + +# Build with hipcc for ROCm and clang for the host. +build:rocm --config=rocm_base +build:rocm --action_env=TF_ROCM_CLANG="1" +build:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" +build:rocm --copt=-Wno-gnu-offsetof-extensions +build:rocm --copt=-Qunused-arguments +build:rocm --action_env=TF_HIPCC_CLANG="1" + +# ############################################################################# +# Cache options below. +# ############################################################################# +# Public read-only cache for Mac builds. JAX uses a GCS bucket to store cache +# from JAX's Mac CI build. By applying --config=macos_cache, any local Mac build +# should be able to read from this cache and potentially see a speedup. The +# "oct2023" in the URL is just the date when the bucket was created and can be +# disregarded. It still contains the latest cache that is being used. build:macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false -# Cache pushes are limited to Jax's CI system. -build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials -# Suppress all warning messages. -build:short_logs --output_filter=DONT_MATCH_ANYTHING +# Cache pushes are limited to JAX's CI system. +build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials -######################################################################### -# RBE config options below. +# ############################################################################# +# CI Build config options below. +# JAX uses these configs in CI builds for building artifacts and when running +# Bazel tests. +# ############################################################################# +# Linux x86 CI configs +build:ci_linux_x86_64 --config=avx_linux --config=avx_posix +build:ci_linux_x86_64 --config=mkl_open_source_only +build:ci_linux_x86_64 --config=clang --verbose_failures=true + +# TODO(b/356695103): We do not have a CPU only toolchain so we use the CUDA +# toolchain for both CPU and GPU builds. +build:ci_linux_x86_64 --host_crosstool_top="@local_config_cuda//crosstool:toolchain" +build:ci_linux_x86_64 --crosstool_top="@local_config_cuda//crosstool:toolchain" +build:ci_linux_x86_64 --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" +build:ci_linux_x86_64 --repo_env=TF_SYSROOT="/dt9" + +# Clang path needs to be set for remote toolchain to be configured correctly. +build:ci_linux_x86_64 --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" + +# The toolchain in `--config=cuda` needs to be read before the toolchain in +# `--config=ci_linux_x86_64`. Otherwise, we run into issues with manylinux +# compliance. +build:ci_linux_x86_64_cuda --config=cuda --config=build_cuda_with_nvcc +build:ci_linux_x86_64_cuda --config=ci_linux_x86_64 + +# Linux Aarch64 CI configs +build:ci_linux_aarch64_base --config=clang --verbose_failures=true +build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10" + +build:ci_linux_aarch64 --config=ci_linux_aarch64_base +build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" +build:ci_linux_aarch64 --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" + +# CUDA configs for Linux Aarch64 do not pass in the crosstool_top flag from +# above because the Aarch64 toolchain rule does not support building with NVCC. +# Instead, we use `@local_config_cuda//crosstool:toolchain` from --config=cuda +# and set `CLANG_CUDA_COMPILER_PATH` to define the toolchain so that we can +# use Clang for the C++ targets and NVCC to build CUDA targets. +build:ci_linux_aarch64_cuda --config=ci_linux_aarch64_base +build:ci_linux_aarch64_cuda --config=cuda --config=build_cuda_with_nvcc +build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" + +# Mac x86 CI configs +build:ci_darwin_x86_64 --macos_minimum_os=10.14 +build:ci_darwin_x86_64 --config=macos_cache_push +build:ci_darwin_x86_64 --verbose_failures=true + +# Mac Arm64 CI configs +build:ci_darwin_arm64 --macos_minimum_os=11.0 +build:ci_darwin_arm64 --config=macos_cache_push +build:ci_darwin_arm64 --verbose_failures=true + +# Windows x86 CI configs +build:ci_windows_amd64 --config=avx_windows +build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true +build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" +build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" +build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE + +# ############################################################################# +# RBE config options below. These inherit the CI configs above and set the +# remote execution backend and authentication options required to run builds +# with RBE. Linux x86 and Windows builds use RBE. +# ############################################################################# # Flag to enable remote config common --experimental_repo_remote_exec +# Allow creation of resultstore URLs for any bazel invocation +build:resultstore --google_default_credentials +build:resultstore --bes_backend=buildeventservice.googleapis.com +build:resultstore --bes_instance_name="tensorflow-testing" +build:resultstore --bes_results_url="https://source.cloud.google.com/results/invocations" +build:resultstore --bes_timeout=600s + +build:rbe --config=resultstore build:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 -build:rbe --google_default_credentials -build:rbe --bes_backend=buildeventservice.googleapis.com -build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" -build:rbe --bes_timeout=600s build:rbe --define=EXECUTOR=remote build:rbe --flaky_test_attempts=3 build:rbe --jobs=200 build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com build:rbe --remote_timeout=3600 build:rbe --spawn_strategy=remote,worker,standalone,local -test:rbe --test_env=USER=anon # Attempt to minimize the amount of data transfer between bazel and the remote # workers: build:rbe --remote_download_toplevel +test:rbe --test_env=USER=anon -build:rbe_linux --config=rbe -build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" -build:rbe_linux --host_javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 -build:rbe_linux --javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 -build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 -build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 - -# Non-rbe settings we should include because we do not run configure -build:rbe_linux --config=avx_linux -build:rbe_linux --linkopt=-lrt -build:rbe_linux --host_linkopt=-lrt -build:rbe_linux --linkopt=-lm -build:rbe_linux --host_linkopt=-lm - -# Use the GPU toolchain until the CPU one is ready. -# https://github.com/bazelbuild/bazel/issues/13623 -build:rbe_cpu_linux_base --config=rbe_linux -build:rbe_cpu_linux_base --config=cuda_clang -build:rbe_cpu_linux_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain" -build:rbe_cpu_linux_base --crosstool_top="@local_config_cuda//crosstool:toolchain" -build:rbe_cpu_linux_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_cpu_linux_base --repo_env=TF_SYSROOT="/dt9" -build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" - -build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base -build:rbe_cpu_linux_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" -build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base -build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" -build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base -build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" -build:rbe_cpu_linux_py3.13 --config=rbe_cpu_linux_base -build:rbe_cpu_linux_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13" - -build:rbe_linux_cuda_base --config=rbe_linux -build:rbe_linux_cuda_base --config=cuda -build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 - -build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda12.3_nvcc_base --config=cuda_nvcc -build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" -build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@local_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_SYSROOT="/dt9" -build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base -build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" -build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base -build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" -build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base -build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" -build:rbe_linux_cuda12.3_nvcc_py3.13 --config=rbe_linux_cuda12.3_nvcc_base -build:rbe_linux_cuda12.3_nvcc_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13" - -# These you may need to change for your own GCP project. -build:tensorflow_testing_rbe --project_id=tensorflow-testing -common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance -build:tensorflow_testing_rbe_linux --config=tensorflow_testing_rbe +# RBE configs for Linux x86 +# Set the remote worker pool +common:rbe_linux_x86_64_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance -# START CROSS-COMPILE CONFIGS +build:rbe_linux_x86_64_base --config=rbe +build:rbe_linux_x86_64_base --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +build:rbe_linux_x86_64_base --linkopt=-lrt +build:rbe_linux_x86_64_base --host_linkopt=-lrt +build:rbe_linux_x86_64_base --linkopt=-lm +build:rbe_linux_x86_64_base --host_linkopt=-lm +# Set the host, execution, and target platform +build:rbe_linux_x86_64_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +build:rbe_linux_x86_64_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +build:rbe_linux_x86_64_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" + +build:rbe_linux_x86_64 --config=rbe_linux_x86_64_base +build:rbe_linux_x86_64 --config=ci_linux_x86_64 + +build:rbe_linux_x86_64_cuda --config=rbe_linux_x86_64_base +build:rbe_linux_x86_64_cuda --config=ci_linux_x86_64_cuda +build:rbe_linux_x86_64_cuda --repo_env=REMOTE_GPU_TESTING=1 + +# RBE configs for Windows +# Set the remote worker pool +common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/instances/windows + +build:rbe_windows_amd64 --config=rbe + +# Set the host, execution, and target platform +build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" + +build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe +build:rbe_windows_amd64 --enable_runfiles +build:rbe_windows_amd64 --define=override_eigen_strong_inline=true + +# Don't build the python zip archive in the RBE build. +build:rbe_windows_amd64 --nobuild_python_zip + +build:rbe_windows_amd64 --config=ci_windows_amd64 + +# ############################################################################# +# Cross-compile config options below. Native RBE support does not exist for +# Linux Aarch64 and Mac x86. So, we use a cross-compile toolchain to build +# targets for Linux Aarch64 and Mac x86 on the Linux x86 RBE pool. +# ############################################################################# # Set execution platform to Linux x86 # Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" # flags seem to be actually used to specify the execution platform details. It @@ -261,48 +319,43 @@ build:cross_compile_base --host_cpu=k8 build:cross_compile_base --host_crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite build:cross_compile_base --extra_execution_platforms=@xla//tools/toolchains/cross_compile/config:linux_x86_64 -# START LINUX AARCH64 CROSS-COMPILE CONFIGS -build:cross_compile_linux_arm64 --config=cross_compile_base +# Linux Aarch64 +build:cross_compile_linux_aarch64 --config=cross_compile_base # Set the target CPU to Aarch64 -build:cross_compile_linux_arm64 --platforms=@xla//tools/toolchains/cross_compile/config:linux_aarch64 -build:cross_compile_linux_arm64 --cpu=aarch64 -build:cross_compile_linux_arm64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_linux_aarch64 --platforms=@xla//tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_aarch64 --cpu=aarch64 +build:cross_compile_linux_aarch64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite build:rbe_cross_compile_base --config=rbe +build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance # RBE cross-compile configs for Linux Aarch64 -build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 -build:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base -# END LINUX AARCH64 CROSS-COMPILE CONFIGS +build:rbe_cross_compile_linux_aarch64 --config=cross_compile_linux_aarch64 +build:rbe_cross_compile_linux_aarch64 --config=rbe_cross_compile_base -# START MACOS CROSS-COMPILE CONFIGS -build:cross_compile_macos_x86 --config=cross_compile_base -build:cross_compile_macos_x86 --config=nonccl +# Mac x86 +build:cross_compile_darwin_x86_64 --config=cross_compile_base +build:cross_compile_darwin_x86_64 --config=nonccl # Target Catalina (10.15) as the minimum supported OS -build:cross_compile_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 +build:cross_compile_darwin_x86_64 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 # Set the target CPU to Darwin x86 -build:cross_compile_macos_x86 --platforms=@xla//tools/toolchains/cross_compile/config:darwin_x86_64 -build:cross_compile_macos_x86 --cpu=darwin -build:cross_compile_macos_x86 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_darwin_x86_64 --platforms=@xla//tools/toolchains/cross_compile/config:darwin_x86_64 +build:cross_compile_darwin_x86_64 --cpu=darwin +build:cross_compile_darwin_x86_64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite # When RBE cross-compiling for macOS, we need to explicitly register the # toolchain. Otherwise, oddly, RBE complains that a "docker container must be # specified". -build:cross_compile_macos_x86 --extra_toolchains=@xla//tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain +build:cross_compile_darwin_x86_64 --extra_toolchains=@xla//tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain # Map --platforms=darwin_x86_64 to --cpu=darwin and vice-versa to make selects() # and transistions that use these flags work. The flag --platform_mappings needs # to be set to a file that exists relative to the package path roots. -build:cross_compile_macos_x86 --platform_mappings=platform_mappings +build:cross_compile_darwin_x86_64 --platform_mappings=platform_mappings # RBE cross-compile configs for Darwin x86 -build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 -build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base -# END MACOS CROSS-COMPILE CONFIGS - -# END CROSS-COMPILE CONFIGS - -############################################################################# +build:rbe_cross_compile_darwin_x86_64 --config=cross_compile_darwin_x86_64 +build:rbe_cross_compile_darwin_x86_64 --config=rbe_cross_compile_base ############################################################################# # Some configs to make getting some forms of debug builds. In general, the @@ -327,3 +380,10 @@ try-import %workspace%/.jax_configure.bazelrc # Load rc file with user-specific options. try-import %workspace%/.bazelrc.user + +# Temporary aliases to not break existing presubmit builds +build:rbe_cpu_linux_py3.13 --config=rbe_linux_x86_64 --repo_env=HERMETIC_PYTHON_VERSION=3.13 +build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_x86_64_cuda --repo_env=HERMETIC_PYTHON_VERSION=3.10 +build:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_linux_aarch64 +build:tensorflow_testing_rbe_linux --project_id=tensorflow-testing +common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml new file mode 100644 index 000000000000..a4ec78f96c97 --- /dev/null +++ b/.github/workflows/asan.yaml @@ -0,0 +1,88 @@ +name: CI - Address Sanitizer (nightly) + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + schedule: + - cron: "0 12 * * *" # Daily at 12:00 UTC + workflow_dispatch: # allows triggering the workflow run manually + pull_request: # Automatically trigger on pull requests affecting this file + branches: + - main + paths: + - '**/workflows/asan.yml' + +jobs: + asan: + runs-on: linux-x86-n2-64 + container: + image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 + strategy: + fail-fast: false + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + with: + path: jax + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + with: + repository: python/cpython + path: cpython + ref: v3.13.0 + - name: Install clang 18 + env: + DEBIAN_FRONTEND: noninteractive + run: | + apt update + apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ + zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ + libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ + libffi-dev liblzma-dev + - name: Build CPython with ASAN enabled + env: + ASAN_OPTIONS: detect_leaks=0 + run: | + cd cpython + mkdir ${GITHUB_WORKSPACE}/cpythonasan + CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpythonasan --with-address-sanitizer --without-pymalloc + make -j64 + make install + ${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv + - name: Install JAX test requirements + env: + ASAN_OPTIONS: detect_leaks=0 + run: | + source ${GITHUB_WORKSPACE}/venv/bin/activate + cd jax + pip install -r build/test-requirements.txt + - name: Build and install JAX + env: + ASAN_OPTIONS: detect_leaks=0 + run: | + source ${GITHUB_WORKSPACE}/venv/bin/activate + cd jax + python build/build.py \ + --bazel_options=--color=yes \ + --bazel_options=--copt=-fsanitize=address \ + --clang_path=/usr/bin/clang-18 + pip install dist/jaxlib-*.whl + pip install -e . + - name: Run tests + env: + ASAN_OPTIONS: detect_leaks=0 + JAX_NUM_GENERATED_CASES: 1 + JAX_ENABLE_X64: true + JAX_SKIP_SLOW_TESTS: true + PY_COLORS: 1 + run: | + source ${GITHUB_WORKSPACE}/venv/bin/activate + cd jax + echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" + echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" + echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" + # The LD_PRELOAD works around https://github.com/google/sanitizers/issues/934#issuecomment-649516500 + LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 python -m pytest -n auto --tb=short --maxfail=20 tests diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 315db489a818..03e7a040570f 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -29,16 +29,18 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 5 steps: - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Set up Python 3.11 - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: 3.11 - - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet: pre-commit/action@v3.0.1 + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 build: - name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})" - runs-on: ${{ matrix.os }} + name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" + runs-on: ROCM-Ubuntu + container: + image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 timeout-minutes: 60 strategy: matrix: @@ -46,20 +48,22 @@ jobs: include: - name-prefix: "with 3.10" python-version: "3.10" - os: ubuntu-20.04-16core enable-x64: 1 prng-upgrade: 1 num_generated_cases: 1 - - name-prefix: "with 3.12" - python-version: "3.12" - os: ubuntu-20.04-16core + - name-prefix: "with 3.13" + python-version: "3.13" enable-x64: 0 prng-upgrade: 0 num_generated_cases: 1 steps: - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - name: Image Setup + run: | + apt update + apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -68,7 +72,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 + uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -104,9 +108,9 @@ jobs: matrix: python-version: ['3.10'] steps: - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -115,7 +119,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 + uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -141,9 +145,9 @@ jobs: matrix: python-version: ['3.10'] steps: - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -152,7 +156,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 + uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -177,9 +181,9 @@ jobs: enable-x64: 0 num_generated_cases: 10 steps: - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -188,7 +192,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 + uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -216,9 +220,9 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 5 steps: - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Set up Python 3.11 - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: 3.11 - name: Get pip cache dir @@ -227,7 +231,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 + uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }} diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 4842e097e10a..d617178254a4 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -32,7 +32,7 @@ jobs: ] name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: - LIBTPU_OLDEST_VERSION_DATE: 20240228 + LIBTPU_OLDEST_VERSION_DATE: 20240722 ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }} runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"] timeout-minutes: 120 @@ -43,7 +43,7 @@ jobs: # https://opensource.google/documentation/reference/github/services#actions # mandates using a specific commit for non-Google actions. We use # https://github.com/sethvargo/ratchet to pin specific versions. - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Install JAX test requirements run: | pip install -U -r build/test-requirements.txt diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index f1dc8eee8a75..010ebae78c43 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -22,9 +22,9 @@ jobs: steps: - name: Checkout jax - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet: actions/checkout@v4 + uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Checkout array-api-tests - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet: actions/checkout@v4 + uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. @@ -32,7 +32,7 @@ jobs: submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml index 75f4bba1a367..3f6d4be94323 100644 --- a/.github/workflows/metal_plugin_ci.yml +++ b/.github/workflows/metal_plugin_ci.yml @@ -27,7 +27,7 @@ jobs: steps: - name: Get repo - uses: actions/checkout@v4 + uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 with: path: jax - name: Setup build and test enviroment diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 2bdd8ba5192e..ccaa02832b7a 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -22,7 +22,7 @@ on: jobs: upstream-dev: - runs-on: ubuntu-20.04-16core + runs-on: ROCM-Ubuntu permissions: contents: read checks: write # for upload-artifact @@ -32,13 +32,13 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.12"] + python-version: ["3.13"] outputs: artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }} steps: - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements @@ -85,7 +85,7 @@ jobs: && steps.status.outcome == 'failure' && github.event_name == 'schedule' && github.repository == 'jax-ml/jax' - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4 + uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4.4.1 with: name: output-${{ matrix.python-version }}-log.jsonl path: output-${{ matrix.python-version }}-log.jsonl @@ -106,11 +106,11 @@ jobs: run: shell: bash steps: - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: "3.x" - - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # ratchet:actions/download-artifact@v4 + - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: path: /tmp/workspace/logs - name: install requirements @@ -123,7 +123,7 @@ jobs: cat logs/*.jsonl > pytest-logs.txt python .github/workflows/parse_logs.py pytest-logs.txt --outfile=parsed-logs.txt - name: Report failures - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # ratchet:actions/github-script@v7 + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index f4fb7727da6b..41f3ab8a0548 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -17,7 +17,7 @@ jobs: matrix: os: [windows-2019-32core] arch: [AMD64] - pyver: ['3.10', '3.11', '3.12', '3.13.0-rc.2'] + pyver: ['3.10', '3.11', '3.12', '3.13'] name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build runs-on: ${{ matrix.os }} @@ -25,9 +25,9 @@ jobs: - name: Install LLVM/Clang run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' @@ -45,7 +45,7 @@ jobs: --bazel_options=--config=win_clang ` --verbose - - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4 + - uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4.4.1 with: name: wheels-${{ matrix.os }}-${{ matrix.pyver }} path: ${{ github.workspace }}\dist\*.whl diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 03a6876cdbb1..61da606ae17a 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -31,11 +31,11 @@ jobs: - name: Install LLVM/Clang run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 with: path: jax - - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' @@ -53,7 +53,7 @@ jobs: --bazel_options=--color=yes ` --bazel_options=--config=win_clang - - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4 + - uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4.4.1 with: name: wheels path: ${{ github.workspace }}\jax\dist\*.whl @@ -66,7 +66,7 @@ jobs: PY_COLORS: 1 run: | cd jax + python -m pip install --pre --find-links ${{ github.workspace }}\jax\dist jaxlib python -m pip install -e ${{ github.workspace }}\jax - python -m pip install --no-index --find-links ${{ github.workspace }}\jax\dist jaxlib echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" pytest -n auto --tb=short tests examples diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c89aa934d95d..1e2127b310ed 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: - id: mypy files: (jax/|tests/typing_test\.py) exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead - additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.31, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4] + additional_dependencies: [types-requests==2.31.0, jaxlib] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext diff --git a/CHANGELOG.md b/CHANGELOG.md index 079e055aa994..7fb6c1bd0fde 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,39 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## jax 0.4.34 +## jax 0.4.35 + +* Breaking Changes + * {func}`jax.numpy.isscalar` now returns True for any array-like object with + zero dimensions. Previously it only returned True for zero-dimensional + array-like objects with a weak dtype. + * `jax.experimental.host_callback` has been deprecated since March 2024, with + JAX version 0.4.26. Now we removed it. + See {jax-issue}`#20385` for a discussion of alternatives. + +* Changes: + * `jax.lax.FftType` was introduced as a public name for the enum of FFT + operations. The semi-public API `jax.lib.xla_client.FftType` has been + deprecated. + +* Deprecations: + * The semi-public API `jax.lib.xla_client.PaddingType` has been deprecated. + No JAX APIs consume this type, so there is no replacement. + * The default behavior of {func}`jax.pure_callback` and + {func}`jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has + the `vectorized` parameter to those functions. The `vmap_method` parameter + should be used instead for better defined behavior. See the discussion in + {jax-issue}`#23881` for more details. + * The semi-public API `jax.lib.xla_client.register_custom_call_target` has + been deprecated. Use the JAX FFI instead. + * The semi-public APIs `jax.lib.xla_client.dtype_to_etype`, + `jax.lib.xla_client.ops`, + `jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`, + `jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and + `jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO + instead. + +## jax 0.4.34 (October 4, 2024) * New Functionality * This release includes wheels for Python 3.13. Free-threading mode is not yet @@ -27,6 +59,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. This avoids costly reshapes when passing results from pmap into jit. + * `jax.experimental.host_callback` has been deprecated since March 2024, with + JAX version 0.4.26. Now we set the default value of the + `--jax_host_callback_legacy` configuration value to `True`, which means that + if your code uses `jax.experimental.host_callback` APIs, those API calls + will be implemented in terms of the new `jax.experimental.io_callback` API. + If this breaks your code, for a very limited time, you can set the + `--jax_host_callback_legacy` to `True`. Soon we will remove that + configuration option, so you should instead transition to using the + new JAX callback APIs. See {jax-issue}`#20385` for a discussion. * Deprecations * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike @@ -56,10 +97,13 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. is only a tree-prefix of itself. To preserve the current behavior, you can ask `jax.tree.map` to treat `None` as a leaf value by writing: `jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`. + * `jax.sharding.XLACompatibleSharding` has been removed. Please use + `jax.sharding.Sharding`. * Bug fixes * Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs if a non-boolean input was provided and `dtype=bool` was specified. + * Edit implementation of {func}`jax.numpy.ldexp` to get correct gradient. ## jax 0.4.33 (September 16, 2024) diff --git a/README.md b/README.md index d67bdac82414..79cf8fc6d358 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ # Transformable numerical computing at scale -![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg) -![PyPI version](https://img.shields.io/pypi/v/jax) +[![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg)](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml) +[![PyPI version](https://img.shields.io/pypi/v/jax)](https://pypi.org/project/jax/) [**Quickstart**](#quickstart-colab-in-the-cloud) | [**Transformations**](#transformations) diff --git a/build/build.py b/build/build.py index de0d5a9817fb..44343ebab4ef 100755 --- a/build/build.py +++ b/build/build.py @@ -285,9 +285,9 @@ def write_bazelrc(*, remote_build, if enable_cuda: f.write("build --config=cuda\n") if use_cuda_nvcc: - f.write("build --config=cuda_nvcc\n") + f.write("build --config=build_cuda_with_nvcc\n") else: - f.write("build --config=cuda_clang\n") + f.write("build --config=build_cuda_with_clang\n") f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") if not enable_nccl: f.write("build --config=nonccl\n") @@ -301,9 +301,12 @@ def write_bazelrc(*, remote_build, f.write( f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n') if enable_rocm: - f.write("build --config=rocm\n") + f.write("build --config=rocm_base\n") if not enable_nccl: f.write("build --config=nonccl\n") + if use_clang: + f.write("build --config=rocm\n") + f.write(f"build --action_env=CLANG_COMPILER_PATH={clang_path}\n") if python_version: f.write( "build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format( @@ -495,7 +498,7 @@ def main(): help="A comma-separated list of CUDA compute capabilities to support.") parser.add_argument( "--rocm_amdgpu_targets", - default="gfx900,gfx906,gfx908,gfx90a,gfx1030", + default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100", help="A comma-separated list of ROCm amdgpu targets to support.") parser.add_argument( "--rocm_path", diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index e2369a8001bb..019c088fbd91 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -12,9 +12,9 @@ attrs==24.2.0 \ --hash=sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346 \ --hash=sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2 # via hypothesis -build==1.2.2 \ - --hash=sha256:119b2fb462adef986483438377a13b2f42064a2a3a4161f24a0cca698a07ac8c \ - --hash=sha256:277ccc71619d98afdd841a0e96ac9fe1593b823af481d3b0cea748e8894e0613 +build==1.2.2.post1 \ + --hash=sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5 \ + --hash=sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7 # via -r build/test-requirements.txt cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ @@ -103,65 +103,71 @@ execnet==2.1.1 \ --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 # via pytest-xdist -filelock==3.16.0 \ - --hash=sha256:81de9eb8453c769b63369f87f11131a7ab04e367f8d97ad39dc230daa07e3bec \ - --hash=sha256:f6ed4c963184f4c84dd5557ce8fece759a3724b37b80c6c4f20a2f63a4dc6609 +filelock==3.16.1 \ + --hash=sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0 \ + --hash=sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435 # via -r build/test-requirements.txt flatbuffers==24.3.25 \ --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 # via -r build/test-requirements.txt -fonttools==4.53.1 \ - --hash=sha256:02569e9a810f9d11f4ae82c391ebc6fb5730d95a0657d24d754ed7763fb2d122 \ - --hash=sha256:0679a30b59d74b6242909945429dbddb08496935b82f91ea9bf6ad240ec23397 \ - --hash=sha256:10f5e6c3510b79ea27bb1ebfcc67048cde9ec67afa87c7dd7efa5c700491ac7f \ - --hash=sha256:2af40ae9cdcb204fc1d8f26b190aa16534fcd4f0df756268df674a270eab575d \ - --hash=sha256:32f029c095ad66c425b0ee85553d0dc326d45d7059dbc227330fc29b43e8ba60 \ - --hash=sha256:35250099b0cfb32d799fb5d6c651220a642fe2e3c7d2560490e6f1d3f9ae9169 \ - --hash=sha256:3b3c8ebafbee8d9002bd8f1195d09ed2bd9ff134ddec37ee8f6a6375e6a4f0e8 \ - --hash=sha256:4824c198f714ab5559c5be10fd1adf876712aa7989882a4ec887bf1ef3e00e31 \ - --hash=sha256:5ff7e5e9bad94e3a70c5cd2fa27f20b9bb9385e10cddab567b85ce5d306ea923 \ - --hash=sha256:651390c3b26b0c7d1f4407cad281ee7a5a85a31a110cbac5269de72a51551ba2 \ - --hash=sha256:6e08f572625a1ee682115223eabebc4c6a2035a6917eac6f60350aba297ccadb \ - --hash=sha256:6ed170b5e17da0264b9f6fae86073be3db15fa1bd74061c8331022bca6d09bab \ - --hash=sha256:73379d3ffdeecb376640cd8ed03e9d2d0e568c9d1a4e9b16504a834ebadc2dfb \ - --hash=sha256:75a157d8d26c06e64ace9df037ee93a4938a4606a38cb7ffaf6635e60e253b7a \ - --hash=sha256:791b31ebbc05197d7aa096bbc7bd76d591f05905d2fd908bf103af4488e60670 \ - --hash=sha256:7b6b35e52ddc8fb0db562133894e6ef5b4e54e1283dff606fda3eed938c36fc8 \ - --hash=sha256:84ec3fb43befb54be490147b4a922b5314e16372a643004f182babee9f9c3407 \ - --hash=sha256:8959a59de5af6d2bec27489e98ef25a397cfa1774b375d5787509c06659b3671 \ - --hash=sha256:9dfdae43b7996af46ff9da520998a32b105c7f098aeea06b2226b30e74fbba88 \ - --hash=sha256:9e6ceba2a01b448e36754983d376064730690401da1dd104ddb543519470a15f \ - --hash=sha256:9efd176f874cb6402e607e4cc9b4a9cd584d82fc34a4b0c811970b32ba62501f \ - --hash=sha256:a1c7c5aa18dd3b17995898b4a9b5929d69ef6ae2af5b96d585ff4005033d82f0 \ - --hash=sha256:aae7bd54187e8bf7fd69f8ab87b2885253d3575163ad4d669a262fe97f0136cb \ - --hash=sha256:b21952c092ffd827504de7e66b62aba26fdb5f9d1e435c52477e6486e9d128b2 \ - --hash=sha256:b96cd370a61f4d083c9c0053bf634279b094308d52fdc2dd9a22d8372fdd590d \ - --hash=sha256:becc5d7cb89c7b7afa8321b6bb3dbee0eec2b57855c90b3e9bf5fb816671fa7c \ - --hash=sha256:bee32ea8765e859670c4447b0817514ca79054463b6b79784b08a8df3a4d78e3 \ - --hash=sha256:c6e7170d675d12eac12ad1a981d90f118c06cf680b42a2d74c6c931e54b50719 \ - --hash=sha256:c818c058404eb2bba05e728d38049438afd649e3c409796723dfc17cd3f08749 \ - --hash=sha256:c8696544c964500aa9439efb6761947393b70b17ef4e82d73277413f291260a4 \ - --hash=sha256:c9cd19cf4fe0595ebdd1d4915882b9440c3a6d30b008f3cc7587c1da7b95be5f \ - --hash=sha256:d4d0096cb1ac7a77b3b41cd78c9b6bc4a400550e21dc7a92f2b5ab53ed74eb02 \ - --hash=sha256:d92d3c2a1b39631a6131c2fa25b5406855f97969b068e7e08413325bc0afba58 \ - --hash=sha256:da33440b1413bad53a8674393c5d29ce64d8c1a15ef8a77c642ffd900d07bfe1 \ - --hash=sha256:e013aae589c1c12505da64a7d8d023e584987e51e62006e1bb30d72f26522c41 \ - --hash=sha256:e128778a8e9bc11159ce5447f76766cefbd876f44bd79aff030287254e4752c4 \ - --hash=sha256:e54f1bba2f655924c1138bbc7fa91abd61f45c68bd65ab5ed985942712864bbb \ - --hash=sha256:e5b708073ea3d684235648786f5f6153a48dc8762cdfe5563c57e80787c29fbb \ - --hash=sha256:e8bf06b94694251861ba7fdeea15c8ec0967f84c3d4143ae9daf42bbc7717fe3 \ - --hash=sha256:f08df60fbd8d289152079a65da4e66a447efc1d5d5a4d3f299cdd39e3b2e4a7d \ - --hash=sha256:f1f8758a2ad110bd6432203a344269f445a2907dc24ef6bccfd0ac4e14e0d71d \ - --hash=sha256:f677ce218976496a587ab17140da141557beb91d2a5c1a14212c994093f2eae2 +fonttools==4.54.1 \ + --hash=sha256:07e005dc454eee1cc60105d6a29593459a06321c21897f769a281ff2d08939f6 \ + --hash=sha256:0a911591200114969befa7f2cb74ac148bce5a91df5645443371aba6d222e263 \ + --hash=sha256:0d1d353ef198c422515a3e974a1e8d5b304cd54a4c2eebcae708e37cd9eeffb1 \ + --hash=sha256:0e88e3018ac809b9662615072dcd6b84dca4c2d991c6d66e1970a112503bba7e \ + --hash=sha256:1d152d1be65652fc65e695e5619e0aa0982295a95a9b29b52b85775243c06556 \ + --hash=sha256:262705b1663f18c04250bd1242b0515d3bbae177bee7752be67c979b7d47f43d \ + --hash=sha256:278913a168f90d53378c20c23b80f4e599dca62fbffae4cc620c8eed476b723e \ + --hash=sha256:301540e89cf4ce89d462eb23a89464fef50915255ece765d10eee8b2bf9d75b2 \ + --hash=sha256:31c32d7d4b0958600eac75eaf524b7b7cb68d3a8c196635252b7a2c30d80e986 \ + --hash=sha256:357cacb988a18aace66e5e55fe1247f2ee706e01debc4b1a20d77400354cddeb \ + --hash=sha256:37cddd62d83dc4f72f7c3f3c2bcf2697e89a30efb152079896544a93907733bd \ + --hash=sha256:41bb0b250c8132b2fcac148e2e9198e62ff06f3cc472065dff839327945c5882 \ + --hash=sha256:4aa4817f0031206e637d1e685251ac61be64d1adef111060df84fdcbc6ab6c44 \ + --hash=sha256:4e10d2e0a12e18f4e2dd031e1bf7c3d7017be5c8dbe524d07706179f355c5dac \ + --hash=sha256:5419771b64248484299fa77689d4f3aeed643ea6630b2ea750eeab219588ba20 \ + --hash=sha256:54471032f7cb5fca694b5f1a0aaeba4af6e10ae989df408e0216f7fd6cdc405d \ + --hash=sha256:58974b4987b2a71ee08ade1e7f47f410c367cdfc5a94fabd599c88165f56213a \ + --hash=sha256:58d29b9a294573d8319f16f2f79e42428ba9b6480442fa1836e4eb89c4d9d61c \ + --hash=sha256:5eb2474a7c5be8a5331146758debb2669bf5635c021aee00fd7c353558fc659d \ + --hash=sha256:6e37561751b017cf5c40fce0d90fd9e8274716de327ec4ffb0df957160be3bff \ + --hash=sha256:76ae5091547e74e7efecc3cbf8e75200bc92daaeb88e5433c5e3e95ea8ce5aa7 \ + --hash=sha256:7965af9b67dd546e52afcf2e38641b5be956d68c425bef2158e95af11d229f10 \ + --hash=sha256:7e3b7d44e18c085fd8c16dcc6f1ad6c61b71ff463636fcb13df7b1b818bd0c02 \ + --hash=sha256:7ed7ee041ff7b34cc62f07545e55e1468808691dddfd315d51dd82a6b37ddef2 \ + --hash=sha256:82834962b3d7c5ca98cb56001c33cf20eb110ecf442725dc5fdf36d16ed1ab07 \ + --hash=sha256:8583e563df41fdecef31b793b4dd3af8a9caa03397be648945ad32717a92885b \ + --hash=sha256:8fa92cb248e573daab8d032919623cc309c005086d743afb014c836636166f08 \ + --hash=sha256:93d458c8a6a354dc8b48fc78d66d2a8a90b941f7fec30e94c7ad9982b1fa6bab \ + --hash=sha256:957f669d4922f92c171ba01bef7f29410668db09f6c02111e22b2bce446f3285 \ + --hash=sha256:9dc080e5a1c3b2656caff2ac2633d009b3a9ff7b5e93d0452f40cd76d3da3b3c \ + --hash=sha256:9ef1b167e22709b46bf8168368b7b5d3efeaaa746c6d39661c1b4405b6352e58 \ + --hash=sha256:a7a310c6e0471602fe3bf8efaf193d396ea561486aeaa7adc1f132e02d30c4b9 \ + --hash=sha256:ab774fa225238986218a463f3fe151e04d8c25d7de09df7f0f5fce27b1243dbc \ + --hash=sha256:ada215fd079e23e060157aab12eba0d66704316547f334eee9ff26f8c0d7b8ab \ + --hash=sha256:c39287f5c8f4a0c5a55daf9eaf9ccd223ea59eed3f6d467133cc727d7b943a55 \ + --hash=sha256:c9c563351ddc230725c4bdf7d9e1e92cbe6ae8553942bd1fb2b2ff0884e8b714 \ + --hash=sha256:d26732ae002cc3d2ecab04897bb02ae3f11f06dd7575d1df46acd2f7c012a8d8 \ + --hash=sha256:d3b659d1029946f4ff9b6183984578041b520ce0f8fb7078bb37ec7445806b33 \ + --hash=sha256:dd9cc95b8d6e27d01e1e1f1fae8559ef3c02c76317da650a19047f249acd519d \ + --hash=sha256:e4564cf40cebcb53f3dc825e85910bf54835e8a8b6880d59e5159f0f325e637e \ + --hash=sha256:e7d82b9e56716ed32574ee106cabca80992e6bbdcf25a88d97d21f73a0aae664 \ + --hash=sha256:e8a4b261c1ef91e7188a30571be6ad98d1c6d9fa2427244c545e2fa0a2494dd7 \ + --hash=sha256:e96bc94c8cda58f577277d4a71f51c8e2129b8b36fd05adece6320dd3d57de8a \ + --hash=sha256:ed2f80ca07025551636c555dec2b755dd005e2ea8fbeb99fc5cdff319b70b23b \ + --hash=sha256:f5b8a096e649768c2f4233f947cf9737f8dbf8728b90e2771e2497c6e3d21d13 \ + --hash=sha256:f8e953cc0bddc2beaf3a3c3b5dd9ab7554677da72dfaf46951e193c9653e515a \ + --hash=sha256:fda582236fee135d4daeca056c8c88ec5f6f6d88a004a79b84a02547c8f57386 \ + --hash=sha256:fdb062893fd6d47b527d39346e0c5578b7957dcea6d6a3b6794569370013d9ac # via matplotlib fsspec==2024.9.0 \ --hash=sha256:4b0afb90c2f21832df142f292649035d80b421f60a9e1c027802e5a0da2b04e8 \ --hash=sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b # via etils -hypothesis==6.112.1 \ - --hash=sha256:93631b1498b20d2c205ed304cbd41d50e9c069d78a9c773c1324ca094c5e30ce \ - --hash=sha256:b070d7a1bb9bd84706c31885c9aeddc138e2b36a9c112a91984f49501c567856 +hypothesis==6.112.5 \ + --hash=sha256:82fbd28a92c4d88743740e3ec05415ea25119d825d1fdac9ab7bf717fe56297b \ + --hash=sha256:e6b7c8ba1126e07cfbf76b8bb544cedd89cb7f7bcf6c315bd759cd2efc2063ff # via -r build/test-requirements.txt importlib-resources==6.4.5 \ --hash=sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065 \ @@ -360,75 +366,74 @@ ml-dtypes==0.5.0 \ --hash=sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7 \ --hash=sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599 # via -r build/requirements.in -mpmath==1.4.0a1 \ - --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ - --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 +mpmath==1.3.0 \ + --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ + --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt -numpy==2.1.1 ; python_version >= "3.13" \ - --hash=sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5 \ - --hash=sha256:0b8cc2715a84b7c3b161f9ebbd942740aaed913584cae9cdc7f8ad5ad41943d0 \ - --hash=sha256:0d07841fd284718feffe7dd17a63a2e6c78679b2d386d3e82f44f0108c905550 \ - --hash=sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c \ - --hash=sha256:13ce49a34c44b6de5241f0b38b07e44c1b2dcacd9e36c30f9c2fcb1bb5135db7 \ - --hash=sha256:24c2ad697bd8593887b019817ddd9974a7f429c14a5469d7fad413f28340a6d2 \ - --hash=sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b \ - --hash=sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df \ - --hash=sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f \ - --hash=sha256:397bc5ce62d3fb73f304bec332171535c187e0643e176a6e9421a6e3eacef06d \ - --hash=sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270 \ - --hash=sha256:50a95ca3560a6058d6ea91d4629a83a897ee27c00630aed9d933dff191f170cd \ - --hash=sha256:52ac2e48f5ad847cd43c4755520a2317f3380213493b9d8a4c5e37f3b87df504 \ - --hash=sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec \ - --hash=sha256:57eb525e7c2a8fdee02d731f647146ff54ea8c973364f3b850069ffb42799647 \ - --hash=sha256:5889dd24f03ca5a5b1e8a90a33b5a0846d8977565e4ae003a63d22ecddf6782f \ - --hash=sha256:59ca673ad11d4b84ceb385290ed0ebe60266e356641428c845b39cd9df6713ab \ - --hash=sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe \ - --hash=sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5 \ - --hash=sha256:7be6a07520b88214ea85d8ac8b7d6d8a1839b0b5cb87412ac9f49fa934eb15d5 \ - --hash=sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e \ - --hash=sha256:7dd86dfaf7c900c0bbdcb8b16e2f6ddf1eb1fe39c6c8cca6e94844ed3152a8fd \ - --hash=sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313 \ - --hash=sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0 \ - --hash=sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f \ - --hash=sha256:913cc1d311060b1d409e609947fa1b9753701dac96e6581b58afc36b7ee35af6 \ - --hash=sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553 \ - --hash=sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed \ - --hash=sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb \ - --hash=sha256:98ce7fb5b8063cfdd86596b9c762bf2b5e35a2cdd7e967494ab78a1fa7f8b86e \ - --hash=sha256:99f4a9ee60eed1385a86e82288971a51e71df052ed0b2900ed30bc840c0f2e39 \ - --hash=sha256:9a8e06c7a980869ea67bbf551283bbed2856915f0a792dc32dd0f9dd2fb56728 \ - --hash=sha256:ae8ce252404cdd4de56dcfce8b11eac3c594a9c16c231d081fb705cf23bd4d9e \ - --hash=sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a \ - --hash=sha256:b49742cdb85f1f81e4dc1b39dcf328244f4d8d1ded95dea725b316bd2cf18c95 \ - --hash=sha256:b5613cfeb1adfe791e8e681128f5f49f22f3fcaa942255a6124d58ca59d9528f \ - --hash=sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480 \ - --hash=sha256:c8a0e34993b510fc19b9a2ce7f31cb8e94ecf6e924a40c0c9dd4f62d0aac47d9 \ - --hash=sha256:caf5d284ddea7462c32b8d4a6b8af030b6c9fd5332afb70e7414d7fdded4bfd0 \ - --hash=sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f \ - --hash=sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd \ - --hash=sha256:d10c39947a2d351d6d466b4ae83dad4c37cd6c3cdd6d5d0fa797da56f710a6ae \ - --hash=sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201 \ - --hash=sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136 \ - --hash=sha256:d51fc141ddbe3f919e91a096ec739f49d686df8af254b2053ba21a910ae518bf \ - --hash=sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78 \ - --hash=sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468 \ - --hash=sha256:e8d5f8a8e3bc87334f025194c6193e408903d21ebaeb10952264943a985066ca \ - --hash=sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef \ - --hash=sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0 \ - --hash=sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556 \ - --hash=sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521 \ - --hash=sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b +numpy==2.1.2 ; python_version >= "3.13" \ + --hash=sha256:05b2d4e667895cc55e3ff2b56077e4c8a5604361fc21a042845ea3ad67465aa8 \ + --hash=sha256:12edb90831ff481f7ef5f6bc6431a9d74dc0e5ff401559a71e5e4611d4f2d466 \ + --hash=sha256:13311c2db4c5f7609b462bc0f43d3c465424d25c626d95040f073e30f7570e35 \ + --hash=sha256:13532a088217fa624c99b843eeb54640de23b3414b14aa66d023805eb731066c \ + --hash=sha256:13602b3174432a35b16c4cfb5de9a12d229727c3dd47a6ce35111f2ebdf66ff4 \ + --hash=sha256:1600068c262af1ca9580a527d43dc9d959b0b1d8e56f8a05d830eea39b7c8af6 \ + --hash=sha256:1b8cde4f11f0a975d1fd59373b32e2f5a562ade7cde4f85b7137f3de8fbb29a0 \ + --hash=sha256:1c193d0b0238638e6fc5f10f1b074a6993cb13b0b431f64079a509d63d3aa8b7 \ + --hash=sha256:1ebec5fd716c5a5b3d8dfcc439be82a8407b7b24b230d0ad28a81b61c2f4659a \ + --hash=sha256:242b39d00e4944431a3cd2db2f5377e15b5785920421993770cddb89992c3f3a \ + --hash=sha256:259ec80d54999cc34cd1eb8ded513cb053c3bf4829152a2e00de2371bd406f5e \ + --hash=sha256:2abbf905a0b568706391ec6fa15161fad0fb5d8b68d73c461b3c1bab6064dd62 \ + --hash=sha256:2cbba4b30bf31ddbe97f1c7205ef976909a93a66bb1583e983adbd155ba72ac2 \ + --hash=sha256:2ffef621c14ebb0188a8633348504a35c13680d6da93ab5cb86f4e54b7e922b5 \ + --hash=sha256:30d53720b726ec36a7f88dc873f0eec8447fbc93d93a8f079dfac2629598d6ee \ + --hash=sha256:32e16a03138cabe0cb28e1007ee82264296ac0983714094380b408097a418cfe \ + --hash=sha256:43cca367bf94a14aca50b89e9bc2061683116cfe864e56740e083392f533ce7a \ + --hash=sha256:456e3b11cb79ac9946c822a56346ec80275eaf2950314b249b512896c0d2505e \ + --hash=sha256:4d6ec0d4222e8ffdab1744da2560f07856421b367928026fb540e1945f2eeeaf \ + --hash=sha256:5006b13a06e0b38d561fab5ccc37581f23c9511879be7693bd33c7cd15ca227c \ + --hash=sha256:675c741d4739af2dc20cd6c6a5c4b7355c728167845e3c6b0e824e4e5d36a6c3 \ + --hash=sha256:6cdb606a7478f9ad91c6283e238544451e3a95f30fb5467fbf715964341a8a86 \ + --hash=sha256:6d95f286b8244b3649b477ac066c6906fbb2905f8ac19b170e2175d3d799f4df \ + --hash=sha256:76322dcdb16fccf2ac56f99048af32259dcc488d9b7e25b51e5eca5147a3fb98 \ + --hash=sha256:7c1c60328bd964b53f8b835df69ae8198659e2b9302ff9ebb7de4e5a5994db3d \ + --hash=sha256:860ec6e63e2c5c2ee5e9121808145c7bf86c96cca9ad396c0bd3e0f2798ccbe2 \ + --hash=sha256:8e00ea6fc82e8a804433d3e9cedaa1051a1422cb6e443011590c14d2dea59146 \ + --hash=sha256:9c6c754df29ce6a89ed23afb25550d1c2d5fdb9901d9c67a16e0b16eaf7e2550 \ + --hash=sha256:a26ae94658d3ba3781d5e103ac07a876b3e9b29db53f68ed7df432fd033358a8 \ + --hash=sha256:a65acfdb9c6ebb8368490dbafe83c03c7e277b37e6857f0caeadbbc56e12f4fb \ + --hash=sha256:a7d80b2e904faa63068ead63107189164ca443b42dd1930299e0d1cb041cec2e \ + --hash=sha256:a84498e0d0a1174f2b3ed769b67b656aa5460c92c9554039e11f20a05650f00d \ + --hash=sha256:ab4754d432e3ac42d33a269c8567413bdb541689b02d93788af4131018cbf366 \ + --hash=sha256:ad369ed238b1959dfbade9018a740fb9392c5ac4f9b5173f420bd4f37ba1f7a0 \ + --hash=sha256:b1d0fcae4f0949f215d4632be684a539859b295e2d0cb14f78ec231915d644db \ + --hash=sha256:b42a1a511c81cc78cbc4539675713bbcf9d9c3913386243ceff0e9429ca892fe \ + --hash=sha256:bd33f82e95ba7ad632bc57837ee99dba3d7e006536200c4e9124089e1bf42426 \ + --hash=sha256:bdd407c40483463898b84490770199d5714dcc9dd9b792f6c6caccc523c00952 \ + --hash=sha256:c6eef7a2dbd0abfb0d9eaf78b73017dbfd0b54051102ff4e6a7b2980d5ac1a03 \ + --hash=sha256:c82af4b2ddd2ee72d1fc0c6695048d457e00b3582ccde72d8a1c991b808bb20f \ + --hash=sha256:d666cb72687559689e9906197e3bec7b736764df6a2e58ee265e360663e9baf7 \ + --hash=sha256:d7bf0a4f9f15b32b5ba53147369e94296f5fffb783db5aacc1be15b4bf72f43b \ + --hash=sha256:d82075752f40c0ddf57e6e02673a17f6cb0f8eb3f587f63ca1eaab5594da5b17 \ + --hash=sha256:da65fb46d4cbb75cb417cddf6ba5e7582eb7bb0b47db4b99c9fe5787ce5d91f5 \ + --hash=sha256:e2b49c3c0804e8ecb05d59af8386ec2f74877f7ca8fd9c1e00be2672e4d399b1 \ + --hash=sha256:e585c8ae871fd38ac50598f4763d73ec5497b0de9a0ab4ef5b69f01c6a046142 \ + --hash=sha256:e8d3ca0a72dd8846eb6f7dfe8f19088060fcb76931ed592d29128e0219652884 \ + --hash=sha256:ef444c57d664d35cac4e18c298c47d7b504c66b17c2ea91312e979fcfbdfb08a \ + --hash=sha256:f1eb068ead09f4994dec71c24b2844f1e4e4e013b9629f812f292f04bd1510d9 \ + --hash=sha256:f2ded8d9b6f68cc26f8425eda5d3877b47343e68ca23d0d0846f4d312ecaa445 \ + --hash=sha256:f751ed0a2f250541e19dfca9f1eafa31a392c71c832b6bb9e113b10d050cb0f1 \ + --hash=sha256:faa88bc527d0f097abdc2c663cddf37c05a1c2f113716601555249805cf573f1 \ + --hash=sha256:fc44e3c68ff00fd991b59092a54350e6e4911152682b4782f68070985aa9e648 # via # -r build/requirements.in # -r build/test-requirements.txt # contourpy # matplotlib # ml-dtypes - # opt-einsum # scipy -opt-einsum==3.3.0 \ - --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ - --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 +opt-einsum==3.4.0 \ + --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ + --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac # via -r build/requirements.in packaging==24.1 \ --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ @@ -552,13 +557,13 @@ pygments==2.18.0 \ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a # via rich -pyparsing==3.2.0b1 \ - --hash=sha256:51e00c907f7b2ac2d2c35c4d431e944c525ddcfd58b09517f308f40d70e0ddca \ - --hash=sha256:ecf0805530839936196a802cd6d6d65ffa9328eebdc8ee5b8f4b358be5f16666 +pyparsing==3.1.4 \ + --hash=sha256:a6a7ee4235a3f944aa1fa2249307708f893fe5717dc603503c6c7969c070fb7c \ + --hash=sha256:f86ec8d1a83f11977c9a6ea7598e8c27fc5cddfa5b07ea2241edbbde1d7bc032 # via matplotlib -pyproject-hooks==1.1.0 \ - --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ - --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 +pyproject-hooks==1.2.0 \ + --hash=sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8 \ + --hash=sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913 # via build pytest==8.3.3 \ --hash=sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181 \ @@ -572,9 +577,9 @@ python-dateutil==2.9.0.post0 \ --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 # via matplotlib -rich==13.8.1 \ - --hash=sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06 \ - --hash=sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a +rich==13.9.2 \ + --hash=sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c \ + --hash=sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1 # via -r build/test-requirements.txt scipy==1.14.1 \ --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \ diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index 0bcc89f493ce..e20291cefd63 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -5,11 +5,17 @@ FROM ubuntu:20.04 AS rocm_base RUN --mount=type=cache,target=/var/cache/apt \ apt-get update && apt-get install -y python3 python-is-python3 +# Install bzip2 and sqlite3 packages +RUN apt-get update && apt-get install -y \ + sqlite3 libsqlite3-dev \ + libbz2-dev \ + && rm -rf /var/lib/apt/lists/* + # Add target file to help determine which device(s) to build for -ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} -# Install ROCM +# Install ROCm ARG ROCM_VERSION=6.0.0 ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} ENV ROCM_PATH=${ROCM_PATH} @@ -19,13 +25,8 @@ RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ --mount=type=cache,target=/var/cache/apt \ python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM -# Set up paths -ENV HCC_HOME=$ROCM_PATH/hcc -ENV HIP_PATH=$ROCM_PATH/ -ENV OPENCL_ROOT=$ROCM_PATH/opencl -ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}" +# add ROCm bins to PATH ENV PATH="$ROCM_PATH/bin:${PATH}" -ENV PATH="$OPENCL_ROOT/bin:${PATH}" ENV PATH="/root/bin:/root/.local/bin:$PATH" # install pyenv and python build dependencies @@ -75,6 +76,7 @@ FROM rocm_base AS rt_build ARG JAX_VERSION ARG JAX_COMMIT ARG XLA_COMMIT +ARG JAX_USE_CLANG LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ com.amdgpu.python_version="$PYTHON_VERSION" \ @@ -82,7 +84,15 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ com.amdgpu.jax_commit="$JAX_COMMIT" \ com.amdgpu.xla_commit="$XLA_COMMIT" + +# Create a directory to copy and retain the wheels in the image. +RUN mkdir -p /rocm_jax_wheels + RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ --mount=type=bind,source=wheelhouse,target=/wheelhouse \ - pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt + cp /wheelhouse/* /rocm_jax_wheels/ && \ + ls -lah /wheelhouse && \ + pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \ + pip3 install wheelhouse/*rocm60*.whl + diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index caf303d45ff3..a67a7ecb2e22 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -7,3 +7,13 @@ ARG ROCM_BUILD_NUM RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM + +ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +RUN printf '%s\n' > /opt/rocm/bin/target.lst ${GPU_DEVICE_TARGETS} + +# Install LLVM 18 and dependencies. +RUN --mount=type=cache,target=/var/cache/dnf \ + dnf install -y wget && dnf clean all +RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \ + mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \ + make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project diff --git a/build/rocm/ci_build b/build/rocm/ci_build index aeb0201e27ed..9fb0ebd77f87 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -34,8 +34,12 @@ def image_by_name(name): def dist_wheels( - rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num="", - compiler="gcc" + rocm_version, + python_versions, + xla_path, + rocm_build_job="", + rocm_build_num="", + compiler="gcc", ): if xla_path: xla_path = os.path.abspath(xla_path) @@ -260,7 +264,7 @@ def parse_args(): p.add_argument( "--compiler", choices=["gcc", "clang"], - help="Compiler backend to use when compiling jax/jaxlib" + help="Compiler backend to use when compiling jax/jaxlib", ) subp = p.add_subparsers(dest="action", required=True) diff --git a/build/rocm/docker/Dockerfile.jax-ubu22 b/build/rocm/docker/Dockerfile.jax-ubu22 index ba64efbbc682..eb971482f708 100644 --- a/build/rocm/docker/Dockerfile.jax-ubu22 +++ b/build/rocm/docker/Dockerfile.jax-ubu22 @@ -3,8 +3,14 @@ FROM ubuntu:22.04 RUN --mount=type=cache,target=/var/cache/apt \ apt-get update && apt-get install -y python3 python-is-python3 +# Install bzip2 and sqlite3 packages +RUN apt-get update && apt-get install -y \ + sqlite3 libsqlite3-dev \ + libbz2-dev \ + && rm -rf /var/lib/apt/lists/* + # Add target file to help determine which device(s) to build for -ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} # Install ROCM @@ -61,4 +67,6 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ --mount=type=bind,source=wheelhouse,target=/wheelhouse \ - pip3 install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt + ls -lah /wheelhouse && \ + pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \ + pip3 install wheelhouse/*rocm60*.whl diff --git a/build/rocm/docker/Dockerfile.jax-ubu24 b/build/rocm/docker/Dockerfile.jax-ubu24 index 44c59b1b7e6b..da714542e5f8 100644 --- a/build/rocm/docker/Dockerfile.jax-ubu24 +++ b/build/rocm/docker/Dockerfile.jax-ubu24 @@ -3,6 +3,12 @@ FROM ubuntu:24.04 RUN --mount=type=cache,target=/var/cache/apt \ apt-get update && apt-get install -y python3 python-is-python3 python3-pip +# Install bzip2 and sqlite3 packages +RUN apt-get update && apt-get install -y \ + sqlite3 libsqlite3-dev \ + libbz2-dev \ + && rm -rf /var/lib/apt/lists/* + # Add target file to help determine which device(s) to build for ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} @@ -60,4 +66,6 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ --mount=type=bind,source=wheelhouse,target=/wheelhouse \ - pip3 install --break-system-packages --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt + ls -lah /wheelhouse && \ + pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \ + pip3 install wheelhouse/*rocm60*.whl diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index b6dd1256e2f5..deb6ab703391 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -56,7 +56,39 @@ def update_rocm_targets(rocm_path, targets): open(version_fp, "a").close() -def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"): +def find_clang_path(): + llvm_base_path = "/usr/lib/" + # Search for llvm directories and pick the highest version. + llvm_dirs = [d for d in os.listdir(llvm_base_path) if d.startswith("llvm-")] + if llvm_dirs: + # Sort to get the highest llvm version. + llvm_dirs.sort(reverse=True) + clang_bin_dir = os.path.join(llvm_base_path, llvm_dirs[0], "bin") + + # Prefer versioned clang binaries (e.g., clang-18). + versioned_clang = None + generic_clang = None + + for f in os.listdir(clang_bin_dir): + # Checks for versioned clang binaries. + if f.startswith("clang-") and f[6:].isdigit(): + versioned_clang = os.path.join(clang_bin_dir, f) + # Fallback to non-versioned clang. + elif f == "clang": + generic_clang = os.path.join(clang_bin_dir, f) + + # Return versioned clang if available, otherwise return generic clang. + if versioned_clang: + return versioned_clang + elif generic_clang: + return generic_clang + + return None + + +def build_jaxlib_wheel( + jax_path, rocm_path, python_version, xla_path=None, compiler="gcc" +): use_clang = "true" if compiler == "clang" else "false" cmd = [ "python", @@ -68,6 +100,14 @@ def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None, compi "--use_clang=%s" % use_clang, ] + # Add clang path if clang is used. + if compiler == "clang": + clang_path = find_clang_path() + if clang_path: + cmd.append("--clang_path=%s" % clang_path) + else: + raise RuntimeError("Clang binary not found in /usr/lib/llvm-*") + if xla_path: cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path) @@ -166,18 +206,26 @@ def to_cpy_ver(python_version): def fix_wheel(path, jax_path): - # NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8 - # so use one of the CPythons in /opt to run - env = dict(os.environ) - py_bin = "/opt/python/cp310-cp310/bin" - env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) - - cmd = ["pip", "install", "auditwheel>=6"] - subprocess.run(cmd, check=True, env=env) - - fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") - cmd = ["python", fixwheel_path, path] - subprocess.run(cmd, check=True, env=env) + try: + # NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8 + # so use one of the CPythons in /opt to run + env = dict(os.environ) + py_bin = "/opt/python/cp310-cp310/bin" + env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) + + cmd = ["pip", "install", "auditwheel>=6"] + subprocess.run(cmd, check=True, env=env) + + fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") + cmd = ["python", fixwheel_path, path] + subprocess.run(cmd, check=True, env=env) + LOG.info("Wheel fix completed successfully.") + except subprocess.CalledProcessError as cpe: + LOG.error(f"Subprocess failed with error: {cpe}") + raise + except Exception as e: + LOG.error(f"An unexpected error occurred: {e}") + raise def parse_args(): diff --git a/build/rocm/tools/get_rocm.py b/build/rocm/tools/get_rocm.py index 5334bf40ece7..2bcae5f9064c 100644 --- a/build/rocm/tools/get_rocm.py +++ b/build/rocm/tools/get_rocm.py @@ -320,11 +320,12 @@ def setup_repos_el8(rocm_version_str): """ [amdgpu] name=amdgpu -baseurl=https://repo.radeon.com/amdgpu/latest/rhel/8.8/main/x86_64/ +baseurl=https://repo.radeon.com/amdgpu/%s/rhel/8.8/main/x86_64/ enabled=1 gpgcheck=1 gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key """ + % rocm_version_str ) diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index 8490bd489608..fcb7b570e493 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -656,7 +656,7 @@ class RmsNormFwdClass: assert len(weight_info.shape) == 2 # partition() will force all dims of all inputs to be replicated except the # first dim of x that will be kept as is. - # This is because the implementaion can only be sharded on the batch dimensions. + # This is because the implementation can only be sharded on the batch dimensions. x_spec = arg_infos[0].sharding.spec # None mean that we replicate on that dimension. diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 9a956670ceea..8b418b16f878 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -1982,10 +1982,15 @@ "metadata": {}, "outputs": [], "source": [ + "import io\n", + "from jax.extend.mlir import ir\n", + "from jax.extend.mlir.dialects import func\n", + "from jax.extend.mlir.dialects import stablehlo as hlo\n", "from jax._src import xla_bridge as xb\n", - "from jax._src.lib import xla_client as xc\n", - "xe = xc._xla\n", - "xops = xc._xla.ops\n", + "\n", + "class MlirContext(NamedTuple):\n", + " module: ir.Module\n", + " symbol_table: ir.SymbolTable\n", "\n", "def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):\n", " consts, args = args[:num_consts], args[num_consts:]\n", @@ -2001,26 +2006,48 @@ " typecheck_jaxpr(jaxpr)\n", " consts = [x.val for x in hashable_consts]\n", " in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]\n", - " c = xc.XlaBuilder('xla_call')\n", - " xla_consts = _xla_consts(c, consts)\n", - " xla_params = _xla_params(c, in_avals)\n", - " outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)\n", - " out = xops.Tuple(c, outs)\n", - " compiled = xb.get_backend(None).compile(\n", - " xc._xla.mlir.xla_computation_to_mlir_module(c.build(out)))\n", + "\n", + " with ir.Context() as ctx, ir.Location.unknown(ctx):\n", + " hlo.register_dialect(ctx)\n", + " m = ir.Module.create()\n", + " c = MlirContext(m, ir.SymbolTable(m.operation))\n", + "\n", + " with ir.InsertionPoint(c.module.body):\n", + " @func.func(*(aval_to_ir_type(aval) for aval in in_avals))\n", + " def main(*params):\n", + " return jaxpr_subcomp(c, jaxpr, _hlo_consts(consts) + params)\n", + "\n", + " output = io.StringIO()\n", + " c.module.operation.print(file=output)\n", + " compiled = xb.get_backend(None).compile(output.getvalue())\n", " return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])\n", "\n", - "def _xla_consts(c: xe.XlaBuilder, consts: list[Any]) -> list[xe.XlaOp]:\n", - " unique_consts = {id(cnst): cnst for cnst in consts}\n", - " xla_consts = {\n", - " id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()}\n", - " return [xla_consts[id(cnst)] for cnst in consts]\n", + "def _mlir_dtype(dtype: np.dtype) -> ir.Type:\n", + " if np.issubdtype(dtype, np.signedinteger):\n", + " return ir.IntegerType.get_signless(np.iinfo(dtype).bits)\n", + " elif dtype == np.float32:\n", + " return ir.F32Type.get()\n", + " elif dtype == np.float64:\n", + " return ir.F64Type.get()\n", + " else:\n", + " raise NotImplementedError(\"MLIR conversion not implemented for \", dtype)\n", "\n", - "def _xla_params(c: xe.XlaBuilder, avals_in: list[ShapedArray]) -> list[xe.XlaOp]:\n", - " return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]\n", + "def aval_to_ir_type(aval: ShapedArray) -> ir.Type:\n", + " return ir.RankedTensorType.get(aval.shape, _mlir_dtype(aval.dtype))\n", "\n", - "def _xla_shape(aval: ShapedArray) -> xe.Shape:\n", - " return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)" + "def _hlo_const(x: Any) -> ir.Value:\n", + " a = np.asarray(x)\n", + " if a.dtype == np.bool_:\n", + " return hlo.constant(ir.DenseElementsAttr.get(\n", + " np.packbits(a, bitorder='little'), type=ir.IntegerType.get_signless(1),\n", + " shape=a.shape))\n", + " else:\n", + " return hlo.constant(ir.DenseElementsAttr.get(a))\n", + "\n", + "def _hlo_consts(consts: list[Any]) -> list[ir.Value]:\n", + " unique_consts = {id(cnst): cnst for cnst in consts}\n", + " ir_consts = {id_: _hlo_const(cnst) for id_, cnst in unique_consts.items()}\n", + " return tuple(ir_consts[id(cnst)] for cnst in consts)" ] }, { @@ -2038,23 +2065,25 @@ "metadata": {}, "outputs": [], "source": [ - "def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: list[xe.XlaOp]\n", - " ) -> list[xe.XlaOp]:\n", - " env: dict[Var, xe.XlaOp] = {}\n", + "def jaxpr_subcomp(c: MlirContext, jaxpr: Jaxpr, args: list[ir.Value]) -> list[ir.Value]:\n", + " env: dict[Var, ir.Value] = {}\n", "\n", - " def read(x: Atom) -> xe.XlaOp:\n", - " return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))\n", + " def read(x: Atom) -> ir.Value:\n", + " return env[x] if type(x) is Var else _hlo_const(np.asarray(x.val))\n", "\n", - " def write(v: Var, val: xe.XlaOp) -> None:\n", + " def write(v: Var, val: ir.Value) -> None:\n", " env[v] = val\n", "\n", " map(write, jaxpr.in_binders, args)\n", " for eqn in jaxpr.eqns:\n", " in_avals = [x.aval for x in eqn.inputs]\n", " in_vals = map(read, eqn.inputs)\n", - " rule = xla_translations[eqn.primitive]\n", - " out_vals = rule(c, in_avals, in_vals, **eqn.params)\n", - " map(write, eqn.out_binders, out_vals)\n", + " out_avals = [x.aval for x in eqn.out_binders]\n", + " rule = hlo_translations[eqn.primitive]\n", + " assert all(isinstance(v, ir.Value) for v in in_vals), in_vals\n", + " out_vals = rule(c, in_avals, out_avals, in_vals, **eqn.params)\n", + " assert all(isinstance(v, ir.Value) for v in out_vals), out_vals\n", + " map(write, eqn.out_binders, out_vals), out_vals\n", " return map(read, jaxpr.outs)\n", "\n", "def execute_compiled(compiled, out_avals, *args):\n", @@ -2070,7 +2099,7 @@ " del aval # Unused for now\n", " return np.asarray(buf)\n", "\n", - "xla_translations = {}" + "hlo_translations = {}" ] }, { @@ -2089,32 +2118,43 @@ "metadata": {}, "outputs": [], "source": [ - "def direct_translation(op, c, in_avals, in_vals):\n", - " del c, in_avals\n", + "def direct_translation(op, c, in_avals, out_avals, in_vals):\n", + " del c, in_avals, out_avals\n", " return [op(*in_vals)]\n", "\n", - "xla_translations[add_p] = partial(direct_translation, xops.Add)\n", - "xla_translations[mul_p] = partial(direct_translation, xops.Mul)\n", - "xla_translations[neg_p] = partial(direct_translation, xops.Neg)\n", - "xla_translations[sin_p] = partial(direct_translation, xops.Sin)\n", - "xla_translations[cos_p] = partial(direct_translation, xops.Cos)\n", - "xla_translations[greater_p] = partial(direct_translation, xops.Gt)\n", - "xla_translations[less_p] = partial(direct_translation, xops.Lt)\n", - "\n", - "def reduce_sum_translation(c, in_avals, in_vals, *, axis):\n", - " (x_aval,), (x,) = in_avals, in_vals\n", - " zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))\n", - " subc = xc.XlaBuilder('add')\n", - " shape = _xla_shape(ShapedArray((), x_aval.dtype))\n", - " xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))\n", - " return [xops.Reduce(c, [x], [zero], subc.build(), axis)]\n", - "xla_translations[reduce_sum_p] = reduce_sum_translation\n", - "\n", - "def broadcast_translation(c, in_avals, in_vals, *, shape, axes):\n", - " x, = in_vals\n", + "hlo_translations[add_p] = partial(direct_translation, hlo.add)\n", + "hlo_translations[mul_p] = partial(direct_translation, hlo.multiply)\n", + "hlo_translations[neg_p] = partial(direct_translation, hlo.negate)\n", + "hlo_translations[sin_p] = partial(direct_translation, hlo.sine)\n", + "hlo_translations[cos_p] = partial(direct_translation, hlo.cosine)\n", + "\n", + "def compare_translation(op, c, in_avals, out_avals, in_vals):\n", + " del c, out_avals\n", + " return [hlo.compare(*in_vals, hlo.ComparisonDirectionAttr.get(op))]\n", + "\n", + "hlo_translations[greater_p] = partial(compare_translation, \"GT\")\n", + "hlo_translations[less_p] = partial(compare_translation, \"LT\")\n", + "\n", + "def reduce_sum_translation(c, in_avals, out_avals, in_vals, *, axis):\n", + " del c\n", + " (x_aval,), (out_aval,), (x,) = in_avals, out_avals, in_vals\n", + " op = hlo.ReduceOp(\n", + " [aval_to_ir_type(out_aval)], [x], [_hlo_const(np.array(0, x_aval.dtype))],\n", + " axis)\n", + " scalar_type = aval_to_ir_type(ShapedArray((), x_aval.dtype))\n", + " reducer_region = op.body.blocks.append(scalar_type, scalar_type)\n", + " with ir.InsertionPoint(reducer_region):\n", + " hlo.return_([hlo.add(*reducer_region.arguments)])\n", + " return op.results\n", + "\n", + "hlo_translations[reduce_sum_p] = reduce_sum_translation\n", + "\n", + "def broadcast_translation(c, in_avals, out_avals, in_vals, *, shape, axes):\n", + " del c\n", + " (x,), (out_aval,) = in_vals, out_avals\n", " dims_complement = [i for i in range(len(shape)) if i not in axes]\n", - " return [xops.BroadcastInDim(x, shape, dims_complement)]\n", - "xla_translations[broadcast_p] = broadcast_translation" + " return [hlo.broadcast_in_dim(aval_to_ir_type(out_aval), x, dims_complement)]\n", + "hlo_translations[broadcast_p] = broadcast_translation" ] }, { @@ -2286,19 +2326,16 @@ " return jaxpr_type.out_types\n", "abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule\n", "\n", - "def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):\n", - " del num_consts # Only used at top-level.\n", + "def xla_call_translation(c, in_avals, out_avals, in_vals, *, jaxpr, num_consts):\n", + " del num_consts, out_avals\n", " # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.\n", - " subc = xc.XlaBuilder('inner xla_call')\n", - " xla_params = _xla_params(subc, in_avals)\n", - " outs = jaxpr_subcomp(subc, jaxpr, xla_params)\n", - " subc = subc.build(xops.Tuple(subc, outs))\n", - " return destructure_tuple(c, xops.Call(c, subc, in_vals))\n", - "xla_translations[xla_call_p] = xla_call_translation\n", - "\n", - "def destructure_tuple(c, tup):\n", - " num_elements = len(c.get_shape(tup).tuple_shapes())\n", - " return [xops.GetTupleElement(tup, i) for i in range(num_elements)]" + " with ir.InsertionPoint(c.module.body):\n", + " @func.func(*(aval_to_ir_type(aval) for aval in in_avals))\n", + " def inner_xla_call(*params):\n", + " return jaxpr_subcomp(c, jaxpr, params)\n", + " name = c.symbol_table.insert(inner_xla_call.func_op)\n", + " return func.CallOp(inner_xla_call.func_op, in_vals).results\n", + "hlo_translations[xla_call_p] = xla_call_translation" ] }, { @@ -3639,28 +3676,18 @@ " return jaxpr_type.out_types\n", "abstract_eval_rules[cond_p] = cond_abstract_eval\n", "\n", - "def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):\n", + "def cond_translation(c, in_avals, out_avals, in_vals, *, true_jaxpr, false_jaxpr):\n", " del in_avals # Unused\n", " pred, *in_vals = in_vals\n", - " flat_vals, in_tree = tree_flatten(in_vals)\n", - " operand = xops.Tuple(c, flat_vals)\n", - " operand_shape = c.get_shape(operand)\n", - "\n", - " def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:\n", - " c = xc.XlaBuilder(name)\n", - " operand = xops.Parameter(c, 0, operand_shape)\n", - " operands = tree_unflatten(in_tree, destructure_tuple(c, operand))\n", - " outs = jaxpr_subcomp(c, jaxpr, operands)\n", - " return c.build(xops.Tuple(c, outs))\n", - "\n", - " true_comp = make_comp('true_fn', true_jaxpr)\n", - " false_comp = make_comp('false_fn', false_jaxpr)\n", - "\n", - " int_etype = xc.dtype_to_etype(np.dtype('int32'))\n", - " out = xops.Conditional(xops.ConvertElementType(pred, int_etype),\n", - " [false_comp, true_comp], [operand] * 2)\n", - " return destructure_tuple(c, out)\n", - "xla_translations[cond_p] = cond_translation" + "\n", + " op = hlo.IfOp([aval_to_ir_type(aval) for aval in out_avals], pred)\n", + " with ir.InsertionPoint(op.true_branch.blocks.append()):\n", + " hlo.return_(jaxpr_subcomp(c, true_jaxpr, in_vals))\n", + " with ir.InsertionPoint(op.false_branch.blocks.append()):\n", + " hlo.return_(jaxpr_subcomp(c, false_jaxpr, in_vals))\n", + " return op.results\n", + "\n", + "hlo_translations[cond_p] = cond_translation" ] }, { diff --git a/docs/autodidax.md b/docs/autodidax.md index 937e1012a230..9e726e5ed82e 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1552,10 +1552,15 @@ class IDHashable: Next, we'll define the evaluation rule for `xla_call`: ```{code-cell} +import io +from jax.extend.mlir import ir +from jax.extend.mlir.dialects import func +from jax.extend.mlir.dialects import stablehlo as hlo from jax._src import xla_bridge as xb -from jax._src.lib import xla_client as xc -xe = xc._xla -xops = xc._xla.ops + +class MlirContext(NamedTuple): + module: ir.Module + symbol_table: ir.SymbolTable def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int): consts, args = args[:num_consts], args[num_consts:] @@ -1571,26 +1576,48 @@ def xla_callable(hashable_jaxpr: IDHashable, typecheck_jaxpr(jaxpr) consts = [x.val for x in hashable_consts] in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]] - c = xc.XlaBuilder('xla_call') - xla_consts = _xla_consts(c, consts) - xla_params = _xla_params(c, in_avals) - outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params) - out = xops.Tuple(c, outs) - compiled = xb.get_backend(None).compile( - xc._xla.mlir.xla_computation_to_mlir_module(c.build(out))) + + with ir.Context() as ctx, ir.Location.unknown(ctx): + hlo.register_dialect(ctx) + m = ir.Module.create() + c = MlirContext(m, ir.SymbolTable(m.operation)) + + with ir.InsertionPoint(c.module.body): + @func.func(*(aval_to_ir_type(aval) for aval in in_avals)) + def main(*params): + return jaxpr_subcomp(c, jaxpr, _hlo_consts(consts) + params) + + output = io.StringIO() + c.module.operation.print(file=output) + compiled = xb.get_backend(None).compile(output.getvalue()) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) -def _xla_consts(c: xe.XlaBuilder, consts: list[Any]) -> list[xe.XlaOp]: - unique_consts = {id(cnst): cnst for cnst in consts} - xla_consts = { - id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()} - return [xla_consts[id(cnst)] for cnst in consts] +def _mlir_dtype(dtype: np.dtype) -> ir.Type: + if np.issubdtype(dtype, np.signedinteger): + return ir.IntegerType.get_signless(np.iinfo(dtype).bits) + elif dtype == np.float32: + return ir.F32Type.get() + elif dtype == np.float64: + return ir.F64Type.get() + else: + raise NotImplementedError("MLIR conversion not implemented for ", dtype) -def _xla_params(c: xe.XlaBuilder, avals_in: list[ShapedArray]) -> list[xe.XlaOp]: - return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)] +def aval_to_ir_type(aval: ShapedArray) -> ir.Type: + return ir.RankedTensorType.get(aval.shape, _mlir_dtype(aval.dtype)) -def _xla_shape(aval: ShapedArray) -> xe.Shape: - return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape) +def _hlo_const(x: Any) -> ir.Value: + a = np.asarray(x) + if a.dtype == np.bool_: + return hlo.constant(ir.DenseElementsAttr.get( + np.packbits(a, bitorder='little'), type=ir.IntegerType.get_signless(1), + shape=a.shape)) + else: + return hlo.constant(ir.DenseElementsAttr.get(a)) + +def _hlo_consts(consts: list[Any]) -> list[ir.Value]: + unique_consts = {id(cnst): cnst for cnst in consts} + ir_consts = {id_: _hlo_const(cnst) for id_, cnst in unique_consts.items()} + return tuple(ir_consts[id(cnst)] for cnst in consts) ``` The main action is in `xla_callable`, which compiles a jaxpr into an XLA HLO @@ -1598,23 +1625,25 @@ program using `jaxpr_subcomp`, then returns a callable which executes the compiled program: ```{code-cell} -def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: list[xe.XlaOp] - ) -> list[xe.XlaOp]: - env: dict[Var, xe.XlaOp] = {} +def jaxpr_subcomp(c: MlirContext, jaxpr: Jaxpr, args: list[ir.Value]) -> list[ir.Value]: + env: dict[Var, ir.Value] = {} - def read(x: Atom) -> xe.XlaOp: - return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val)) + def read(x: Atom) -> ir.Value: + return env[x] if type(x) is Var else _hlo_const(np.asarray(x.val)) - def write(v: Var, val: xe.XlaOp) -> None: + def write(v: Var, val: ir.Value) -> None: env[v] = val map(write, jaxpr.in_binders, args) for eqn in jaxpr.eqns: in_avals = [x.aval for x in eqn.inputs] in_vals = map(read, eqn.inputs) - rule = xla_translations[eqn.primitive] - out_vals = rule(c, in_avals, in_vals, **eqn.params) - map(write, eqn.out_binders, out_vals) + out_avals = [x.aval for x in eqn.out_binders] + rule = hlo_translations[eqn.primitive] + assert all(isinstance(v, ir.Value) for v in in_vals), in_vals + out_vals = rule(c, in_avals, out_avals, in_vals, **eqn.params) + assert all(isinstance(v, ir.Value) for v in out_vals), out_vals + map(write, eqn.out_binders, out_vals), out_vals return map(read, jaxpr.outs) def execute_compiled(compiled, out_avals, *args): @@ -1630,7 +1659,7 @@ def handle_result(aval: ShapedArray, buf): del aval # Unused for now return np.asarray(buf) -xla_translations = {} +hlo_translations = {} ``` Notice that `jaxpr_subcomp` has the structure of a simple interpreter. That's @@ -1639,32 +1668,43 @@ And as with any interpreter, we need an interpretation rule for each primitive: ```{code-cell} -def direct_translation(op, c, in_avals, in_vals): - del c, in_avals +def direct_translation(op, c, in_avals, out_avals, in_vals): + del c, in_avals, out_avals return [op(*in_vals)] -xla_translations[add_p] = partial(direct_translation, xops.Add) -xla_translations[mul_p] = partial(direct_translation, xops.Mul) -xla_translations[neg_p] = partial(direct_translation, xops.Neg) -xla_translations[sin_p] = partial(direct_translation, xops.Sin) -xla_translations[cos_p] = partial(direct_translation, xops.Cos) -xla_translations[greater_p] = partial(direct_translation, xops.Gt) -xla_translations[less_p] = partial(direct_translation, xops.Lt) - -def reduce_sum_translation(c, in_avals, in_vals, *, axis): - (x_aval,), (x,) = in_avals, in_vals - zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype)) - subc = xc.XlaBuilder('add') - shape = _xla_shape(ShapedArray((), x_aval.dtype)) - xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape)) - return [xops.Reduce(c, [x], [zero], subc.build(), axis)] -xla_translations[reduce_sum_p] = reduce_sum_translation - -def broadcast_translation(c, in_avals, in_vals, *, shape, axes): - x, = in_vals +hlo_translations[add_p] = partial(direct_translation, hlo.add) +hlo_translations[mul_p] = partial(direct_translation, hlo.multiply) +hlo_translations[neg_p] = partial(direct_translation, hlo.negate) +hlo_translations[sin_p] = partial(direct_translation, hlo.sine) +hlo_translations[cos_p] = partial(direct_translation, hlo.cosine) + +def compare_translation(op, c, in_avals, out_avals, in_vals): + del c, out_avals + return [hlo.compare(*in_vals, hlo.ComparisonDirectionAttr.get(op))] + +hlo_translations[greater_p] = partial(compare_translation, "GT") +hlo_translations[less_p] = partial(compare_translation, "LT") + +def reduce_sum_translation(c, in_avals, out_avals, in_vals, *, axis): + del c + (x_aval,), (out_aval,), (x,) = in_avals, out_avals, in_vals + op = hlo.ReduceOp( + [aval_to_ir_type(out_aval)], [x], [_hlo_const(np.array(0, x_aval.dtype))], + axis) + scalar_type = aval_to_ir_type(ShapedArray((), x_aval.dtype)) + reducer_region = op.body.blocks.append(scalar_type, scalar_type) + with ir.InsertionPoint(reducer_region): + hlo.return_([hlo.add(*reducer_region.arguments)]) + return op.results + +hlo_translations[reduce_sum_p] = reduce_sum_translation + +def broadcast_translation(c, in_avals, out_avals, in_vals, *, shape, axes): + del c + (x,), (out_aval,) = in_vals, out_avals dims_complement = [i for i in range(len(shape)) if i not in axes] - return [xops.BroadcastInDim(x, shape, dims_complement)] -xla_translations[broadcast_p] = broadcast_translation + return [hlo.broadcast_in_dim(aval_to_ir_type(out_aval), x, dims_complement)] +hlo_translations[broadcast_p] = broadcast_translation ``` With that, we can now use `jit` to stage out, compile, and execute programs @@ -1783,19 +1823,16 @@ def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts): return jaxpr_type.out_types abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule -def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts): - del num_consts # Only used at top-level. +def xla_call_translation(c, in_avals, out_avals, in_vals, *, jaxpr, num_consts): + del num_consts, out_avals # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead. - subc = xc.XlaBuilder('inner xla_call') - xla_params = _xla_params(subc, in_avals) - outs = jaxpr_subcomp(subc, jaxpr, xla_params) - subc = subc.build(xops.Tuple(subc, outs)) - return destructure_tuple(c, xops.Call(c, subc, in_vals)) -xla_translations[xla_call_p] = xla_call_translation - -def destructure_tuple(c, tup): - num_elements = len(c.get_shape(tup).tuple_shapes()) - return [xops.GetTupleElement(tup, i) for i in range(num_elements)] + with ir.InsertionPoint(c.module.body): + @func.func(*(aval_to_ir_type(aval) for aval in in_avals)) + def inner_xla_call(*params): + return jaxpr_subcomp(c, jaxpr, params) + name = c.symbol_table.insert(inner_xla_call.func_op) + return func.CallOp(inner_xla_call.func_op, in_vals).results +hlo_translations[xla_call_p] = xla_call_translation ``` ```{code-cell} @@ -2853,28 +2890,18 @@ def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr): return jaxpr_type.out_types abstract_eval_rules[cond_p] = cond_abstract_eval -def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr): +def cond_translation(c, in_avals, out_avals, in_vals, *, true_jaxpr, false_jaxpr): del in_avals # Unused pred, *in_vals = in_vals - flat_vals, in_tree = tree_flatten(in_vals) - operand = xops.Tuple(c, flat_vals) - operand_shape = c.get_shape(operand) - - def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation: - c = xc.XlaBuilder(name) - operand = xops.Parameter(c, 0, operand_shape) - operands = tree_unflatten(in_tree, destructure_tuple(c, operand)) - outs = jaxpr_subcomp(c, jaxpr, operands) - return c.build(xops.Tuple(c, outs)) - - true_comp = make_comp('true_fn', true_jaxpr) - false_comp = make_comp('false_fn', false_jaxpr) - - int_etype = xc.dtype_to_etype(np.dtype('int32')) - out = xops.Conditional(xops.ConvertElementType(pred, int_etype), - [false_comp, true_comp], [operand] * 2) - return destructure_tuple(c, out) -xla_translations[cond_p] = cond_translation + + op = hlo.IfOp([aval_to_ir_type(aval) for aval in out_avals], pred) + with ir.InsertionPoint(op.true_branch.blocks.append()): + hlo.return_(jaxpr_subcomp(c, true_jaxpr, in_vals)) + with ir.InsertionPoint(op.false_branch.blocks.append()): + hlo.return_(jaxpr_subcomp(c, false_jaxpr, in_vals)) + return op.results + +hlo_translations[cond_p] = cond_translation ``` ```{code-cell} diff --git a/docs/autodidax.py b/docs/autodidax.py index c10e6365e62d..f57af2cd96f2 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1544,10 +1544,15 @@ def __eq__(self, other): # Next, we'll define the evaluation rule for `xla_call`: # + +import io +from jax.extend.mlir import ir +from jax.extend.mlir.dialects import func +from jax.extend.mlir.dialects import stablehlo as hlo from jax._src import xla_bridge as xb -from jax._src.lib import xla_client as xc -xe = xc._xla -xops = xc._xla.ops + +class MlirContext(NamedTuple): + module: ir.Module + symbol_table: ir.SymbolTable def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int): consts, args = args[:num_consts], args[num_consts:] @@ -1563,26 +1568,48 @@ def xla_callable(hashable_jaxpr: IDHashable, typecheck_jaxpr(jaxpr) consts = [x.val for x in hashable_consts] in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]] - c = xc.XlaBuilder('xla_call') - xla_consts = _xla_consts(c, consts) - xla_params = _xla_params(c, in_avals) - outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params) - out = xops.Tuple(c, outs) - compiled = xb.get_backend(None).compile( - xc._xla.mlir.xla_computation_to_mlir_module(c.build(out))) + + with ir.Context() as ctx, ir.Location.unknown(ctx): + hlo.register_dialect(ctx) + m = ir.Module.create() + c = MlirContext(m, ir.SymbolTable(m.operation)) + + with ir.InsertionPoint(c.module.body): + @func.func(*(aval_to_ir_type(aval) for aval in in_avals)) + def main(*params): + return jaxpr_subcomp(c, jaxpr, _hlo_consts(consts) + params) + + output = io.StringIO() + c.module.operation.print(file=output) + compiled = xb.get_backend(None).compile(output.getvalue()) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) -def _xla_consts(c: xe.XlaBuilder, consts: list[Any]) -> list[xe.XlaOp]: - unique_consts = {id(cnst): cnst for cnst in consts} - xla_consts = { - id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()} - return [xla_consts[id(cnst)] for cnst in consts] +def _mlir_dtype(dtype: np.dtype) -> ir.Type: + if np.issubdtype(dtype, np.signedinteger): + return ir.IntegerType.get_signless(np.iinfo(dtype).bits) + elif dtype == np.float32: + return ir.F32Type.get() + elif dtype == np.float64: + return ir.F64Type.get() + else: + raise NotImplementedError("MLIR conversion not implemented for ", dtype) -def _xla_params(c: xe.XlaBuilder, avals_in: list[ShapedArray]) -> list[xe.XlaOp]: - return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)] +def aval_to_ir_type(aval: ShapedArray) -> ir.Type: + return ir.RankedTensorType.get(aval.shape, _mlir_dtype(aval.dtype)) -def _xla_shape(aval: ShapedArray) -> xe.Shape: - return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape) +def _hlo_const(x: Any) -> ir.Value: + a = np.asarray(x) + if a.dtype == np.bool_: + return hlo.constant(ir.DenseElementsAttr.get( + np.packbits(a, bitorder='little'), type=ir.IntegerType.get_signless(1), + shape=a.shape)) + else: + return hlo.constant(ir.DenseElementsAttr.get(a)) + +def _hlo_consts(consts: list[Any]) -> list[ir.Value]: + unique_consts = {id(cnst): cnst for cnst in consts} + ir_consts = {id_: _hlo_const(cnst) for id_, cnst in unique_consts.items()} + return tuple(ir_consts[id(cnst)] for cnst in consts) # - @@ -1592,23 +1619,25 @@ def _xla_shape(aval: ShapedArray) -> xe.Shape: # compiled program: # + -def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: list[xe.XlaOp] - ) -> list[xe.XlaOp]: - env: dict[Var, xe.XlaOp] = {} +def jaxpr_subcomp(c: MlirContext, jaxpr: Jaxpr, args: list[ir.Value]) -> list[ir.Value]: + env: dict[Var, ir.Value] = {} - def read(x: Atom) -> xe.XlaOp: - return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val)) + def read(x: Atom) -> ir.Value: + return env[x] if type(x) is Var else _hlo_const(np.asarray(x.val)) - def write(v: Var, val: xe.XlaOp) -> None: + def write(v: Var, val: ir.Value) -> None: env[v] = val map(write, jaxpr.in_binders, args) for eqn in jaxpr.eqns: in_avals = [x.aval for x in eqn.inputs] in_vals = map(read, eqn.inputs) - rule = xla_translations[eqn.primitive] - out_vals = rule(c, in_avals, in_vals, **eqn.params) - map(write, eqn.out_binders, out_vals) + out_avals = [x.aval for x in eqn.out_binders] + rule = hlo_translations[eqn.primitive] + assert all(isinstance(v, ir.Value) for v in in_vals), in_vals + out_vals = rule(c, in_avals, out_avals, in_vals, **eqn.params) + assert all(isinstance(v, ir.Value) for v in out_vals), out_vals + map(write, eqn.out_binders, out_vals), out_vals return map(read, jaxpr.outs) def execute_compiled(compiled, out_avals, *args): @@ -1624,7 +1653,7 @@ def handle_result(aval: ShapedArray, buf): del aval # Unused for now return np.asarray(buf) -xla_translations = {} +hlo_translations = {} # - @@ -1635,32 +1664,43 @@ def handle_result(aval: ShapedArray, buf): # primitive: # + -def direct_translation(op, c, in_avals, in_vals): - del c, in_avals +def direct_translation(op, c, in_avals, out_avals, in_vals): + del c, in_avals, out_avals return [op(*in_vals)] -xla_translations[add_p] = partial(direct_translation, xops.Add) -xla_translations[mul_p] = partial(direct_translation, xops.Mul) -xla_translations[neg_p] = partial(direct_translation, xops.Neg) -xla_translations[sin_p] = partial(direct_translation, xops.Sin) -xla_translations[cos_p] = partial(direct_translation, xops.Cos) -xla_translations[greater_p] = partial(direct_translation, xops.Gt) -xla_translations[less_p] = partial(direct_translation, xops.Lt) - -def reduce_sum_translation(c, in_avals, in_vals, *, axis): - (x_aval,), (x,) = in_avals, in_vals - zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype)) - subc = xc.XlaBuilder('add') - shape = _xla_shape(ShapedArray((), x_aval.dtype)) - xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape)) - return [xops.Reduce(c, [x], [zero], subc.build(), axis)] -xla_translations[reduce_sum_p] = reduce_sum_translation - -def broadcast_translation(c, in_avals, in_vals, *, shape, axes): - x, = in_vals +hlo_translations[add_p] = partial(direct_translation, hlo.add) +hlo_translations[mul_p] = partial(direct_translation, hlo.multiply) +hlo_translations[neg_p] = partial(direct_translation, hlo.negate) +hlo_translations[sin_p] = partial(direct_translation, hlo.sine) +hlo_translations[cos_p] = partial(direct_translation, hlo.cosine) + +def compare_translation(op, c, in_avals, out_avals, in_vals): + del c, out_avals + return [hlo.compare(*in_vals, hlo.ComparisonDirectionAttr.get(op))] + +hlo_translations[greater_p] = partial(compare_translation, "GT") +hlo_translations[less_p] = partial(compare_translation, "LT") + +def reduce_sum_translation(c, in_avals, out_avals, in_vals, *, axis): + del c + (x_aval,), (out_aval,), (x,) = in_avals, out_avals, in_vals + op = hlo.ReduceOp( + [aval_to_ir_type(out_aval)], [x], [_hlo_const(np.array(0, x_aval.dtype))], + axis) + scalar_type = aval_to_ir_type(ShapedArray((), x_aval.dtype)) + reducer_region = op.body.blocks.append(scalar_type, scalar_type) + with ir.InsertionPoint(reducer_region): + hlo.return_([hlo.add(*reducer_region.arguments)]) + return op.results + +hlo_translations[reduce_sum_p] = reduce_sum_translation + +def broadcast_translation(c, in_avals, out_avals, in_vals, *, shape, axes): + del c + (x,), (out_aval,) = in_vals, out_avals dims_complement = [i for i in range(len(shape)) if i not in axes] - return [xops.BroadcastInDim(x, shape, dims_complement)] -xla_translations[broadcast_p] = broadcast_translation + return [hlo.broadcast_in_dim(aval_to_ir_type(out_aval), x, dims_complement)] +hlo_translations[broadcast_p] = broadcast_translation # - @@ -1777,19 +1817,16 @@ def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts): return jaxpr_type.out_types abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule -def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts): - del num_consts # Only used at top-level. +def xla_call_translation(c, in_avals, out_avals, in_vals, *, jaxpr, num_consts): + del num_consts, out_avals # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead. - subc = xc.XlaBuilder('inner xla_call') - xla_params = _xla_params(subc, in_avals) - outs = jaxpr_subcomp(subc, jaxpr, xla_params) - subc = subc.build(xops.Tuple(subc, outs)) - return destructure_tuple(c, xops.Call(c, subc, in_vals)) -xla_translations[xla_call_p] = xla_call_translation - -def destructure_tuple(c, tup): - num_elements = len(c.get_shape(tup).tuple_shapes()) - return [xops.GetTupleElement(tup, i) for i in range(num_elements)] + with ir.InsertionPoint(c.module.body): + @func.func(*(aval_to_ir_type(aval) for aval in in_avals)) + def inner_xla_call(*params): + return jaxpr_subcomp(c, jaxpr, params) + name = c.symbol_table.insert(inner_xla_call.func_op) + return func.CallOp(inner_xla_call.func_op, in_vals).results +hlo_translations[xla_call_p] = xla_call_translation # + @@ -2845,28 +2882,18 @@ def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr): return jaxpr_type.out_types abstract_eval_rules[cond_p] = cond_abstract_eval -def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr): +def cond_translation(c, in_avals, out_avals, in_vals, *, true_jaxpr, false_jaxpr): del in_avals # Unused pred, *in_vals = in_vals - flat_vals, in_tree = tree_flatten(in_vals) - operand = xops.Tuple(c, flat_vals) - operand_shape = c.get_shape(operand) - - def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation: - c = xc.XlaBuilder(name) - operand = xops.Parameter(c, 0, operand_shape) - operands = tree_unflatten(in_tree, destructure_tuple(c, operand)) - outs = jaxpr_subcomp(c, jaxpr, operands) - return c.build(xops.Tuple(c, outs)) - - true_comp = make_comp('true_fn', true_jaxpr) - false_comp = make_comp('false_fn', false_jaxpr) - - int_etype = xc.dtype_to_etype(np.dtype('int32')) - out = xops.Conditional(xops.ConvertElementType(pred, int_etype), - [false_comp, true_comp], [operand] * 2) - return destructure_tuple(c, out) -xla_translations[cond_p] = cond_translation + + op = hlo.IfOp([aval_to_ir_type(aval) for aval in out_avals], pred) + with ir.InsertionPoint(op.true_branch.blocks.append()): + hlo.return_(jaxpr_subcomp(c, true_jaxpr, in_vals)) + with ir.InsertionPoint(op.false_branch.blocks.append()): + hlo.return_(jaxpr_subcomp(c, false_jaxpr, in_vals)) + return op.results + +hlo_translations[cond_p] = cond_translation # - out = jit(lambda: cond(False, lambda: 1, lambda: 2))() diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index a8cd5219d4b5..8b5d5ea6c907 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -303,9 +303,9 @@ " # type (which corresponds to numpy's `float32` type), and it must be a\n", " # static parameter (i.e. not a JAX array).\n", " eps=np.float32(eps),\n", - " # The `vectorized` parameter controls this function's behavior under `vmap`\n", + " # The `vmap_method` parameter controls this function's behavior under `vmap`\n", " # as discussed below.\n", - " vectorized=True,\n", + " vmap_method=\"broadcast_fullrank\",\n", " )\n", "\n", "\n", @@ -325,7 +325,7 @@ "Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n", "Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n", "\n", - "The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n", + "The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n", "\n", "```{tip}\n", "If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n", @@ -336,19 +336,29 @@ "(ffi-call-vmap)=\n", "### Batching with `vmap`\n", "\n", - "All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.\n", - "By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n", - "This default implementation is general purpose, but it doesn't parallelize very well.\n", - "But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.\n", + "{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n", + "The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n", "\n", - "The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.\n", + "The simplest `vmap_method` is `\"sequential\"`.\n", + "In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n", + "This implementation is general purpose, but it doesn't parallelize very well.\n", + "Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"broadcast\"` or `\"broadcast_fullrank\"` methods can be used to expose a better implementation.\n", + "\n", + "In this case, since we only have one input argument, `\"broadcast\"` and `\"broadcast_fullrank\"` actually have the same behavior.\n", + "The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.\n", "Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n", "\n", "```python\n", "ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])\n", "```\n", "\n", - "Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:" + "```{tip}\n", + "Note that things get a bit more complicated when we have multiple input arguments.\n", + "For simplicity, we will use the `\"broadcast_fullrank\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"broadcast\"` method.\n", + "The documentation for {func}`~jax.pure_callback` includes some examples of this\n", + "```\n", + "\n", + "Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_fullrank\"` out of the box:" ] }, { @@ -380,7 +390,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:" + "Using `vmap_method=\"sequential\"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:" ] }, { @@ -389,24 +399,24 @@ "metadata": {}, "outputs": [], "source": [ - "def rms_norm_not_vectorized(x, eps=1e-5):\n", + "def rms_norm_sequential(x, eps=1e-5):\n", " return jex.ffi.ffi_call(\n", " \"rms_norm\",\n", " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", " x,\n", " eps=np.float32(eps),\n", - " vectorized=False, # This is the default behavior\n", + " vmap_method=\"sequential\",\n", " )\n", "\n", "\n", - "jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)" + "jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)." + "If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)." ] }, { @@ -454,7 +464,7 @@ " ),\n", " x,\n", " eps=np.float32(eps),\n", - " vectorized=True,\n", + " vmap_method=\"broadcast_fullrank\",\n", " )\n", " return y, (res, x)\n", "\n", @@ -471,7 +481,7 @@ " res,\n", " x,\n", " ct,\n", - " vectorized=True,\n", + " vmap_method=\"broadcast_fullrank\",\n", " ),\n", " )\n", "\n", @@ -561,7 +571,7 @@ " out_type,\n", " x,\n", " eps=np.float32(eps),\n", - " vectorized=True,\n", + " vmap_method=\"broadcast_fullrank\",\n", " )\n", "\n", " return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n", diff --git a/docs/ffi.md b/docs/ffi.md index cc3863ed99b2..b3d1dcf46364 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -264,9 +264,9 @@ def rms_norm(x, eps=1e-5): # type (which corresponds to numpy's `float32` type), and it must be a # static parameter (i.e. not a JAX array). eps=np.float32(eps), - # The `vectorized` parameter controls this function's behavior under `vmap` + # The `vmap_method` parameter controls this function's behavior under `vmap` # as discussed below. - vectorized=True, + vmap_method="broadcast_fullrank", ) @@ -282,7 +282,7 @@ It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_cal Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`. Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments. -The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next. +The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next. ```{tip} If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`. @@ -293,19 +293,29 @@ One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support so (ffi-call-vmap)= ### Batching with `vmap` -All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient. -By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body. -This default implementation is general purpose, but it doesn't parallelize very well. -But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation. +{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter. +The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`. -The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes. +The simplest `vmap_method` is `"sequential"`. +In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body. +This implementation is general purpose, but it doesn't parallelize very well. +Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"broadcast"` or `"broadcast_fullrank"` methods can be used to expose a better implementation. + +In this case, since we only have one input argument, `"broadcast"` and `"broadcast_fullrank"` actually have the same behavior. +The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions. Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly: ```python ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs]) ``` -Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box: +```{tip} +Note that things get a bit more complicated when we have multiple input arguments. +For simplicity, we will use the `"broadcast_fullrank"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"broadcast"` method. +The documentation for {func}`~jax.pure_callback` includes some examples of this +``` + +Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_fullrank"` out of the box: ```{code-cell} ipython3 np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5) @@ -317,23 +327,23 @@ We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms jax.make_jaxpr(jax.vmap(rms_norm))(x) ``` -If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body: +Using `vmap_method="sequential"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body: ```{code-cell} ipython3 -def rms_norm_not_vectorized(x, eps=1e-5): +def rms_norm_sequential(x, eps=1e-5): return jex.ffi.ffi_call( "rms_norm", jax.ShapeDtypeStruct(x.shape, x.dtype), x, eps=np.float32(eps), - vectorized=False, # This is the default behavior + vmap_method="sequential", ) -jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x) +jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x) ``` -If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues). +If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues). +++ @@ -372,7 +382,7 @@ def rms_norm_fwd(x, eps=1e-5): ), x, eps=np.float32(eps), - vectorized=True, + vmap_method="broadcast_fullrank", ) return y, (res, x) @@ -389,7 +399,7 @@ def rms_norm_bwd(eps, res, ct): res, x, ct, - vectorized=True, + vmap_method="broadcast_fullrank", ), ) @@ -469,7 +479,7 @@ def rms_norm_cross_platform(x, eps=1e-5): out_type, x, eps=np.float32(eps), - vectorized=True, + vmap_method="broadcast_fullrank", ) return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda")) diff --git a/docs/jax.experimental.host_callback.rst b/docs/jax.experimental.host_callback.rst deleted file mode 100644 index 8ac26b2c3702..000000000000 --- a/docs/jax.experimental.host_callback.rst +++ /dev/null @@ -1,20 +0,0 @@ -``jax.experimental.host_callback`` module -========================================= - - -.. automodule:: jax.experimental.host_callback - -API ---- - -.. autosummary:: - :toctree: _autosummary - - id_tap - id_print - call - barrier_wait - CallbackException - - - diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst new file mode 100644 index 000000000000..65ae4195b7be --- /dev/null +++ b/docs/jax.experimental.pallas.mosaic_gpu.rst @@ -0,0 +1,42 @@ +``jax.experimental.pallas.mosaic_gpu`` module +============================================= + +.. automodule:: jax.experimental.pallas.mosaic_gpu + +Classes +------- + +.. autosummary:: + :toctree: _autosummary + + Barrier + GPUBlockSpec + GPUCompilerParams + GPUMemorySpace + SwizzleTransform + TilingTransform + TransposeTransform + WGMMAAccumulatorRef + +Functions +--------- + +.. autosummary:: + :toctree: _autosummary + + copy_gmem_to_smem + copy_smem_to_gmem + wait_barrier + wait_smem_to_gmem + wgmma + wgmma_wait + +Aliases +------- + +.. autosummary:: + :toctree: _autosummary + + ACC + GMEM + SMEM diff --git a/docs/jax.experimental.pallas.rst b/docs/jax.experimental.pallas.rst index d250d682d32a..c945f939fa4d 100644 --- a/docs/jax.experimental.pallas.rst +++ b/docs/jax.experimental.pallas.rst @@ -3,6 +3,16 @@ .. automodule:: jax.experimental.pallas +Backends +-------- + +.. toctree:: + :maxdepth: 1 + + jax.experimental.pallas.mosaic_gpu + jax.experimental.pallas.triton + jax.experimental.pallas.tpu + Classes ------- @@ -13,6 +23,8 @@ Classes GridSpec Slice + MemoryRef + Functions --------- @@ -34,5 +46,11 @@ Functions atomic_min atomic_or atomic_xchg - + atomic_xor + broadcast_to debug_print + dot + max_contiguous + multiple_of + run_scoped + when diff --git a/docs/jax.experimental.pallas.tpu.rst b/docs/jax.experimental.pallas.tpu.rst new file mode 100644 index 000000000000..ae4e2c2253e4 --- /dev/null +++ b/docs/jax.experimental.pallas.tpu.rst @@ -0,0 +1,16 @@ +``jax.experimental.pallas.tpu`` module +====================================== + +.. automodule:: jax.experimental.pallas.tpu + +Classes +------- + +.. autosummary:: + :toctree: _autosummary + +Functions +--------- + +.. autosummary:: + :toctree: _autosummary \ No newline at end of file diff --git a/docs/jax.experimental.pallas.triton.rst b/docs/jax.experimental.pallas.triton.rst new file mode 100644 index 000000000000..76b0896ccf17 --- /dev/null +++ b/docs/jax.experimental.pallas.triton.rst @@ -0,0 +1,22 @@ +``jax.experimental.pallas.triton`` module +========================================= + +.. automodule:: jax.experimental.pallas.triton + +Classes +------- + +.. autosummary:: + :toctree: _autosummary + + TritonCompilerParams + +Functions +--------- + +.. autosummary:: + :toctree: _autosummary + + approx_tanh + debug_barrier + elementwise_inline_asm \ No newline at end of file diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index 78db1d4907a4..7672c94c6b52 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -16,18 +16,17 @@ Experimental Modules jax.experimental.array_api jax.experimental.checkify - jax.experimental.host_callback - jax.experimental.pjit - jax.experimental.sparse - jax.experimental.jet - jax.experimental.custom_partitioning - jax.experimental.multihost_utils jax.experimental.compilation_cache + jax.experimental.custom_partitioning + jax.experimental.jet jax.experimental.key_reuse jax.experimental.mesh_utils + jax.experimental.multihost_utils + jax.experimental.pallas + jax.experimental.pjit jax.experimental.serialize_executable jax.experimental.shard_map - jax.experimental.pallas + jax.experimental.sparse Experimental APIs ----------------- diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 3a03665b3217..065127718c54 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -250,9 +250,20 @@ Argument classes .. autoclass:: ConvDimensionNumbers .. autoclass:: ConvGeneralDilatedDimensionNumbers .. autoclass:: DotAlgorithm +.. autoclass:: DotAlgorithmPreset + :members: + :undoc-members: + :member-order: bysource +.. autoclass:: FftType + :members: .. autoclass:: GatherDimensionNumbers .. autoclass:: GatherScatterMode .. autoclass:: Precision .. autoclass:: PrecisionLike +.. autoclass:: RandomAlgorithm + :members: + :member-order: bysource .. autoclass:: RoundingMethod + :members: + :member-order: bysource .. autoclass:: ScatterDimensionNumbers diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index d6b7d74bd429..9eb518464b4e 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -376,6 +376,7 @@ namespace; they are listed below. size sort sort_complex + spacing split sqrt square diff --git a/docs/jax.rst b/docs/jax.rst index ecfeaf29e3c0..673c71685c52 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -73,13 +73,12 @@ Just-in-time compilation (:code:`jit`) eval_shape ShapeDtypeStruct device_put - device_put_replicated - device_put_sharded device_get default_backend named_call named_scope block_until_ready + make_mesh .. _jax-grad: diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 1315783c340c..aa355f471a20 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -510,6 +510,8 @@ "the corresponding `PartitionSpec` `spec` as roughly\n", "`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.\n", "\n", + "(shard_map_collectives_tutorial)=\n", + "\n", "## Collectives tutorial\n", "\n", "A `shard_map` need not be a pure map: function applications can communicate\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 96667e709ac6..d77dec652068 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -357,6 +357,8 @@ from the shape `shape` of the corresponding argument to `shard_map`-of-`f` and the corresponding `PartitionSpec` `spec` as roughly `tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`. +(shard_map_collectives_tutorial)= + ## Collectives tutorial A `shard_map` need not be a pure map: function applications can communicate diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 43ba3ebd6afb..d7ed91011a95 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -18,14 +18,23 @@ Remember to align the itemized text with the first line of an item within a list * {func}`jax.experimental.pallas.debug_print` no longer requires all arguments to be scalars. The restrictions on the arguments are backend-specific: Non-scalar arguments are currently only supported on GPU, when using Triton. + * {class}`jax.experimental.pallas.BlockSpec` no longer supports the previously + deprecated argument order, where `index_map` comes before `block_shape`. * Deprecations + * The {mod}`jax.experimental.pallas.gpu` submodule is deprecated to avoid + ambiguite with {mod}`jax.experimental.pallas.mosaic_gpu`. To use the + Triton backend import {mod}`jax.experimental.pallas.triton`. + * New functionality * {func}`jax.experimental.pallas.pallas_call` now accepts `scratch_shapes`, a PyTree specifying backend-specific temporary objects needed by the kernel, for example, buffers, synchronization primitives etc. + * {func}`checkify.check` can now be used to insert runtime asserts when + pallas_call is called with the `pltpu.enable_runtime_assert(True)` context + manager. ## Released with jax 0.4.33 (September 16, 2024) diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md index 267199128283..cde200528785 100644 --- a/docs/pallas/grid_blockspec.md +++ b/docs/pallas/grid_blockspec.md @@ -75,7 +75,8 @@ programs write to disjoint places in HBM to avoid these parallel writes. On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations. -See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions). + +See {ref}`pallas_tpu_noteworthy_properties`. (pallas_blockspec)= @@ -88,8 +89,7 @@ to *which block of our inputs and outputs to be operated on*. This is provided via {class}`jax.experimental.pallas.BlockSpec` objects. Before we get into the details of `BlockSpec`s, you may want -to revisit the -[Pallas Quickstart BlockSpecs example](https://jax.readthedocs.io/en/latest/pallas/quickstart.html#block-specs-by-example). +to revisit {ref}`pallas_block_specs_by_example` in Pallas Quickstart. `BlockSpec`s are provided to `pallas_call` via the `in_specs` and `out_specs`, one for each input and output respectively. @@ -239,7 +239,7 @@ The output shown below was generated on CPU using `interpret=True` mode, which at the moment executes the invocation sequentially. On TPUs, programs are executed in a combination of parallel and sequential, and this function generates the output shown. -See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions). +See {ref}`pallas_tpu_noteworthy_properties`. ```python >>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2, 10), diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 0e759a493a61..50464ce8ffd4 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -319,6 +319,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "(pallas_block_specs_by_example)=\n", + "\n", "### Block specs by example" ] }, diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index a8b13ea38eaf..b9acd6497fb5 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -209,6 +209,8 @@ You can read more details at {ref}`pallas_grid`. +++ +(pallas_block_specs_by_example)= + ### Block specs by example +++ diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index 4a2d4daa637f..b7ce10d564f6 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -59,6 +59,8 @@ ideas described transfer to later generations as well. * `TPU v4: An Optically Reconfigurable Supercomputer for Machine Learning with Hardware Support for Embeddings `_ +.. _pallas_tpu_noteworthy_properties: + Noteworthy properties and restrictions -------------------------------------- diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index 8552e10d8552..95abf803a780 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -11,8 +11,8 @@ "In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n", "\n", "Some recommended readings beforehand:\n", - " - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html)\n", - " - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial)" + " - [Pallas Pipelining on TPU](pallas_tpu_pipelining)\n", + " - [Collectives with `shard_map`](shard_map_collectives_tutorial)" ] }, { @@ -1703,7 +1703,7 @@ "\n", "### Megacore\n", "\n", - "Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n", + "Certain TPUs contain multiple cores in a [Megacore](pallas_tpu_megacore) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n", "\n", "### Interaction with XLA\n", "\n", diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index fc3f929866bd..c71f75ec6040 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -20,8 +20,8 @@ kernelspec: In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. Some recommended readings beforehand: - - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html) - - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial) + - [Pallas Pipelining on TPU](pallas_tpu_pipelining) + - [Collectives with `shard_map`](shard_map_collectives_tutorial) ```{code-cell} ipython3 --- @@ -1516,7 +1516,7 @@ print( ### Megacore -Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core. +Certain TPUs contain multiple cores in a [Megacore](pallas_tpu_megacore) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core. ### Interaction with XLA diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index b5f2c652b5a5..9774e08dcda8 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -645,6 +645,8 @@ "id": "KvPFez9N8cKJ" }, "source": [ + "(pallas_tpu_megacore)=\n", + "\n", "## TPUs in Megacore configuration" ] }, diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 19150b3832fa..21865430178d 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -436,6 +436,8 @@ dimensions. +++ {"id": "KvPFez9N8cKJ"} +(pallas_tpu_megacore)= + ## TPUs in Megacore configuration +++ {"id": "0f4HAVzQ8n71"} diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index 8d9b811374d1..62142fd49034 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -13,3 +13,7 @@ find_package(nanobind CONFIG REQUIRED) nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc") target_include_directories(_rms_norm PUBLIC ${XLA_DIR}) install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) + +nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc") +target_include_directories(_attrs PUBLIC ${XLA_DIR}) +install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) diff --git a/examples/ffi/src/jax_ffi_example/attrs.cc b/examples/ffi/src/jax_ffi_example/attrs.cc new file mode 100644 index 000000000000..2a6e8d847cf4 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/attrs.cc @@ -0,0 +1,66 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/ffi.h" + +namespace nb = nanobind; +namespace ffi = xla::ffi; + +ffi::Error ArrayAttrImpl(ffi::Span array, + ffi::Result> res) { + int64_t total = 0; + for (int32_t x : array) { + total += x; + } + res->typed_data()[0] = total; + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl, + ffi::Ffi::Bind() + .Attr>("array") + .Ret>()); + +ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs, + ffi::Result> secret, + ffi::Result> count) { + auto maybe_secret = attrs.get("secret"); + if (maybe_secret.has_error()) { + return maybe_secret.error(); + } + secret->typed_data()[0] = maybe_secret.value(); + count->typed_data()[0] = attrs.size(); + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl, + ffi::Ffi::Bind() + .Attrs() + .Ret>() + .Ret>()); + +NB_MODULE(_attrs, m) { + m.def("registrations", []() { + nb::dict registrations; + registrations["array_attr"] = + nb::capsule(reinterpret_cast(ArrayAttr)); + registrations["dictionary_attr"] = + nb::capsule(reinterpret_cast(DictionaryAttr)); + return registrations; + }); +} diff --git a/examples/ffi/src/jax_ffi_example/attrs.py b/examples/ffi/src/jax_ffi_example/attrs.py new file mode 100644 index 000000000000..30d7d6c74344 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/attrs.py @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example demonstrating the different ways that attributes can be passed to +the FFI. + +For example, we can pass arrays, variadic attributes, and user-defined types. +Full support of user-defined types isn't yet supported by XLA, so that example +will be added in the future. +""" + +import numpy as np + +import jax +import jax.extend as jex + +from jax_ffi_example import _attrs + +for name, target in _attrs.registrations().items(): + jex.ffi.register_ffi_target(name, target) + + +def array_attr(num: int): + return jex.ffi.ffi_call( + "array_attr", + jax.ShapeDtypeStruct((), np.int32), + array=np.arange(num, dtype=np.int32), + ) + + +def dictionary_attr(**kwargs): + return jex.ffi.ffi_call( + "dictionary_attr", + (jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)), + **kwargs, + ) diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py index 4e0ed1d195b4..d063f1cf319c 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.py +++ b/examples/ffi/src/jax_ffi_example/rms_norm.py @@ -60,8 +60,7 @@ def rms_norm(x, eps=1e-5): # type (which corresponds to numpy's `float32` type), and it must be a # static parameter (i.e. not a JAX array). eps=np.float32(eps), - # The `vectorized` parameter controls this function's behavior under `vmap`. - vectorized=True, + vmap_method="broadcast_fullrank", ) @@ -74,7 +73,7 @@ def rms_norm_fwd(x, eps=1e-5): ), x, eps=np.float32(eps), - vectorized=True, + vmap_method="broadcast_fullrank", ) return y, (res, x) @@ -91,7 +90,7 @@ def rms_norm_bwd(eps, res, ct): res, x, ct, - vectorized=True, + vmap_method="broadcast_fullrank", ), ) diff --git a/examples/ffi/tests/attrs_test.py b/examples/ffi/tests/attrs_test.py new file mode 100644 index 000000000000..0288b31cf9fa --- /dev/null +++ b/examples/ffi/tests/attrs_test.py @@ -0,0 +1,61 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu + +from jax_ffi_example import attrs + +jax.config.parse_flags_with_absl() + + +class AttrsTests(jtu.JaxTestCase): + def test_array_attr(self): + self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum()) + self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum()) + + def test_array_attr_jit_cache(self): + jit_array_attr = jax.jit(attrs.array_attr, static_argnums=(0,)) + with jtu.count_jit_and_pmap_lowerings() as count: + jit_array_attr(5) + self.assertEqual(count[0], 1) # compiles once the first time + with jtu.count_jit_and_pmap_lowerings() as count: + jit_array_attr(5) + self.assertEqual(count[0], 0) # cache hit + + def test_array_attr_no_jit(self): + with jax.disable_jit(): + attrs.array_attr(5) # doesn't crash + + def test_dictionary_attr(self): + secret, count = attrs.dictionary_attr(secret=5) + self.assertEqual(secret, 5) + self.assertEqual(count, 1) + + secret, count = attrs.dictionary_attr(secret=3, a_string="hello") + self.assertEqual(secret, 3) + self.assertEqual(count, 2) + + with self.assertRaisesRegex(Exception, "Unexpected attribute"): + attrs.dictionary_attr() + + with self.assertRaisesRegex(Exception, "Wrong attribute type"): + attrs.dictionary_attr(secret="invalid") + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/BUILD b/jax/BUILD index c25d0004e772..12c239a2d63e 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -147,7 +147,11 @@ py_library( srcs = ["_src/internal_test_util/test_harnesses.py"], visibility = [":internal"] + jax_internal_test_harnesses_visibility, deps = [ + ":ad_util", + ":config", ":jax", + ":test_util", + "//jax/_src/lib", ] + py_deps("numpy"), ) @@ -583,9 +587,11 @@ pytype_strict_library( ], exclude = [ "experimental/pallas/gpu.py", - "experimental/pallas/tpu.py", + "experimental/pallas/mosaic_gpu.py", "experimental/pallas/ops/gpu/**/*.py", "experimental/pallas/ops/tpu/**/*.py", + "experimental/pallas/tpu.py", + "experimental/pallas/triton.py", ], ), visibility = [ @@ -647,19 +653,45 @@ pytype_strict_library( pytype_strict_library( name = "pallas_gpu", - srcs = ["experimental/pallas/gpu.py"], visibility = [ ":pallas_gpu_users", ], deps = [ - "//jax/_src/pallas/mosaic_gpu:core", # build_cleaner: keep - "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep + ":pallas_triton", + # TODO(slebedev): Add :pallas_mosaic_gpu once it is ready. + ], +) + +pytype_strict_library( + name = "pallas_triton", + srcs = [ + "experimental/pallas/gpu.py", + "experimental/pallas/triton.py", + ], + visibility = [ + ":pallas_gpu_users", + ], + deps = [ + ":deprecations", "//jax/_src/pallas/triton:core", "//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep "//jax/_src/pallas/triton:primitives", ], ) +pytype_strict_library( + name = "pallas_mosaic_gpu", + srcs = ["experimental/pallas/mosaic_gpu.py"], + visibility = [ + ":mosaic_gpu_users", + ], + deps = [ + "//jax/_src/pallas/mosaic_gpu:core", + "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/mosaic_gpu:primitives", + ], +) + # This target only supports sm_90 GPUs. py_library( name = "mosaic_gpu", @@ -851,6 +883,7 @@ pytype_strict_library( ":dtypes", ":effects", ":pretty_printer", + ":traceback_util", ":tree_util", ":typing", ":util", diff --git a/jax/_src/api.py b/jax/_src/api.py index 6d2fc4143066..d2ac5465eded 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2213,7 +2213,8 @@ def _check_sharding(aval, s): def device_put( x, device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None, - *, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None): + *, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None, + donate: bool | Any = False, may_alias: bool | None | Any = None): """Transfers ``x`` to ``device``. Args: @@ -2222,6 +2223,16 @@ def device_put( (nested) :py:class:`Sharding` in standard Python container (must be a tree prefix of ``x``), representing the device(s) to which ``x`` should be transferred. If given, then the result is committed to the device(s). + src: The (optional) :py:class:`Device`, :py:class:`Sharding`, or a (nested) + :py:class:`Sharding` in standard Python container (must be a tree prefix + of ``x``), representing the device(s) on which ``x`` belongs. + donate: bool or a (nested) bool in standard Python container (must be a tree + prefix of ``x``). If True, ``x`` can be overwritten and marked deleted in + the caller. This is best effort. JAX will donate if possible, otherwise it + won't. The input buffer (in the future) will always be deleted if donated. + may_alias: bool or None or a (nested) bool in standard Python container + (must be a tree prefix of ``x``). If False, `x` will be copied. If true, + `x` may be aliased depending on the runtime's implementation. Returns: A copy of ``x`` that resides on ``device``. @@ -2251,11 +2262,35 @@ def device_put( src_flat = flatten_axes("device_put source", treedef, src) src_flat = list(map(_infer_src_sharding, src_flat, x_flat)) - for xf, d in zip(x_flat, device_flat): + if isinstance(donate, bool): + donate_flat = [donate] * len(x_flat) + else: + donate_flat = flatten_axes("device_put donate", treedef, donate) + + if isinstance(may_alias, bool): + may_alias_flat = [may_alias] * len(x_flat) + else: + may_alias_flat = flatten_axes("device_put may_alias", treedef, may_alias) + + copy_semantics = [] + for m, d in zip(may_alias_flat, donate_flat): + if m and d: + raise ValueError('may_alias and donate cannot be True at the same time.') + if m is None: + m = not d + if m and not d: + copy_semantics.append(dispatch.CopySemantics.ALIAS) + elif not m and d: + copy_semantics.append(dispatch.CopySemantics.DONATE) + else: + assert not m and not d + copy_semantics.append(dispatch.CopySemantics.COPY) + + for xf, d in zip(x_flat, device_flat): # type: ignore _check_sharding(shaped_abstractify(xf), d) out_flat = dispatch.device_put_p.bind( - *x_flat, devices=device_flat, srcs=src_flat - ) + *x_flat, devices=device_flat, srcs=src_flat, + copy_semantics=copy_semantics) return tree_unflatten(treedef, out_flat) @@ -2740,6 +2775,8 @@ def clear_backends(): def clean_up(): if xb._default_backend is not None: clear_backends() + clear_caches() + # Shut down distributed system if it exists. Otherwise, this is a no-op. distributed.shutdown() diff --git a/jax/_src/array.py b/jax/_src/array.py index 83be3d418c50..4e0cd3d16875 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1029,7 +1029,8 @@ def make_array_from_single_device_arrays( def _get_aval_array(self): if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding): return self.aval.update(sharding=NamedSharding( - self.sharding.mesh.abstract_mesh, self.sharding.spec)) + self.sharding.mesh.abstract_mesh, + self.sharding.normalized_spec(self.ndim))) else: return self.aval api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 3a18dcdfa2ac..5fcd2c4c4260 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -22,6 +22,7 @@ import jax from jax._src import core +from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import effects @@ -31,13 +32,18 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.lax import lax from jax._src.lax.control_flow.loops import map as lax_map from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding +from jax._src.typing import DeprecatedArg import numpy as np logger = logging.getLogger(__name__) +# TODO(dfm): Remove after 6 months. +# Added Oct 1, 2024 +deprecations.register("jax-callback-vectorized") # `pure_callback_p` is the main primitive for staging out Python pure callbacks. pure_callback_p = core.Primitive("pure_callback") @@ -45,6 +51,7 @@ dispatch.prim_requires_devices_during_lowering.add(pure_callback_p) map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip @dataclasses.dataclass(frozen=True) @@ -69,9 +76,10 @@ def pure_callback_impl( result_avals, callback: _FlatCallback, sharding: SingleDeviceSharding | None, - vectorized: bool, + vectorized: bool | DeprecatedArg, + vmap_method: str | None, ): - del sharding, vectorized, result_avals + del sharding, vectorized, vmap_method, result_avals try: cpu_device, *_ = jax.local_devices(backend="cpu") except RuntimeError as e: @@ -99,9 +107,10 @@ def pure_callback_abstract_eval( callback: _FlatCallback, result_avals, sharding: SingleDeviceSharding | None, - vectorized: bool, + vectorized: bool | DeprecatedArg, + vmap_method: str | None, ): - del avals, callback, sharding, vectorized + del avals, callback, sharding, vectorized, vmap_method return result_avals @@ -129,25 +138,51 @@ def callback_batching_rule( args, dims, *, - vectorized: bool, + vectorized: bool | None | DeprecatedArg, + vmap_method: str | None, result_avals: Sequence[core.ShapedArray], **kwargs: Any, ): - axis_size = next(a.shape[d] for a, d in zip(args, dims) - if d is not batching.not_mapped) + if isinstance(vectorized, DeprecatedArg) and vmap_method is None: + deprecations.warn( + "jax-callback-vectorized", + f"The default behavior of {prim.name} under vmap will soon " + "change. Currently, the default behavior is to generate a sequential " + "vmap (i.e. a loop), but in the future the default will be to raise " + "an error. To keep the current default, set vmap_method='sequential'.", + stacklevel=6) + vmap_method = "sequential" + + axis_size, = {a.shape[d] for a, d in zip(args, dims) + if d is not batching.not_mapped} new_args = [arg if dim is batching.not_mapped else batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)] - if vectorized: - result_avals = tuple( - core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore - for aval in result_avals) + batched_result_avals = tuple( + core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) + for aval in result_avals) + if vmap_method == "legacy_vectorized": + # This method is kept to support the behavior that was previously exposed + # when using `vectorized=True`. outvals = prim.bind( *new_args, vectorized=vectorized, - result_avals=result_avals, + vmap_method=vmap_method, + result_avals=batched_result_avals, **kwargs, ) - else: + elif vmap_method == "broadcast" or vmap_method == "broadcast_fullrank": + size = axis_size if vmap_method == "broadcast_fullrank" else 1 + bcast_args = [ + lax.broadcast(x, (size,)) if d is batching.not_mapped else x + for x, d in zip(new_args, dims)] + outvals = prim.bind( + *bcast_args, + vectorized=vectorized, + vmap_method=vmap_method, + result_avals=batched_result_avals, + **kwargs, + ) + elif vmap_method == "sequential": is_batched = [d is not batching.not_mapped for d in dims] unbatched_args, batched_args = util.partition_list(is_batched, new_args) def _batch_fun(batched_args): @@ -156,9 +191,15 @@ def _batch_fun(batched_args): *merged_args, result_avals=result_avals, vectorized=vectorized, + vmap_method=vmap_method, **kwargs, ) outvals = lax_map(_batch_fun, batched_args) + else: + raise NotImplementedError( + f"vmap is only supported for the {prim.name} primitive when vmap_method " + "is one of 'sequential', 'broadcast', 'broadcast_fullrank', or " + "'legacy_vectorized'.") return tuple(outvals), (0,) * len(outvals) @@ -261,7 +302,8 @@ def pure_callback( result_shape_dtypes: Any, *args: Any, sharding: SingleDeviceSharding | None = None, - vectorized: bool = False, + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), + vmap_method: str | None = None, **kwargs: Any, ): """Calls a pure Python callback. Works under :func:`jit`/:func:`~vmap`/etc. @@ -279,17 +321,25 @@ def pure_callback( `jit`-decorated function has no data dependence on its value. Pure callbacks may also be reordered if data-dependence allows. - When `vmap`-ed the behavior will depend on the value of the - ``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback - is assumed to obey - ``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``. - Therefore, the callback will be called directly on batched inputs (where the - batch axes are the leading dimensions). Additionally, the callbacks should - return outputs that have corresponding leading batch axes. If not vectorized - ``callback`` will be mapped sequentially across the batched axis. - For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free - to set ``vectorized=True`` because the ``np.matmul`` function handles - arbitrary leading batch dimensions. + When `vmap`-ed the behavior will depend on the value of the ``vmap_method``. + + * Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method`` + is deprecated and it will eventually raise ``NotImplementedError``. + * ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over + the batched arugments, calling ``callback`` once for each batch element. + * ``vmap_method="broadcast"`` calls ``callback`` with new axes of size ``1`` + added as the leading dimension unbatched inputs. + * ``vmap_method="broadcast_fullrank"`` behaves like ``broadcast``, but the + inputs are tiled to the expected batched shape. + + If necessary, the legacy behavior provided by the deprecated + ``vectorized=True`` argument can be recovered using + ``vmap_method="legacy_vectorized"``. + + The current default behavior is to use ``vmap_method="sequential"`` when + not specified, but this behavior is deprecated, and in the future, the + default will be to raise a ``NotImplementedError`` unless ``vmap_method`` is + explicitly specified. Args: callback: function to execute on the host. The callback is assumed to be a pure @@ -303,8 +353,8 @@ def pure_callback( *args: arguments to be passed to the callback function sharding: optional sharding that specifies the device from which the callback should be invoked. - vectorized: boolean specifying whether the callback function can operate in a - vectorized manner. + vmap_method: string specifying how the callback transforms under + :func:`~jax.vmap` as described above. **kwargs: keyword arguments to be passed to the callback function Returns: @@ -316,8 +366,62 @@ def pure_callback( - :func:`jax.debug.callback`: callback designed for general-purpose debugging. - :func:`jax.debug.print`: callback designed for printing. + Examples: + The behavior of ``pure_callback`` under :func:`~jax.vmap` is controlled by + the ``vmap_method`` argument as described above. It is useful to consider + some explicit examples that demonstrate the semantics. For example, + consider the following function: + + >>> def callback(x, y): + ... print(jnp.shape(x), jnp.shape(y)) + ... return x + y + + >>> def fun(x, y, *, vmap_method): + ... shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y)) + ... dtype = jnp.result_type(x, y) + ... out_type = jax.ShapeDtypeStruct(shape, dtype) + ... return jax.pure_callback(callback, out_type, x, y, + ... vmap_method=vmap_method) + + Calling this with ``vmap_method="broadcast"`` adds a new axis of size ``1`` + to ``y``: + + >>> from functools import partial + >>> x = jnp.arange(4) + >>> y = 1.0 + >>> jax.vmap(partial(fun, vmap_method="broadcast"), in_axes=(0, None))(x, y) + (4,) (1,) + Array([1., 2., 3., 4.], dtype=float32) + + Whereas, ``vmap_method="broadcast_fullrank"`` adds an axis of size ``4`` to + ``y``: + + >>> jax.vmap(partial(fun, vmap_method="broadcast_fullrank"), + ... in_axes=(0, None))(x, y) + (4,) (4,) + Array([1., 2., 3., 4.], dtype=float32) + .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html """ + if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: + deprecations.warn( + "jax-callback-vectorized", + "The vectorized argument of jax.pure_callback is deprecated and setting " + "it will soon raise an error. To avoid an error in the future, and to " + "suppress this warning, please use the vmap_method argument instead.", + stacklevel=2) + if vmap_method is not None: + raise ValueError( + "the vectorized and vmap_method arguments of jax.pure_callback cannot " + "be used together. Please use the vmap_method argument.") + vmap_method = "legacy_vectorized" if vectorized else "sequential" + allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank", + "legacy_vectorized", None] + if vmap_method not in allowed_vmap_methods: + raise ValueError( + f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, " + f"but got: {vmap_method}") + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) tree_util.tree_map(_check_shape_dtype, result_shape_dtypes) result_avals = tree_util.tree_map( @@ -329,6 +433,7 @@ def pure_callback( result_avals=tuple(flat_result_avals), sharding=sharding, vectorized=vectorized, + vmap_method=vmap_method, ) return tree_util.tree_unflatten(out_tree, out_flat) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 32cc4feb9054..944bf303b8f6 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -27,6 +27,7 @@ from jax.experimental import shard_map from jax._src import api +from jax._src import ad_checkpoint from jax._src import linear_util as lu from jax._src import config from jax._src import core @@ -933,6 +934,19 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, error_checks[pjit.pjit_p] = pjit_error_check +def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params): + err_vals, err_tree = jtu.tree_flatten(error) + new_vals_in = [*err_vals, *vals_in] + in_avals = tuple(map(get_shaped_aval, new_vals_in)) + checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr( + pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals) + checked_jaxpr, () = checked_jaxpr_.jaxpr, checked_jaxpr_.consts + err_and_out = ad_checkpoint.remat_p.bind(*new_vals_in, jaxpr=checked_jaxpr, + **params) + return tree_unflatten(out_tree, err_and_out) +error_checks[ad_checkpoint.remat_p] = remat_error_check + + def shard_map_error_check( error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs ): @@ -950,12 +964,10 @@ def shard_map_error_check( raise ValueError(f'Unsupported aval type: {type(v)}') in_avals[i] = sharder(mesh, new_in_names[i], v) - if not isinstance(jaxpr, core.ClosedJaxpr): - jaxpr = core.ClosedJaxpr(jaxpr, ()) with core.extend_axis_env_nd(mesh.shape.items()): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( - jaxpr, enabled_errors, err_tree, *in_avals + pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals ) num_out_error_vals = out_tree.num_leaves - len(out_names) @@ -1197,7 +1209,11 @@ def checked_fun(*args, **kwargs): return error, jtu.tree_unflatten(out_tree(), out_flat) return checked_fun -def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None: +def check(pred: Bool, msg: str, + *fmt_args, + debug: bool = False, + **fmt_kwargs, + ) -> None: """Check a predicate, add an error with msg if predicate is False. This is an effectful operation, and can't be staged (jitted/scanned/...). @@ -1206,6 +1222,9 @@ def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None: Args: pred: if False, a FailedCheckError error is added. msg: error message if error is added. Can be a format string. + debug: Whether to turn on debugging mode. If True, check will be removed + during execution. If False, the the check must be functionalized using + checkify.checkify. fmt_args, fmt_kwargs: Positional and keyword formatting arguments for `msg`, eg.: ``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)`` @@ -1230,7 +1249,7 @@ def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None: jax._src.checkify.JaxRuntimeError: -3. needs to be positive! """ - _check(pred, msg, False, *fmt_args, **fmt_kwargs) + _check(pred, msg, debug, *fmt_args, **fmt_kwargs) def _check(pred, msg, debug, *fmt_args, **fmt_kwargs): if not is_scalar_pred(pred): diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 6033e1bbb928..c7665da961af 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -77,8 +77,8 @@ def cloud_tpu_init() -> None: running_in_cloud_tpu_vm = True os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') - os.environ['TPU_ML_PLATFORM'] = 'JAX' - os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__ + os.environ.setdefault('TPU_ML_PLATFORM', 'JAX') + os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__) os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index c85abb2f83de..02aea2cd64d5 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -44,7 +44,7 @@ def get_metadata(key): while retry_count < 6: api_resp = requests.get( f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}', - headers={'Metadata-Flavor': 'Google'}) + headers={'Metadata-Flavor': 'Google'}, timeout=60) if api_resp.status_code == 200: break retry_count += 1 diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 108741b5f8fd..6c40f06d1352 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -200,6 +200,9 @@ def get_compile_options( debug_options.xla_llvm_disable_expensive_passes = True debug_options.xla_test_all_input_layouts = False + if not config.enable_remat_opt_pass.value: + debug_options.xla_disable_hlo_passes = "rematerialization" + # XLA-AutoFDO profile version: precedence order is: # 1. Whatever --jax_xla_profile_version is set to. # 2. If --jax_xla_profile_version is not set (i.e., 0), call the function diff --git a/jax/_src/config.py b/jax/_src/config.py index b21d2f35f9a4..5a0c80a4f6c5 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1537,6 +1537,14 @@ def _update_disable_jit_thread_local(val): default=True, help=('Enables using optimization-barrier op for lowering remat.')) +enable_remat_opt_pass = bool_state( + name='jax_compiler_enable_remat_pass', + default=True, + help=('Config to enable / disable the rematerialization HLO pass. ' + 'Useful to allow XLA to automatically trade off memory and ' + 'compute when encountering OOM errors. However, you are ' + 'likely to get better results manually with jax.checkpoint')) + # TODO(sharadmv,mattjj): set default to True, then remove eager_pmap = bool_state( name='jax_eager_pmap', diff --git a/jax/_src/core.py b/jax/_src/core.py index 9ef19fbeccdc..a2d243de9ea5 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2906,6 +2906,7 @@ def write(v: Var, a: AbstractValue) -> None: # Check each eqn. sentinel = object() in_idx = {v: i for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))} + mut_arrays = set() for eqn_idx, eqn in enumerate(jaxpr.eqns): prim = eqn.primitive try: @@ -2930,6 +2931,7 @@ def write(v: Var, a: AbstractValue) -> None: if prim is mutable_array_p: outvar, = eqn.outvars in_idx[outvar] = None # type: ignore + mut_arrays.add(outvar) if eqn.effects != eqn_effects: raise JaxprTypeError("Inferred effects do not match equation effects. " f"Equation effects: {eqn.effects}. " @@ -2937,6 +2939,8 @@ def write(v: Var, a: AbstractValue) -> None: for eff in eqn.effects: if isinstance(eff, effects.JaxprInputEffect): eqn_invar = eqn.invars[eff.input_index] + if eqn_invar in mut_arrays: + continue if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel: raise JaxprTypeError( "Invalid `JaxprInputEffect`: must correspond to a jaxpr invar") diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 5f1d132bcbb3..962244a321a9 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -133,3 +133,4 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-numpy-linalg-pinv-rcond') register('jax-numpy-quantile-interpolation') register('jax-numpy-trimzeros-not-1d-array') +register('pallas-gpu-triton') diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 59739f4130f3..97680bd0fcc3 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -16,9 +16,10 @@ from __future__ import annotations import atexit -from collections.abc import Callable, Iterator, Sequence +from collections.abc import Callable, Sequence import contextlib import dataclasses +import enum from functools import partial import itertools import time @@ -33,6 +34,7 @@ from jax._src import config from jax._src import core from jax._src import api +from jax._src import array from jax._src import dtypes from jax._src import source_info_util from jax._src import traceback_util @@ -205,6 +207,7 @@ def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool: # stablehlo is oblivious of physical devices. prim_requires_devices_during_lowering: set[core.Primitive] = set() +@util.weakref_lru_cache def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool: for eqn in jaxpr.eqns: if eqn.primitive in prim_requires_devices_during_lowering: @@ -220,23 +223,24 @@ class SourceInfo(NamedTuple): eqn_name: str +@util.weakref_lru_cache def get_intermediate_shardings( - jaxpr: core.Jaxpr, -) -> Iterator[tuple[Sharding, SourceInfo]]: + jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]: from jax._src import pjit from jax.experimental import shard_map + out = [] for eqn in jaxpr.eqns: if eqn.primitive is pjit.sharding_constraint_p: s = eqn.params['sharding'] if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): continue source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - yield (s, source_info) + out.append((s, source_info)) elif eqn.primitive is pjit.pjit_p: source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - yield from ((i, source_info) for i in eqn.params['in_shardings']) - yield from ((o, source_info) for o in eqn.params['out_shardings']) + out.extend((i, source_info) for i in eqn.params['in_shardings']) + out.extend((o, source_info) for o in eqn.params['out_shardings']) elif eqn.primitive is shard_map.shard_map_p: if not eqn.params['mesh']._is_jax_device_mesh: continue @@ -244,14 +248,15 @@ def get_intermediate_shardings( def _names_to_pspec(names): ndmin = max(names) + 1 if names else 0 return PartitionSpec(*(names.get(i) for i in range(ndmin))) - yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) - for names in [*eqn.params['in_names'], *eqn.params['out_names']]) + out.extend((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) + for names in [*eqn.params['in_names'], *eqn.params['out_names']]) elif eqn.primitive is device_put_p: source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - yield from ((s, source_info) for s in eqn.params['devices'] - if isinstance(s, Sharding) and s.memory_kind is not None) + out.extend((s, source_info) for s in eqn.params['devices'] + if isinstance(s, Sharding) and s.memory_kind is not None) for subjaxpr in core.subjaxprs(jaxpr): - yield from get_intermediate_shardings(subjaxpr) + out.extend(get_intermediate_shardings(subjaxpr)) + return out def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool: @@ -327,18 +332,22 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))): raise FloatingPointError(f"invalid value (inf) encountered in {name}") +class CopySemantics(enum.Enum): + ALIAS = enum.auto() + COPY = enum.auto() + DONATE = enum.auto() def _identity_fn(x): return x -def _different_device_order_reshard(x, target_sharding): - from jax._src import api, array - +def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): x._check_if_deleted() inp_sharding = x.sharding + donate_argnums = 0 if copy == CopySemantics.DONATE else None if inp_sharding._device_assignment == target_sharding._device_assignment: - return api.jit(_identity_fn, out_shardings=target_sharding)(x) + return api.jit(_identity_fn, out_shardings=target_sharding, + donate_argnums=donate_argnums)(x) if inp_sharding.device_set != target_sharding.device_set: inp_ids = [d.id for d in inp_sharding._device_assignment] @@ -381,7 +390,8 @@ def _different_device_order_reshard(x, target_sharding): memory_kind=target_sharding.memory_kind), x._arrays, ) - return api.jit(_identity_fn, out_shardings=target_sharding)(new_x) + return api.jit(_identity_fn, out_shardings=target_sharding, + donate_argnums=donate_argnums)(new_x) @dataclasses.dataclass(frozen=True) @@ -403,26 +413,26 @@ def result_handler(self): return pxla.global_aval_to_result_handler(self.aval, self.s, self.committed) -def _device_put_sharding_impl(x, aval, device): - from jax._src import array +def _device_put_sharding_impl(x, aval, device, copy): from jax.experimental import multihost_utils if isinstance(device, Sharding): s = device - if getattr(x, 'sharding', None) == s and getattr(x, '_committed', False): + if (getattr(x, 'sharding', None) == s and getattr(x, '_committed', False) + and copy == CopySemantics.ALIAS): return x if (not s.is_fully_addressable and isinstance(x, array.ArrayImpl) and not x.is_fully_addressable): assert isinstance(s, Sharding) - return _different_device_order_reshard(x, s) + return _different_device_order_reshard(x, s, copy) if (s.is_fully_addressable and isinstance(x, array.ArrayImpl) and x.is_fully_addressable and s.num_devices > 1 and s._internal_device_list != x.sharding._internal_device_list and # pytype: disable=attribute-error s.device_set == x.sharding.device_set): assert isinstance(s, Sharding) - return _different_device_order_reshard(x, s) + return _different_device_order_reshard(x, s, copy) if not s.is_fully_addressable: if ((isinstance(x, array.ArrayImpl) and not x._committed) or @@ -433,7 +443,9 @@ def _device_put_sharding_impl(x, aval, device): f"{type(x)} passed to device_put is not the same on each" " process. Make sure you are passing the same value of" f" {type(x)} on each process.")) - return api.jit(_identity_fn, out_shardings=s)(x) + return api.jit( + _identity_fn, out_shardings=s, + donate_argnums=(0 if copy == CopySemantics.DONATE else None))(x) # TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array. raise ValueError( "device_put's second argument must be a Device or a Sharding which" @@ -447,9 +459,10 @@ def _device_put_sharding_impl(x, aval, device): raise ValueError( "device_put's first argument must be a fully addressable array, but " f"got value with devices {x.devices()}") - if device is None: + if device is None and copy == CopySemantics.ALIAS: return x elif is_single_device_sharding(x.sharding): + device = x.sharding._device_assignment[0] if device is None else device return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], [device]) @@ -459,11 +472,8 @@ def _device_put_sharding_impl(x, aval, device): def _device_put_impl( - x, - *, - device: Device | Sharding | Layout | None, - src: Device | Sharding | Layout | None, -): + x, *, device: Device | Sharding | Layout | None, + src: Device | Sharding | Layout | None, copy: CopySemantics): if (isinstance(device, TransferToMemoryKind) or isinstance(src, TransferToMemoryKind)): raise ValueError( @@ -482,30 +492,33 @@ def _device_put_impl( dll = l.device_local_layout x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None if dll is None and l.sharding is None: - return _device_put_sharding_impl(x, aval, l.sharding) + return _device_put_sharding_impl(x, aval, l.sharding, copy) if (not isinstance(l.sharding, Sharding) or not isinstance(dll, (DeviceLocalLayout, type(None)))): raise ValueError( "sharding and device_local_layout in `Layout` instance should be" f" concrete. Got layout: {l} for input {aval.str_short()}") - if getattr(x, 'layout', None) == l and getattr(x, '_committed', False): + if (getattr(x, 'layout', None) == l and getattr(x, '_committed', False) and + copy == CopySemantics.ALIAS): return x if x_dll is None and dll is None: - return _device_put_sharding_impl(x, aval, l.sharding) - return api.jit(_identity_fn, out_shardings=l)(x) + return _device_put_sharding_impl(x, aval, l.sharding, copy) + return api.jit( + _identity_fn, out_shardings=l, + donate_argnums=(0 if copy == CopySemantics.DONATE else None))(x) - return _device_put_sharding_impl(x, aval, device) + return _device_put_sharding_impl(x, aval, device, copy) def _batched_device_put_impl( *xs, devices: Sequence[Device | Sharding | Layout | None], srcs: Sequence[Device | Sharding | Layout | None], -): + copy_semantics: Sequence[CopySemantics]): ys = [] shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], [] - for i, (x, device, src) in enumerate(zip(xs, devices, srcs)): - y = _device_put_impl(x, device=device, src=src) + for i, (x, device, src, cp) in enumerate(zip(xs, devices, srcs, copy_semantics)): + y = _device_put_impl(x, device=device, src=src, copy=cp) if isinstance(y, _DeferredShardArg): shard_arg_indices.append(i) shard_arg_xs.append(y.x) @@ -529,17 +542,29 @@ def _batched_device_put_impl( device_put_p = core.Primitive('device_put') device_put_p.multiple_results = True device_put_p.def_impl(_batched_device_put_impl) -device_put_p.def_abstract_eval(lambda *xs, devices, srcs: xs) +device_put_p.def_abstract_eval(lambda *xs, devices, srcs, copy_semantics: xs) -def _device_put_transpose(cts, *_, devices, srcs): +def _device_put_transpose(cts, *_, devices, srcs, copy_semantics): results = [None] * len(cts) dp_args = [] - for i, (ct, device, src) in enumerate(zip(cts, devices, srcs)): + for i, (ct, device, src, cp) in enumerate(zip(cts, devices, srcs, copy_semantics)): if type(ct) is not ad.Zero: - dp_args.append((i, ct, device, src)) + dp_args.append((i, ct, device, src, cp)) if dp_args: - indices, args, devices, srcs = list(zip(*dp_args)) - ys = device_put_p.bind(*args, devices=srcs, srcs=devices) + indices, args, devices, srcs, copy_semantics = list(zip(*dp_args)) + new_copy_semantics = [] + for cp in copy_semantics: + if cp == CopySemantics.DONATE: + raise ValueError( + "donate=True is not allowed during tranposition of device_put." + " Please file an issue if you want this to be supported.") + elif cp == CopySemantics.ALIAS: + new_copy_semantics.append(CopySemantics.COPY) + else: + assert cp == CopySemantics.COPY + new_copy_semantics.append(CopySemantics.COPY) + ys = device_put_p.bind(*args, devices=srcs, srcs=devices, + copy_semantics=new_copy_semantics) for i, y in zip(indices, ys): results[i] = y return results @@ -554,21 +579,27 @@ def _device_put_batcher(batched_args, batch_dims, **params): return device_put_p.bind(*batched_args, **params), batch_dims batching.primitive_batchers[device_put_p] = _device_put_batcher -def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs): +def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics): # TODO(yashkatariya): Maybe we should add the custom calls anyways if it's # being used inside jit? Atleast for now, this preserves the old behavior. if ctx.module_context.all_default_mem_kind: return xs - def lower(x, device, src, aval, out_aval): + def lower(x, device, aval, out_aval): if (isinstance(device, (Sharding, TransferToMemoryKind)) and device.memory_kind is not None): if isinstance(device, Sharding): - x = mlir.wrap_with_sharding_op( - ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) + if config.use_shardy_partitioner.value: + x = mlir.wrap_with_sharding_op( + ctx, x, out_aval, + device._to_sdy_sharding(aval.ndim)) + else: + x = mlir.wrap_with_sharding_op( + ctx, x, out_aval, + device._to_xla_hlo_sharding(aval.ndim).to_proto()) x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval) return x return x - return list(map(lower, xs, devices, srcs, ctx.avals_in, ctx.avals_out)) + return list(map(lower, xs, devices, ctx.avals_in, ctx.avals_out)) mlir.register_lowering( device_put_p, _tpu_gpu_device_put_lowering, platform='tpu') @@ -576,11 +607,11 @@ def lower(x, device, src, aval, out_aval): device_put_p, _tpu_gpu_device_put_lowering, platform='gpu') -def _common_device_put_lowering(ctx, *xs, devices, srcs): +def _common_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics): return xs mlir.register_lowering(device_put_p, _common_device_put_lowering) -def _propagate_mem_kind_dp(*xm, devices=None, srcs=None): +def _propagate_mem_kind_dp(*xm, devices, srcs, copy_semantics): memory_kinds = [] for device in devices: if isinstance(device, (Sharding, TransferToMemoryKind)): diff --git a/jax/_src/environment_info.py b/jax/_src/environment_info.py index 32d34db254fd..4abfdeaa0f14 100644 --- a/jax/_src/environment_info.py +++ b/jax/_src/environment_info.py @@ -21,7 +21,7 @@ from jax import version from jax._src import lib -from jax._src import xla_bridge +from jax._src import xla_bridge as xb import numpy as np def try_nvidia_smi() -> str | None: @@ -41,19 +41,15 @@ def print_environment_info(return_string: bool = False) -> str | None: """ # TODO(jakevdp): should we include other info, e.g. jax.config.values? python_version = sys.version.replace('\n', ' ') - with np.printoptions(threshold=4, edgeitems=2): - devices_short = str(np.array(xla_bridge.devices())).replace('\n', '') - info = textwrap.dedent( - f"""\ + info = textwrap.dedent(f"""\ jax: {version.__version__} jaxlib: {lib.version_str} numpy: {np.__version__} python: {python_version} - jax.devices ({xla_bridge.device_count()} total, {xla_bridge.local_device_count()} local): {devices_short} - process_count: {xla_bridge.process_count()} + device info: {xb.devices()[0].device_kind}-{xb.device_count()}, {xb.local_device_count()} local devices" + process_count: {xb.process_count()} platform: {platform.uname()} -""" - ) +""") nvidia_smi = try_nvidia_smi() if nvidia_smi: info += '\n\n$ nvidia-smi\n' + nvidia_smi diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 7f7773acbd39..774953ed9b97 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -223,15 +223,6 @@ def __str__(self): # do not want the entire serialized module to end up in locations. return f"Exported(fun_name={self.fun_name}, ...)" - # For backwards compatibility - # TODO(necula): remove after September 2024. - @property - def in_shardings(self): - return self.in_shardings_hlo - @property - def out_shardings(self): - return self.out_shardings_hlo - def in_shardings_jax( self, mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: @@ -936,6 +927,7 @@ def _check_lowering(lowering) -> None: "lapack_sgeev_ffi", "lapack_dgeev_ffi", "lapack_cgeev_ffi", "lapack_zgeev_ffi", "lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi", "lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi", + "lapack_sgehrd_ffi", "lapack_dgehrd_ffi", "lapack_cgehrd_ffi", "lapack_zgehrd_ffi", ] # These are the JAX custom call target names that are guaranteed to be stable. # Their backwards compatibility is tested by back_compat_test.py. @@ -946,40 +938,30 @@ def _check_lowering(lowering) -> None: "__gpu$xla.gpu.triton", # Pallas call on GPU # cholesky on CPU "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", - # eigh on CPU - "lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd", - # eigh on GPU - "cusolver_syevj", "cusolver_syevd", - "hipsolver_syevj", "hipsolver_syevd", # eigh on TPU "Eigh", # eig on CPU "lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev", - # qr on CPU - "lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf", - # householder product on CPU - "lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr", # svd on CPU "lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd", - # qr on GPU - "cusolver_geqrf", "cublas_geqrf_batched", - "cusolver_orgqr", - "hipsolver_geqrf", "hipblas_geqrf_batched", - "hipsolver_orgqr", # qr and svd on TPU "Qr", "ProductOfElementaryHouseholderReflectors", # triangular_solve on CPU "blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm", - # TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU - # lu on CPU - "lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf", # schur on CPU "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", + # hessenberg on CPU + "lapack_sgehrd", "lapack_dgehrd", "lapack_cgehrd", "lapack_zgehrd", # lu on GPU - "cu_lu_pivots_to_permutation", - # "cublas_getrf_batched", "cusolver_getrf", - # "hipblas_getrf_batched", "hipsolver_getrf", - "cusolver_getrf_ffi", + "cu_lu_pivots_to_permutation", "cusolver_getrf_ffi", + "hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi", + "cu_lu_pivots_to_permutation", "cusolver_getrf_ffi", + # qr on GPU + "cusolver_geqrf_ffi", "cusolver_orgqr_ffi", + "hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", + # eigh on GPU + "cusolver_syevd_ffi", "hipsolver_syevd_ffi", + # svd on GPU # lu on TPU "LuDecomposition", # ApproxTopK on TPU diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index 833ac4f615a8..6bbba0cbd88e 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -20,8 +20,12 @@ import os from typing import Any +import numpy as np + from jax._src import core +from jax._src import deprecations from jax._src import dispatch +from jax._src import effects from jax._src import util from jax._src.callback import _check_shape_dtype, callback_batching_rule from jax._src.interpreters import ad @@ -31,7 +35,8 @@ from jax._src.lib import jaxlib from jax._src.lib import xla_client from jax._src.lib.mlir import ir -from jax._src.typing import Array, ArrayLike, DuckTypedArray, Shape +from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray, + Shape) map, unsafe_map = util.safe_map, map FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None @@ -196,22 +201,22 @@ def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], *args: ArrayLike, - vectorized: bool = False, + has_side_effect: bool = False, + vmap_method: str | None = None, + vectorized: bool | DeprecatedArg = DeprecatedArg(), **kwargs: Any, ) -> Array | list[Array]: """Call a foreign function interface (FFI) target. Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under - :func:`~jax.vmap` depends on the value of ``vectorized``. When ``vectorized`` - is ``True``, the FFI target is assumed to satisfy: ``ffi_call(xs) == - jnp.stack([ffi_call(x) for x in xs])``. In other words, calling the FFI target - with an extra leading dimension should return the same result as calling it - within a loop and stacking along the zeroth axis. Therefore, the FFI target - will be called directly on batched inputs (where the batch axes are the - leading dimensions). Additionally, the callbacks should return outputs that - have corresponding leading batch axes. If ``vectorized`` is ``False`` (the - default behavior), transforming this ``ffi_call`` under :func:`~jax.vmap` will - result in a :func:`~jax.lax.scan` with the ``ffi_call`` in the body. + :func:`~jax.vmap` depends on the value of ``vmap_method``. See the + :func:`~jax.pure_callback` documenation for more details about the allowed + values and examples of their behavior. + + The current default behavior is to use ``vmap_method="sequential"`` when + not specified, but this behavior is deprecated, and in the future, the + default will be to raise a ``NotImplementedError`` unless ``vmap_method`` is + explicitly specified. Args: target_name: the name of the XLA FFI custom call target that was registered @@ -222,8 +227,11 @@ def ffi_call( used to define the elements of ``result_shape_dtypes``. ``jax.core.abstract_token`` may be used to represent a token-typed output. *args: the arguments passed to the custom call. - vectorized: boolean specifying whether the callback function can operate in - a vectorized manner, as described above. + has_side_effect: boolean specifying whether the custom call has side + effects. When ``True``, the FFI call will be executed even when the + outputs are not used. + vmap_method: string specifying how the FFI call transforms under + :func:`~jax.vmap` as described above. **kwargs: keyword arguments that are passed as named attributes to the custom call using XLA's FFI interface. @@ -231,6 +239,25 @@ def ffi_call( One or more :class:`~jax.Array` objects whose shapes and dtypes match ``result_shape_dtypes``. """ + if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: + deprecations.warn( + "jax-callback-vectorized", + "The vectorized argument of ffi_call is deprecated and setting " + "it will soon raise an error. To avoid an error in the future, and to " + "suppress this warning, please use the vmap_method argument instead.", + stacklevel=2) + if vmap_method is not None: + raise ValueError( + "the vectorized and vmap_method arguments of ffi_call cannot " + "be used together. Please use the vmap_method argument.") + vmap_method = "legacy_vectorized" if vectorized else "sequential" + allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank", + "legacy_vectorized", None] + if vmap_method not in allowed_vmap_methods: + raise ValueError( + f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, " + f"but got: {vmap_method}") + if isinstance(result_shape_dtypes, Sequence): multiple_results = True result_avals = _result_avals(result_shape_dtypes) @@ -241,8 +268,10 @@ def ffi_call( *args, result_avals=result_avals, vectorized=vectorized, + vmap_method=vmap_method, target_name=target_name, - **kwargs, + has_side_effect=has_side_effect, + **_wrap_kwargs_hashable(kwargs), ) if multiple_results: return results @@ -250,15 +279,98 @@ def ffi_call( return results[0] +# ffi_call must support some small non-hashable input arguments, like np.arrays +# and dicts, to support calling FFI targets with array inputs or user defined +# structs. Since these arguments will eventually be embedded in the HLO as +# dense attributes, we assume that they are small and hash by making an +# immutable copy and hashing by value. +def _wrap_kwargs_hashable(kwargs: dict[str, Any]) -> dict[str, Any]: + hashable_kwargs: dict[str, Any] = {} + for k, v in kwargs.items(): + if isinstance(v, np.ndarray): + hashable_kwargs[k] = HashableArray(v) + elif isinstance(v, dict): + hashable_kwargs[k] = HashableDict(v) + else: + try: + hash(v) + except TypeError as e: + raise TypeError( + f"Non-hashable keyword argument to ffi_call {k}: {v}") from e + else: + hashable_kwargs[k] = v + return hashable_kwargs + + +def _unwrap_kwargs_hashable(kwargs: dict[str, Any]) -> dict[str, Any]: + unwrapped_kwargs: dict[str, Any] = {} + for k, v in kwargs.items(): + if isinstance(v, HashableArray): + unwrapped_kwargs[k] = v.val + elif isinstance(v, HashableDict): + unwrapped_kwargs[k] = dict(v.val) + else: + unwrapped_kwargs[k] = v + return unwrapped_kwargs + + +class HashableArray: + __slots__ = ["val"] + + def __init__(self, val): + assert isinstance(val, np.ndarray) + self.val = np.copy(val) + self.val.setflags(write=False) + + def __repr__(self): + return f"HashableArray({self.val})" + + def __hash__(self): + return hash((self.val.shape, self.val.dtype, self.val.tobytes())) + + def __eq__(self, other): + return isinstance(other, HashableArray) and np.array_equal(self.val, other.val) + + +class HashableDict: + __slots__ = ["val"] + + def __init__(self, val): + assert isinstance(val, dict) + self.val = tuple(sorted(val.items())) + + def __repr__(self): + return f"HashableDict({dict(self.val)})" + + def __hash__(self): + return hash(self.val) + + def __eq__(self, other): + return isinstance(other, HashableDict) and self.val == other.val + + +class FfiEffect(effects.Effect): + def __str__(self): + return "FFI" + + +_FfiEffect = FfiEffect() +effects.lowerable_effects.add_type(FfiEffect) +effects.control_flow_allowed_effects.add_type(FfiEffect) + + def ffi_call_abstract_eval( *avals_in, result_avals: tuple[core.AbstractValue, ...], target_name: str, - vectorized: bool, + vectorized: bool | DeprecatedArg, + vmap_method: str | None, + has_side_effect: bool, **kwargs: Any, ): - del avals_in, target_name, vectorized, kwargs - return result_avals + del avals_in, target_name, vectorized, vmap_method, kwargs + effects = {_FfiEffect} if has_side_effect else core.no_effects + return result_avals, effects def ffi_call_jvp(*args, target_name, **kwargs): @@ -280,17 +392,20 @@ def ffi_call_lowering( *operands: ir.Value, result_avals: tuple[core.AbstractValue, ...], target_name: str, - vectorized: bool, + vectorized: bool | DeprecatedArg, + vmap_method: str | None, + has_side_effect: bool, **kwargs: Any, ) -> Sequence[ir.Value]: - del result_avals, vectorized - return ffi_lowering(target_name)(ctx, *operands, **kwargs) + del result_avals, vectorized, vmap_method + rule = ffi_lowering(target_name, has_side_effect=has_side_effect) + return rule(ctx, *operands, **_unwrap_kwargs_hashable(kwargs)) ffi_call_p = core.Primitive("ffi_call") ffi_call_p.multiple_results = True -ffi_call_p.def_impl(functools.partial(dispatch.apply_primitive, ffi_call_p)) -ffi_call_p.def_abstract_eval(ffi_call_abstract_eval) +dispatch.simple_impl(ffi_call_p) +ffi_call_p.def_effectful_abstract_eval(ffi_call_abstract_eval) ad.primitive_jvps[ffi_call_p] = ffi_call_jvp ad.primitive_transposes[ffi_call_p] = ffi_call_transpose batching.primitive_batchers[ffi_call_p] = functools.partial( diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py new file mode 100644 index 000000000000..204af8f55396 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py @@ -0,0 +1,526 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, float32, complex64 + +data_2024_08_30 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_30["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zgehrd'], + serialized_date=datetime.date(2024, 8, 30), + inputs=(), + expected_outputs=(array([[[ 0.7137638961069523 +2.4533812415320035e+00j, + -0.3272236912989258 -3.2003874808591863e+00j, + -3.065817294924296 +1.6978219378771007e+00j, + -3.3971558164664 +2.6931967836060400e-01j], + [ 6.346214936866542 +0.0000000000000000e+00j, + 2.083218259144673 -1.2191838498692813e+00j, + 1.9552582313969427 -3.3216313521481879e+00j, + 2.7451664155727293 +2.5460553490974451e+00j], + [-0.16133388943502391 +3.6906265775683444e-01j, + -4.698636849217318 +0.0000000000000000e+00j, + 2.5396292124414077 -3.3038474840573420e+00j, + 2.5410992366186456 +4.1958389320867528e-01j], + [ 0.47396123039280513 +3.9524384493417053e-03j, + 0.058880409351504966-7.8934332132630333e-02j, + 0.9469634796174572 +0.0000000000000000e+00j, + -3.130422531669044 -8.8070401977461810e-01j]], + + [[-6.7065483048969465 -4.1981401054281309e-01j, + -0.21813268822330256 -3.8602920478381799e+00j, + -0.8248337528620167 -2.9073223456990824e+00j, + -3.597231249446879 +2.7626541679004930e+00j], + [-6.812126638479044 +0.0000000000000000e+00j, + -0.20651586628458585 -1.0948249928988512e+00j, + -1.6675586608354327 +4.2553627621795744e+00j, + -2.410110723267707 +3.6065122124698634e-01j], + [ 0.038235817369200516-3.7823713529009173e-01j, + -8.508141062606947 +0.0000000000000000e+00j, + 4.260708077719245 -6.8052584397204630e-02j, + 5.345997177836541 -1.1955161503390279e+00j], + [-0.18541509608158574 -1.2016051097247168e-01j, + -0.02698777746917469 -4.4847463691672246e-01j, + 6.149305574585603 +0.0000000000000000e+00j, + -2.483131585236393 +2.8524912589603817e+00j]]]), array([[1.2286220194325557+0.5121060656500841j , + 1.9529937219183482-0.23299856112387676j, + 1.5940499664125072-0.8044281430962614j ], + [1.6682114302246909-0.11372755955977935j, + 1.4075913155446236-0.6008708461880701j , + 1.5086928152468893-0.8609480935086589j ]])), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(0.71376389610695234,2.4533812415320035), (-1.0686093138739379,-1.885041510645256), (3.2629529488994033,-0.87160041258342402), (2.4332168907311504,3.4960248990882183)], [(-1.450884474619478,-3.249935163088522), (0.53920035905924757,-5.0056840575116066), (0.13157186736298554,2.5015499854549939), (-1.2451270607408882,0.24345856951924827)], [(2.457366083193417,-2.3532935513245605), (-0.37595429769485644,1.5729223427874068), (3.5877693970448052,-0.30904304334212157), (-1.685615117470264,2.6148811836470265)], [(-3.6826776618664727,-1.5711608241015744), (-0.12407609317204518,-4.7137561145212281), (1.3298255603911306,-1.6739172003954141), (-2.6345448161870149,-0.089008252847513236)]], [[(-6.7065483048969465,-0.41981401054281309), (-2.1586544949255457,0.34815132010709054), (-5.1462488701272413,3.440817752555807), (1.0301804086076078,-0.6994760434270566)], [(4.551940883969797,-0.77472653800638502), (4.4485186470774796,-0.0024458890677252756), (0.66610302132250898,2.5976571401862039), (-5.0693248202533674,-5.7405538897950699)], [(0.14148406399087146,-4.3279346473525058), (-2.353557113110897,2.0880432773400326), (-3.2524452107293618,-0.42398740171508631), (3.7200566224095519,-0.56951559566037058)], [(-2.2001612082232613,-1.2218661647417151), (0.72437359623190833,8.6381970213061301), (0.72314820631775734,0.058458198280771749), (0.37498718985014962,2.1160469724471378)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) + %0:4 = stablehlo.custom_call @lapack_zgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<4288xcomplex>) loc(#loc2) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) + %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) + %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) + %cst_6 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) + %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) + %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) + %cst_7 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) + %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) + %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) + return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) +#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa19\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x10\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b/OoO\x01\x05\x0b\x0f\x035\x0f\x1b\x07\x0b\x17\x07\x07\x0f\x07\x13\x17\x07\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xce\x0f\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1f\x01\x03\x01\x1dC\x1dE\x1dG\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x08p\t\xdba\'\xd7\xe6?\xa8\xff\'X\x86\xa0\x03@\x0c\xa2t\x14\x06\x19\xf1\xbfT.}I!)\xfe\xbf\x0fG_\x13\x87\x1a\n@\xae:g\x8c&\xe4\xeb\xbf\xeb\x1e\xcej:w\x03@N\xaf\xfc\xe6\xdb\xf7\x0b@\x9f<\x8c\xa3\xd26\xf7\xbf^\xaf\xbc\x01\xde\xff\t\xc0b\xd4\x84\x1c!A\xe1?\xd6{\xa4\n\xd2\x05\x14\xc0\xf0\xe6\xb2\xd1X\xd7\xc0?2\xb5\x86\xa3,\x03\x04@\x91\xf2SZ\n\xec\xf3\xbf\x04\x10\x02\x81\xa6)\xcf?8\xec\x8c\x8c\xaf\xa8\x03@\r\x9d\xc6\x91\x8b\xd3\x02\xc0\xb0\xf6X\x9d\xa2\x0f\xd8\xbf\xbd\xb6V\x9e\xb0*\xf9?7-\x0fq\xc0\xb3\x0c@{|\ry\\\xc7\xd3\xbf\x04\xd9\xb2\x8eG\xf8\xfa\xbf\x9b\x84u\xd3F\xeb\x04@\xf4h\xbb\xb4\x1fv\r\xc0\xdc\\D\x88y#\xf9\xbf\x9a\xaecjs\xc3\xbf\xbf<\xc1\x04\xe2\xe2\xda\x12\xc0\x89<\xb4*\xf7F\xf5?\x1b\x90\xfef]\xc8\xfa\xbf\xdc\xf4\x8a;\x8c\x13\x05\xc0\xf8\xdd\r\xaf>\xc9\xb6\xbfvN\x1af\x81\xd3\x1a\xc0Z\xc6k\x95;\xde\xda\xbf\x87\x8c\xd8\xa5\xecD\x01\xc0\xdd\xd3zy\x1cH\xd6?\x04\x18\x89C\xc2\x95\x14\xc0\x8c\xc95u\xcb\x86\x0b@\x881\xbfs\x9e{\xf0?\x92Y[\x95\x1bb\xe6\xbf\x06\xe7\xb7\xfd/5\x12@L\x95\x02O\x8f\xca\xe8\xbf2`\xe3xH\xcb\x11@>\xda\xc6\xb1f\td\xbfZ\x1a\x8bH\xb7P\xe5?\xa8\x90zw\x00\xc8\x04@<(\xef\x15\xfdF\x14\xc0\xb4aF\xc2S\xf6\x16\xc0\xc1{\xdfY&\x1c\xc2?\xcfj\xa6\x19\xceO\x11\xc0\xc4\xa2p\xc0\x15\xd4\x02\xc0\xfcv\xa6\x08P\xb4\x00@^\xea\xa0\xfe\x01\x05\n\xc0^\x11\x12\x0e\x9c"\xdb\xbfR#\xe4\x0b\xad\xc2\r@F\x8b=\xc5x9\xe2\xbfZ\xf9\x99\x1e\xee\x99\x01\xc0My\x1a\x89\xc3\x8c\xf3\xbf\xd1\xdc<\x89\x11.\xe7?2\xd4\x8d\xc2\xc1F!@mw\t\xb5\x07$\xe7?G\x16\x99\xa3;\xee\xad?M\xd24E\xca\xff\xd7?\xa2\xae\xfb\x08\xaa\xed\x00@\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f\'\x01\t\x07\x07\x01\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x03\x1b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r\x0b)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1d\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03)\x05\x13\x19\x05\x07\x01\x11\x03+\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03/\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x033\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x035\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgehrd\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_30["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cgehrd'], + serialized_date=datetime.date(2024, 8, 30), + inputs=(), + expected_outputs=(array([[[ 5.2023945 -0.878671j , -2.8841915 -0.47488597j , + 1.3024182 +0.6651789j , 4.9291854 -1.9147056j ], + [ 6.3457894 +0.j , 1.6869383 -4.6557646j , + 0.88955224-1.7617276j , 2.9149916 +4.342665j ], + [-0.2465725 -0.5776757j , -5.3007755 +0.j , + -0.9786545 -0.0633831j , -1.3690261 -1.5921416j ], + [ 0.35462287+0.35993803j , -0.38403815-0.46558398j , + 2.8020499 +0.j , 0.5636822 -6.218306j ]], + + [[ 1.0687767 -3.88293j , -4.0144 -2.5885587j , + 5.3900986 -0.8850739j , 2.079677 +3.5515747j ], + [ 7.5675693 +0.j , 0.5971966 -3.6699948j , + 2.246994 -1.0858283j , -0.8870981 -0.022960603j], + [-0.2183232 +0.10552277j , 5.860886 +0.j , + -5.091036 +6.2841997j , 5.008773 +1.8765848j ], + [ 0.1378771 +0.427895j , 0.63263524-0.3470098j , + 6.4528017 +0.j , -4.233642 -0.84165764j ]]], + dtype=complex64), array([[1.0933675-0.3605358j , 1.1987956+0.5659744j , + 1.9999101-0.013409062j], + [1.4504763-0.44363326j , 1.3110259-0.07426627j , + 1.227255 +0.97383535j ]], dtype=complex64)), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(5.20239449,-0.87867099), (-0.211780012,-0.923053801), (-5.25181627,1.90887547), (-1.61342144,-1.98000157)], [(-5.924900e-01,2.28788424), (-1.74142945,-3.25563216), (3.08765078,-3.25260139), (-3.35189271,-0.571629047)], [(3.032444,3.44394636), (1.22205484,0.808871626), (2.58686161,-7.47011566), (1.9139297,-2.57945323)], [(-3.28396916,-1.68601465), (2.62759161,-0.953538239), (-2.78763294,-0.0429570749), (0.426534384,-0.211706176)]], [[(1.06877673,-3.882930e+00), (-0.0192247611,5.96663713), (1.15329504,-5.0599103), (-1.76508892,-1.98541296)], [(-3.40901089,3.35722542), (-6.13531398,2.55851483), (-4.8095789,0.164206699), (-0.247624069,-3.13545418)], [(2.04217815,-1.89123917), (-1.18974173,-1.69466627), (-2.28673625,-0.487834573), (3.01541853,-1.85637176)], [(-2.9499588,-4.23393869), (8.44624137,5.57274485), (-1.09048736,2.4864223), (-0.305431545,-0.298133373)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) + %0:4 = stablehlo.custom_call @lapack_cgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<4288xcomplex>) loc(#loc2) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) + %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) + %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) + %cst_6 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) + %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) + %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) + %cst_7 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) + %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) + %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) + return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) +#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa19\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x035\x0f\x1b\x07\x0b\x17\x07\x07\x0f\x07\x13\x17\x07\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xae\x0b\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1f\x01\x03\x01\x1dC\x1dE\x1dG\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x04\x04z\xa6@\x95\xf0`\xbf\xdc\xdcX\xbeAMl\xbf\xe1\x0e\xa8\xc0\x08V\xf4?\x98\x84\xce\xbf\xb1p\xfd\xbfm\xad\x17\xbf\xb2l\x12@)\xe7\xde\xbfG\\P\xc0\x12\x9cE@\x9f*P\xc0i\x85V\xc0HV\x12\xbf\x90\x13B@\x9ei\\@Kl\x9c?6\x12O?$\x8f%@0\x0b\xef\xc0\xa6\xfb\xf4?\xc3\x15%\xc0\x8d,R\xc0T\xcf\xd7\xbfv*(@\x15\x1bt\xbf\x94h2\xc0\xc2\xf3/\xbd\xb7b\xda>\x81\xc9X\xbe\xad\xcd\x88?\xed\x81x\xc0?}\x9d\xbc\xb1\xee\xbe@,\x9f\x93?\xc9\xea\xa1\xc0o\xee\xe1\xbf\x03"\xfe\xbf<-Z\xc0\xc8\xdcV@~T\xc4\xc0\xb5\xbe#@\x12\xe8\x99\xc0\xcd%(>*\x91}\xbeH\xabH\xc0\x0c\xb3\x02@ \x14\xf2\xbfuI\x98\xbf\xd3\xea\xd8\xbf\xe3Y\x12\xc0t\xc5\xf9\xbe\x9e\xfc@@\x97\x9d\xed\xbf \xcc<\xc0m|\x87\xc0\xce#\x07A\xedS\xb2@\x17\x95\x8b\xbf\x8b!\x1f@\x86a\x9c\xbe\xf0\xa4\x98\xbe\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f\'\x01\t\x07\x07\x01\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x03\x1b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r\t)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1d\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03)\x05\x13\x19\x05\x07\x01\x11\x03+\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03/\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x033\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x035\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_cgehrd\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_30["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_sgehrd'], + serialized_date=datetime.date(2024, 8, 30), + inputs=(), + expected_outputs=(array([[[-3.5237675 , -6.1161256 , -0.549011 , -4.7706876 ], + [ 5.8401766 , 3.424213 , 0.3059119 , 2.3492367 ], + [ 0.63135445 , 2.7238827 , -0.106214404, -0.82470125 ], + [-0.27146497 , 0.09917235 , 0.2545611 , -0.5113605 ]], + + [[ 4.297168 , -1.8758869 , 0.33528137 , 5.867136 ], + [-7.129698 , -3.3118155 , -1.3492918 , -2.8959117 ], + [-0.7266852 , -3.506432 , 4.77164 , -4.0780373 ], + [ 0.14084078 , 0.3389384 , 2.3910007 , -0.79807365 ]]], + dtype=float32), array([[1.3584172, 1.9805213, 0. ], + [1.2920669, 1.7939165, 0. ]], dtype=float32)), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[-3.52376747, -0.758410036, 4.85795927, -6.0243597], [-2.09321976, -1.27957773, -0.956288218, -1.11928439], [-5.00878525, 0.51314038, 3.53047514, -2.91282868], [2.15363932, 0.635739565, -0.21264787, 0.555740714]], [[4.29716778, -3.86209464, -2.39021468, 4.17441607], [2.08234859, -1.03958249, 4.09025383, 5.22586823], [-6.69425774, 3.43749118, -0.691099107, 1.59547663], [1.29743183, -2.00156212, 3.08750296, 2.39243269]]]> : tensor<2x4x4xf32> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) + %0:4 = stablehlo.custom_call @lapack_sgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x3xf32>, tensor<2xi32>, tensor<4288xf32>) loc(#loc2) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) + %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) + %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) + %cst_6 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) + %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) + %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) + %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) + %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc2) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) + %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc2) + return %6, %10 : tensor<2x4x4xf32>, tensor<2x3xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) +#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xed\xa17\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x04\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b/\x1foO\x01\x05\x0b\x0f\x033\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x17\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\x96\t\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1d\x01\x03\x01\x1dC\x1dE\x1dG\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x02h\x85a\xc0)\'B\xbfgt\x9b@\x8e\xc7\xc0\xc0P\xf7\x05\xc04\xc9\xa3\xbfN\xcft\xbf\xb6D\x8f\xbf\xf8G\xa0\xc0+]\x03?N\xf3a@\xc9k:\xc0:\xd5\t@\xd4\xbf"?]\xc0Y\xbe\x06E\x0e?f\x82\x89@\x8f,w\xc0G\xf9\x18\xc0\xd1\x94\x85@3E\x05@\n\x11\x85\xbf\\\xe3\x82@P:\xa7@\\7\xd6\xc0\xdb\xff[@\xdf\xeb0\xbf\x948\xcc??\x12\xa6?\x98\x19\x00\xc0\xa6\x99E@\x9e\x1d\x19@\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\t\x00\x00\xc0\x7f\x1f/1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\t)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1b\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03\'\x05\x13\x19\x05\x07\x01\x11\x03)\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03-\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x031\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x033\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_sgehrd\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_30["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dgehrd'], + serialized_date=datetime.date(2024, 8, 30), + inputs=(), + expected_outputs=(array([[[ 0.9307390587491866 , -0.35692982324474015 , + -0.1271353200176119 , -0.43952156917870067 ], + [ 2.2633695323673964 , 0.9965090965971986 , + -1.3244131008423046 , 1.7324542351344163 ], + [ 0.24558316247256504 , 2.922776762811796 , + 3.630059093036474 , 1.4330664619737252 ], + [-0.2856727718012896 , -0.4601276537179077 , + -2.8602148466873802 , 1.9928744545245372 ]], + + [[-0.5351339571818844 , 5.753313169426148 , + 0.1385440281649789 , 2.8445493054193807 ], + [ 4.676815781213274 , 2.920688567170204 , + -2.610159425457712 , 4.0359806870679655 ], + [-0.16963242599901043 , -2.342935131066633 , + 4.179999589709703 , -0.6810604472011716 ], + [ 0.030645999613174775, -0.2271804227402005 , + -2.2755242550977153 , 0.7136684502626782 ]]]), array([[1.751436143556826 , 1.6505497938190505, 0. ], + [1.9422862513069978, 1.9018440331997255, 0. ]])), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[0.93073905874918661, 0.18483901505653183, -0.11804347408930886, -0.53725392025434981], [-1.700777672846173, 1.3531570270421245, -2.4375034855727518, 2.2945174202226699], [-0.97352780716312858, -0.8319788592736328, 2.4986640885328582, -2.8118637941861766], [1.1324489199416958, -1.9301638714393787, 1.5523821278819048, 2.7676215285832253]], [[-0.53513395718188439, -5.2137633671981938, 2.9644475919777618, 2.2891023676266191], [-4.4068992105328642, 1.2751848926168665, -2.8947257279736456, -2.6817410994805888], [1.5408926111334784, -0.85423691880254915, 6.4217874587762065, -0.43997818045540715], [-0.27837952612324207, 1.1509460853774549, -0.21686805683301608, 0.11738425574951133]]]> : tensor<2x4x4xf64> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) + %0:4 = stablehlo.custom_call @lapack_dgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x3xf64>, tensor<2xi32>, tensor<4288xf64>) loc(#loc2) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) + %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) + %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) + %cst_6 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) + %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) + %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) + %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) + %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc2) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) + %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc2) + return %6, %10 : tensor<2x4x4xf64>, tensor<2x3xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) +#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xed\xa17\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x033\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x17\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xa6\x0b\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1d\x01\x03\x01\x1dC\x1dE\x1dG\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x04\xa6\x00NG\x9d\xc8\xed?\xf2\xa8X\n\xce\xa8\xc7?#E\xb8\xdc\x188\xbe\xbf\xb8|$"/1\xe1\xbf\xc4B*\xa6b6\xfb\xbf\xe8\xf9\x97\xfb\x87\xa6\xf5?)^\xd3\xd3\x01\x80\x03\xc0T\xab\xff\xf2+[\x02@4d\xb0\xc9#\'\xef\xbf~e\xf1 \x92\x9f\xea\xbf\x96\x81\xff\x98C\xfd\x03@W\xb0\xe6q\xb2~\x06\xc0F\xa48\xc2\x82\x1e\xf2?\xcc\x0b\xfc\x82\xf3\xe1\xfe\xbf\xdc\\b\xa4\x8e\xd6\xf8?\x8c\xc3\x87\xc1\x16$\x06@\x83h\xa2?\xd1\x1f\xe1\xbf\xdc\xcb\xbc\xc8\xe4\xda\x14\xc0\xe6\x00\x92L0\xb7\x07@Q8\xf1\xe6\x14P\x02@\t\x07\xc8/\xaa\xa0\x11\xc0\x8eH"F(g\xf4?\xf5Jd\xf6e(\x07\xc0\x9e\xddt\xad4t\x05\xc0\x1cv\xb7\x02\x7f\xa7\xf8?B^\xa9\xa9\xe8U\xeb\xbf\x1e:5\r\xe9\xaf\x19@\xa2\x9c\x00>\x9a(\xdc\xbf\xc1\xd1$\\\xf8\xd0\xd1\xbf}|BqFj\xf2?6\x8b\xd2\x1dU\xc2\xcb\xbfdk\x82\x03\xe5\x0c\xbe?\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f/1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x0b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1b\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03\'\x05\x13\x19\x05\x07\x01\x11\x03)\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03-\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x031\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x033\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_dgehrd\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_08_31 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_31["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zgehrd_ffi'], + serialized_date=datetime.date(2024, 8, 30), + inputs=(), + expected_outputs=(array([[[ 0.7137638961069523 +2.4533812415320035e+00j, + -0.3272236912989258 -3.2003874808591863e+00j, + -3.065817294924296 +1.6978219378771007e+00j, + -3.3971558164664 +2.6931967836060400e-01j], + [ 6.346214936866542 +0.0000000000000000e+00j, + 2.083218259144673 -1.2191838498692813e+00j, + 1.9552582313969427 -3.3216313521481879e+00j, + 2.7451664155727293 +2.5460553490974451e+00j], + [-0.16133388943502391 +3.6906265775683444e-01j, + -4.698636849217318 +0.0000000000000000e+00j, + 2.5396292124414077 -3.3038474840573420e+00j, + 2.5410992366186456 +4.1958389320867528e-01j], + [ 0.47396123039280513 +3.9524384493417053e-03j, + 0.058880409351504966-7.8934332132630333e-02j, + 0.9469634796174572 +0.0000000000000000e+00j, + -3.130422531669044 -8.8070401977461810e-01j]], + + [[-6.7065483048969465 -4.1981401054281309e-01j, + -0.21813268822330256 -3.8602920478381799e+00j, + -0.8248337528620167 -2.9073223456990824e+00j, + -3.597231249446879 +2.7626541679004930e+00j], + [-6.812126638479044 +0.0000000000000000e+00j, + -0.20651586628458585 -1.0948249928988512e+00j, + -1.6675586608354327 +4.2553627621795744e+00j, + -2.410110723267707 +3.6065122124698634e-01j], + [ 0.038235817369200516-3.7823713529009173e-01j, + -8.508141062606947 +0.0000000000000000e+00j, + 4.260708077719245 -6.8052584397204630e-02j, + 5.345997177836541 -1.1955161503390279e+00j], + [-0.18541509608158574 -1.2016051097247168e-01j, + -0.02698777746917469 -4.4847463691672246e-01j, + 6.149305574585603 +0.0000000000000000e+00j, + -2.483131585236393 +2.8524912589603817e+00j]]]), array([[1.2286220194325557+0.5121060656500841j , + 1.9529937219183482-0.23299856112387676j, + 1.5940499664125072-0.8044281430962614j ], + [1.6682114302246909-0.11372755955977935j, + 1.4075913155446236-0.6008708461880701j , + 1.5086928152468893-0.8609480935086589j ]])), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(0.71376389610695234,2.4533812415320035), (-1.0686093138739379,-1.885041510645256), (3.2629529488994033,-0.87160041258342402), (2.4332168907311504,3.4960248990882183)], [(-1.450884474619478,-3.249935163088522), (0.53920035905924757,-5.0056840575116066), (0.13157186736298554,2.5015499854549939), (-1.2451270607408882,0.24345856951924827)], [(2.457366083193417,-2.3532935513245605), (-0.37595429769485644,1.5729223427874068), (3.5877693970448052,-0.30904304334212157), (-1.685615117470264,2.6148811836470265)], [(-3.6826776618664727,-1.5711608241015744), (-0.12407609317204518,-4.7137561145212281), (1.3298255603911306,-1.6739172003954141), (-2.6345448161870149,-0.089008252847513236)]], [[(-6.7065483048969465,-0.41981401054281309), (-2.1586544949255457,0.34815132010709054), (-5.1462488701272413,3.440817752555807), (1.0301804086076078,-0.6994760434270566)], [(4.551940883969797,-0.77472653800638502), (4.4485186470774796,-0.0024458890677252756), (0.66610302132250898,2.5976571401862039), (-5.0693248202533674,-5.7405538897950699)], [(0.14148406399087146,-4.3279346473525058), (-2.353557113110897,2.0880432773400326), (-3.2524452107293618,-0.42398740171508631), (3.7200566224095519,-0.56951559566037058)], [(-2.2001612082232613,-1.2218661647417151), (0.72437359623190833,8.6381970213061301), (0.72314820631775734,0.058458198280771749), (0.37498718985014962,2.1160469724471378)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %0:3 = stablehlo.custom_call @lapack_zgehrd_ffi(%cst) {mhlo.backend_config = {high = 4 : i32, low = 1 : i32}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc2) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) + %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) + %cst_0 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) + %4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) + %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) + %cst_1 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) + %8 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) + %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) + return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) +#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xe5\x9b5\x01Q\x0f\x07\x0b\x0b\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0b\x0b\x0b\x0bo\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x10\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17O/\x1f\x0f\x0b\x0b/OoO\x01\x05\x0b\x0f\x031\x1b\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x17\x07\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02"\x0f\x1d?A\x1f\x05\x15\x05\x17\x03\x03\x05\x8d\x11\x03\x05\x05\x19\x03\x03\x05\x93\x03\x03\x07\x95\x03\t\x15\x17\x19\x0b\x1b\x0b\r\x1d\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b!Q#[%]\rg\'i\x05#\x05%\x05\'\x05)\x03\x03\x07k\x03\x13-m/o1q3Q5s7u9\x7f;\x81=\x85\x05+\x05-\x05/\x051\x053\x055\x057\x059\x05;\x05=\x17C\xe6\n\x1b\x05?\x03\x03\x07\x8b\x03\x05I\x8fK\x91\x05A\x05C\x03\x03\x05\x97\x03\x03\x05\x99\x03\x01\x1dE\x1dG\x1dI\x1f\x1d1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05_c\r\x05SaUW\x1dK\r\x05SeUW\x1dM\x1dO\x1dQ\x1f\x05\x02\x08p\t\xdba\'\xd7\xe6?\xa8\xff\'X\x86\xa0\x03@\x0c\xa2t\x14\x06\x19\xf1\xbfT.}I!)\xfe\xbf\x0fG_\x13\x87\x1a\n@\xae:g\x8c&\xe4\xeb\xbf\xeb\x1e\xcej:w\x03@N\xaf\xfc\xe6\xdb\xf7\x0b@\x9f<\x8c\xa3\xd26\xf7\xbf^\xaf\xbc\x01\xde\xff\t\xc0b\xd4\x84\x1c!A\xe1?\xd6{\xa4\n\xd2\x05\x14\xc0\xf0\xe6\xb2\xd1X\xd7\xc0?2\xb5\x86\xa3,\x03\x04@\x91\xf2SZ\n\xec\xf3\xbf\x04\x10\x02\x81\xa6)\xcf?8\xec\x8c\x8c\xaf\xa8\x03@\r\x9d\xc6\x91\x8b\xd3\x02\xc0\xb0\xf6X\x9d\xa2\x0f\xd8\xbf\xbd\xb6V\x9e\xb0*\xf9?7-\x0fq\xc0\xb3\x0c@{|\ry\\\xc7\xd3\xbf\x04\xd9\xb2\x8eG\xf8\xfa\xbf\x9b\x84u\xd3F\xeb\x04@\xf4h\xbb\xb4\x1fv\r\xc0\xdc\\D\x88y#\xf9\xbf\x9a\xaecjs\xc3\xbf\xbf<\xc1\x04\xe2\xe2\xda\x12\xc0\x89<\xb4*\xf7F\xf5?\x1b\x90\xfef]\xc8\xfa\xbf\xdc\xf4\x8a;\x8c\x13\x05\xc0\xf8\xdd\r\xaf>\xc9\xb6\xbfvN\x1af\x81\xd3\x1a\xc0Z\xc6k\x95;\xde\xda\xbf\x87\x8c\xd8\xa5\xecD\x01\xc0\xdd\xd3zy\x1cH\xd6?\x04\x18\x89C\xc2\x95\x14\xc0\x8c\xc95u\xcb\x86\x0b@\x881\xbfs\x9e{\xf0?\x92Y[\x95\x1bb\xe6\xbf\x06\xe7\xb7\xfd/5\x12@L\x95\x02O\x8f\xca\xe8\xbf2`\xe3xH\xcb\x11@>\xda\xc6\xb1f\td\xbfZ\x1a\x8bH\xb7P\xe5?\xa8\x90zw\x00\xc8\x04@<(\xef\x15\xfdF\x14\xc0\xb4aF\xc2S\xf6\x16\xc0\xc1{\xdfY&\x1c\xc2?\xcfj\xa6\x19\xceO\x11\xc0\xc4\xa2p\xc0\x15\xd4\x02\xc0\xfcv\xa6\x08P\xb4\x00@^\xea\xa0\xfe\x01\x05\n\xc0^\x11\x12\x0e\x9c"\xdb\xbfR#\xe4\x0b\xad\xc2\r@F\x8b=\xc5x9\xe2\xbfZ\xf9\x99\x1e\xee\x99\x01\xc0My\x1a\x89\xc3\x8c\xf3\xbf\xd1\xdc<\x89\x11.\xe7?2\xd4\x8d\xc2\xc1F!@mw\t\xb5\x07$\xe7?G\x16\x99\xa3;\xee\xad?M\xd24E\xca\xff\xd7?\xa2\xae\xfb\x08\xaa\xed\x00@\x0b\x03\x1dS\x1dU\x05\x01\r\x05wy{}\x1dW\x13\x0b\x11\x1dY\x13\x0b\x05\x03\x03Y\x03\x03\x83\x15\x03\x01\x01\x01\x03\x07Y\x87\x89\x1f\x1f!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17\t\x00\x00\x00\x00\x1f#\x01\t\x07\x07\x01\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f-1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x0f\x01)\x05\t\r\x0f\x1b\x1d\x03\x1b\x13)\x01\x0f)\x03\t\x0b)\x01\x0b\x11\x01\x05\x05\t\x0b)\x03\r\x11)\x03\t\x11)\x03\x05\x11)\x03\x01\r)\x03\t\x07)\x07\t\x05\x05\x07)\x03\x05\r)\x07\t\x11\x11\x07)\x03\r\r)\x05\t\x05\x07)\x05\t\r\x07)\x03\t\r\x042\x02\x05\x01\x11\x03\x13\x07\x03\x01\x05\t\x11\x03\x1f\x07\x03#A\x05\x03\x03)\x03\x05\x0b\x07\x01+\x07\x05\t\x15\x03\x01\x05\x03\x01E\x03\x17\x03\x07\x01\t\x03\x15\x03\t\r\x07\x01G\x03%\x05\x07\x0b\x03\x07\x01\x0f\x03\'\x03\r\x05\x03\x01\x11\x03\x13\x03\x07\x01\t\x03\x05\x03\x11\x03\x07\x01M\x03+\x03\x0f\x07\x06\x01\x03\x05\x07\x15\x03\x13\x03\x07\x01\x0f\x03/\x03\r\x05\x03\x01\x11\x03\x13\x03\x07\x01\t\x03\t\x03\x1b\x03\x07\x01O\x031\x03\x19\x07\x06\x01\x03\t\x07\x1f\x05\x1d\x0f\x04\x03\x05\x17!\x06\x03\x01\x05\x01\x00\x82\n[\t\x0b%\x03\x0f\x0b\t\t\x11#!+\x1bi?\x1f/!)!)#\x1f\x19\x1f\x15\x1d\x15\x13%)9\x13\r+\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00broadcast_dimensions\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgehrd_ffi\x00high\x00low\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_31["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cgehrd_ffi'], + serialized_date=datetime.date(2024, 8, 30), + inputs=(), + expected_outputs=(array([[[ 5.2023945 -0.878671j , -2.8841915 -0.47488597j , + 1.3024182 +0.6651789j , 4.9291854 -1.9147056j ], + [ 6.3457894 +0.j , 1.6869383 -4.6557646j , + 0.88955224-1.7617276j , 2.9149916 +4.342665j ], + [-0.2465725 -0.5776757j , -5.3007755 +0.j , + -0.9786545 -0.0633831j , -1.3690261 -1.5921416j ], + [ 0.35462287+0.35993803j , -0.38403815-0.46558398j , + 2.8020499 +0.j , 0.5636822 -6.218306j ]], + + [[ 1.0687767 -3.88293j , -4.0144 -2.5885587j , + 5.3900986 -0.8850739j , 2.079677 +3.5515747j ], + [ 7.5675693 +0.j , 0.5971966 -3.6699948j , + 2.246994 -1.0858283j , -0.8870981 -0.022960603j], + [-0.2183232 +0.10552277j , 5.860886 +0.j , + -5.091036 +6.2841997j , 5.008773 +1.8765848j ], + [ 0.1378771 +0.427895j , 0.63263524-0.3470098j , + 6.4528017 +0.j , -4.233642 -0.84165764j ]]], + dtype=complex64), array([[1.0933675-0.3605358j , 1.1987956+0.5659744j , + 1.9999101-0.013409062j], + [1.4504763-0.44363326j , 1.3110259-0.07426627j , + 1.227255 +0.97383535j ]], dtype=complex64)), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(5.20239449,-0.87867099), (-0.211780012,-0.923053801), (-5.25181627,1.90887547), (-1.61342144,-1.98000157)], [(-5.924900e-01,2.28788424), (-1.74142945,-3.25563216), (3.08765078,-3.25260139), (-3.35189271,-0.571629047)], [(3.032444,3.44394636), (1.22205484,0.808871626), (2.58686161,-7.47011566), (1.9139297,-2.57945323)], [(-3.28396916,-1.68601465), (2.62759161,-0.953538239), (-2.78763294,-0.0429570749), (0.426534384,-0.211706176)]], [[(1.06877673,-3.882930e+00), (-0.0192247611,5.96663713), (1.15329504,-5.0599103), (-1.76508892,-1.98541296)], [(-3.40901089,3.35722542), (-6.13531398,2.55851483), (-4.8095789,0.164206699), (-0.247624069,-3.13545418)], [(2.04217815,-1.89123917), (-1.18974173,-1.69466627), (-2.28673625,-0.487834573), (3.01541853,-1.85637176)], [(-2.9499588,-4.23393869), (8.44624137,5.57274485), (-1.09048736,2.4864223), (-0.305431545,-0.298133373)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %0:3 = stablehlo.custom_call @lapack_cgehrd_ffi(%cst) {mhlo.backend_config = {high = 4 : i32, low = 1 : i32}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc2) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) + %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) + %cst_0 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) + %4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) + %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) + %cst_1 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) + %8 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) + %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) + return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) +#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xe5\x9b5\x01Q\x0f\x07\x0b\x0b\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0b\x0b\x0b\x0bo\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17O/\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x031\x1b\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x17\x07\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\x02\x0b\x1d?A\x1f\x05\x15\x05\x17\x03\x03\x05\x8d\x11\x03\x05\x05\x19\x03\x03\x05\x93\x03\x03\x07\x95\x03\t\x15\x17\x19\x0b\x1b\x0b\r\x1d\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b!Q#[%]\rg\'i\x05#\x05%\x05\'\x05)\x03\x03\x07k\x03\x13-m/o1q3Q5s7u9\x7f;\x81=\x85\x05+\x05-\x05/\x051\x053\x055\x057\x059\x05;\x05=\x17C\xe6\n\x1b\x05?\x03\x03\x07\x8b\x03\x05I\x8fK\x91\x05A\x05C\x03\x03\x05\x97\x03\x03\x05\x99\x03\x01\x1dE\x1dG\x1dI\x1f\x1d1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05_c\r\x05SaUW\x1dK\r\x05SeUW\x1dM\x1dO\x1dQ\x1f\x05\x02\x04\x04z\xa6@\x95\xf0`\xbf\xdc\xdcX\xbeAMl\xbf\xe1\x0e\xa8\xc0\x08V\xf4?\x98\x84\xce\xbf\xb1p\xfd\xbfm\xad\x17\xbf\xb2l\x12@)\xe7\xde\xbfG\\P\xc0\x12\x9cE@\x9f*P\xc0i\x85V\xc0HV\x12\xbf\x90\x13B@\x9ei\\@Kl\x9c?6\x12O?$\x8f%@0\x0b\xef\xc0\xa6\xfb\xf4?\xc3\x15%\xc0\x8d,R\xc0T\xcf\xd7\xbfv*(@\x15\x1bt\xbf\x94h2\xc0\xc2\xf3/\xbd\xb7b\xda>\x81\xc9X\xbe\xad\xcd\x88?\xed\x81x\xc0?}\x9d\xbc\xb1\xee\xbe@,\x9f\x93?\xc9\xea\xa1\xc0o\xee\xe1\xbf\x03"\xfe\xbf<-Z\xc0\xc8\xdcV@~T\xc4\xc0\xb5\xbe#@\x12\xe8\x99\xc0\xcd%(>*\x91}\xbeH\xabH\xc0\x0c\xb3\x02@ \x14\xf2\xbfuI\x98\xbf\xd3\xea\xd8\xbf\xe3Y\x12\xc0t\xc5\xf9\xbe\x9e\xfc@@\x97\x9d\xed\xbf \xcc<\xc0m|\x87\xc0\xce#\x07A\xedS\xb2@\x17\x95\x8b\xbf\x8b!\x1f@\x86a\x9c\xbe\xf0\xa4\x98\xbe\x0b\x03\x1dS\x1dU\x05\x01\r\x05wy{}\x1dW\x13\x0b\x11\x1dY\x13\x0b\x05\x03\x03Y\x03\x03\x83\x15\x03\x01\x01\x01\x03\x07Y\x87\x89\x1f\x1f!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17\t\x00\x00\x00\x00\x1f#\x01\t\x07\x07\x01\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f-1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x0f\x01)\x05\t\r\x0f\x1b\x1d\x03\x1b\x13)\x01\x0f)\x03\t\x0b)\x01\x0b\x11\x01\x05\x05\t\t)\x03\r\x11)\x03\t\x11)\x03\x05\x11)\x03\x01\r)\x03\t\x07)\x07\t\x05\x05\x07)\x03\x05\r)\x07\t\x11\x11\x07)\x03\r\r)\x05\t\x05\x07)\x05\t\r\x07)\x03\t\r\x042\x02\x05\x01\x11\x03\x13\x07\x03\x01\x05\t\x11\x03\x1f\x07\x03#A\x05\x03\x03)\x03\x05\x0b\x07\x01+\x07\x05\t\x15\x03\x01\x05\x03\x01E\x03\x17\x03\x07\x01\t\x03\x15\x03\t\r\x07\x01G\x03%\x05\x07\x0b\x03\x07\x01\x0f\x03\'\x03\r\x05\x03\x01\x11\x03\x13\x03\x07\x01\t\x03\x05\x03\x11\x03\x07\x01M\x03+\x03\x0f\x07\x06\x01\x03\x05\x07\x15\x03\x13\x03\x07\x01\x0f\x03/\x03\r\x05\x03\x01\x11\x03\x13\x03\x07\x01\t\x03\t\x03\x1b\x03\x07\x01O\x031\x03\x19\x07\x06\x01\x03\t\x07\x1f\x05\x1d\x0f\x04\x03\x05\x17!\x06\x03\x01\x05\x01\x00\x82\n[\t\x0b%\x03\x0f\x0b\t\t\x11#!+\x1bi?\x1f/!)!)#\x1f\x19\x1f\x15\x1d\x15\x13%)9\x13\r+\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00broadcast_dimensions\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_cgehrd_ffi\x00high\x00low\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_31["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_sgehrd_ffi'], + serialized_date=datetime.date(2024, 8, 30), + inputs=(), + expected_outputs=(array([[[-3.5237675 , -6.1161256 , -0.549011 , -4.7706876 ], + [ 5.8401766 , 3.424213 , 0.3059119 , 2.3492367 ], + [ 0.63135445 , 2.7238827 , -0.106214404, -0.82470125 ], + [-0.27146497 , 0.09917235 , 0.2545611 , -0.5113605 ]], + + [[ 4.297168 , -1.8758869 , 0.33528137 , 5.867136 ], + [-7.129698 , -3.3118155 , -1.3492918 , -2.8959117 ], + [-0.7266852 , -3.506432 , 4.77164 , -4.0780373 ], + [ 0.14084078 , 0.3389384 , 2.3910007 , -0.79807365 ]]], + dtype=float32), array([[1.3584172, 1.9805213, 0. ], + [1.2920669, 1.7939165, 0. ]], dtype=float32)), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[-3.52376747, -0.758410036, 4.85795927, -6.0243597], [-2.09321976, -1.27957773, -0.956288218, -1.11928439], [-5.00878525, 0.51314038, 3.53047514, -2.91282868], [2.15363932, 0.635739565, -0.21264787, 0.555740714]], [[4.29716778, -3.86209464, -2.39021468, 4.17441607], [2.08234859, -1.03958249, 4.09025383, 5.22586823], [-6.69425774, 3.43749118, -0.691099107, 1.59547663], [1.29743183, -2.00156212, 3.08750296, 2.39243269]]]> : tensor<2x4x4xf32> loc(#loc) + %0:3 = stablehlo.custom_call @lapack_sgehrd_ffi(%cst) {mhlo.backend_config = {high = 4 : i32, low = 1 : i32}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x3xf32>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc2) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) + %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) + %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) + %4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) + %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) + %cst_1 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) + %8 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc2) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) + %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc2) + return %6, %10 : tensor<2x4x4xf32>, tensor<2x3xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) +#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xe3\x9b3\x01Q\x0f\x07\x0b\x0b\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0b\x0b\x0b\x0bo\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x04\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17O/\x1f\x0f\x0b\x0b/\x1foO\x01\x05\x0b\x0f\x03/\x1b\x07\x17\x07\x07\x07\x07\x0f\x13\x0f\x17\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xea\x08\x1d?A\x1f\x05\x15\x05\x17\x03\x03\x05\x8d\x11\x03\x05\x05\x19\x03\x03\x05\x93\x03\x03\x07\x95\x03\t\x15\x17\x19\x0b\x1b\x0b\r\x1d\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b!Q#[%]\rg\'i\x05#\x05%\x05\'\x05)\x03\x03\x07k\x03\x13-m/o1q3Q5s7u9\x7f;\x81=\x85\x05+\x05-\x05/\x051\x053\x055\x057\x059\x05;\x05=\x17C\xe6\n\x1b\x05?\x03\x03\x07\x8b\x03\x05I\x8fK\x91\x05A\x05C\x03\x03\x05\x97\x03\x03\x05\x99\x03\x01\x1dE\x1dG\x1dI\x1f\x1b1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05_c\r\x05SaUW\x1dK\r\x05SeUW\x1dM\x1dO\x1dQ\x1f\x05\x02\x02h\x85a\xc0)\'B\xbfgt\x9b@\x8e\xc7\xc0\xc0P\xf7\x05\xc04\xc9\xa3\xbfN\xcft\xbf\xb6D\x8f\xbf\xf8G\xa0\xc0+]\x03?N\xf3a@\xc9k:\xc0:\xd5\t@\xd4\xbf"?]\xc0Y\xbe\x06E\x0e?f\x82\x89@\x8f,w\xc0G\xf9\x18\xc0\xd1\x94\x85@3E\x05@\n\x11\x85\xbf\\\xe3\x82@P:\xa7@\\7\xd6\xc0\xdb\xff[@\xdf\xeb0\xbf\x948\xcc??\x12\xa6?\x98\x19\x00\xc0\xa6\x99E@\x9e\x1d\x19@\x0b\x03\x1dS\x1dU\x05\x01\r\x05wy{}\x1dW\x13\x0b\x11\x1dY\x13\x0b\x05\x03\x03Y\x03\x03\x83\x15\x03\x01\x01\x01\x03\x07Y\x87\x89\x1f\x1d!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17\t\x00\x00\x00\x00\x1f!\x01\t\x07\x07\x01\x1f\'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\t\x00\x00\xc0\x7f\x1f+1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x0f\x01)\x05\t\r\x0f\x1b\x1d\t\x13)\x01\x0f)\x03\t\x0b)\x01\x0b\x11\x01\x05\x05\t)\x03\r\x11)\x03\t\x11)\x03\x05\x11)\x03\x01\r)\x03\t\x07)\x07\t\x05\x05\x07)\x03\x05\r)\x07\t\x11\x11\x07)\x03\r\r)\x05\t\x05\x07)\x05\t\r\x07)\x03\t\r\x042\x02\x05\x01\x11\x03\x13\x07\x03\x01\x05\t\x11\x03\x1f\x07\x03#A\x05\x03\x03)\x03\x05\x0b\x07\x01+\x07\x05\t\x15\x03\x01\x05\x03\x01E\x03\x17\x03\x07\x01\t\x03\x15\x03\t\r\x07\x01G\x03#\x05\x07\x0b\x03\x07\x01\x0f\x03%\x03\r\x05\x03\x01\x11\x03\x13\x03\x07\x01\t\x03\x05\x03\x11\x03\x07\x01M\x03)\x03\x0f\x07\x06\x01\x03\x05\x07\x15\x03\x13\x03\x07\x01\x0f\x03-\x03\r\x05\x03\x01\x11\x03\x13\x03\x07\x01\t\x03\t\x03\x1b\x03\x07\x01O\x03/\x03\x19\x07\x06\x01\x03\t\x07\x1f\x05\x1d\x0f\x04\x03\x05\x17!\x06\x03\x01\x05\x01\x00\x82\n[\t\x0b%\x03\x0f\x0b\t\t\x11#!+\x1bi?\x1f/!)!)#\x1f\x19\x1f\x15\x1d\x15\x13%)9\x13\r+\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00broadcast_dimensions\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_sgehrd_ffi\x00high\x00low\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_31["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dgehrd_ffi'], + serialized_date=datetime.date(2024, 8, 30), + inputs=(), + expected_outputs=(array([[[ 0.9307390587491866 , -0.35692982324474015 , + -0.1271353200176119 , -0.43952156917870067 ], + [ 2.2633695323673964 , 0.9965090965971986 , + -1.3244131008423046 , 1.7324542351344163 ], + [ 0.24558316247256504 , 2.922776762811796 , + 3.630059093036474 , 1.4330664619737252 ], + [-0.2856727718012896 , -0.4601276537179077 , + -2.8602148466873802 , 1.9928744545245372 ]], + + [[-0.5351339571818844 , 5.753313169426148 , + 0.1385440281649789 , 2.8445493054193807 ], + [ 4.676815781213274 , 2.920688567170204 , + -2.610159425457712 , 4.0359806870679655 ], + [-0.16963242599901043 , -2.342935131066633 , + 4.179999589709703 , -0.6810604472011716 ], + [ 0.030645999613174775, -0.2271804227402005 , + -2.2755242550977153 , 0.7136684502626782 ]]]), array([[1.751436143556826 , 1.6505497938190505, 0. ], + [1.9422862513069978, 1.9018440331997255, 0. ]])), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[0.93073905874918661, 0.18483901505653183, -0.11804347408930886, -0.53725392025434981], [-1.700777672846173, 1.3531570270421245, -2.4375034855727518, 2.2945174202226699], [-0.97352780716312858, -0.8319788592736328, 2.4986640885328582, -2.8118637941861766], [1.1324489199416958, -1.9301638714393787, 1.5523821278819048, 2.7676215285832253]], [[-0.53513395718188439, -5.2137633671981938, 2.9644475919777618, 2.2891023676266191], [-4.4068992105328642, 1.2751848926168665, -2.8947257279736456, -2.6817410994805888], [1.5408926111334784, -0.85423691880254915, 6.4217874587762065, -0.43997818045540715], [-0.27837952612324207, 1.1509460853774549, -0.21686805683301608, 0.11738425574951133]]]> : tensor<2x4x4xf64> loc(#loc) + %0:3 = stablehlo.custom_call @lapack_dgehrd_ffi(%cst) {mhlo.backend_config = {high = 4 : i32, low = 1 : i32}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x3xf64>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc2) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) + %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) + %cst_0 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) + %4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) + %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) + %cst_1 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) + %8 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc2) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) + %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc2) + return %6, %10 : tensor<2x4x4xf64>, tensor<2x3xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) +#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xe3\x9b3\x01Q\x0f\x07\x0b\x0b\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0b\x0b\x0b\x0bo\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17O/\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x03/\x1b\x07\x17\x07\x07\x07\x07\x0f\x13\x0f\x17\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xfa\n\x1d?A\x1f\x05\x15\x05\x17\x03\x03\x05\x8d\x11\x03\x05\x05\x19\x03\x03\x05\x93\x03\x03\x07\x95\x03\t\x15\x17\x19\x0b\x1b\x0b\r\x1d\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b!Q#[%]\rg\'i\x05#\x05%\x05\'\x05)\x03\x03\x07k\x03\x13-m/o1q3Q5s7u9\x7f;\x81=\x85\x05+\x05-\x05/\x051\x053\x055\x057\x059\x05;\x05=\x17C\xe6\n\x1b\x05?\x03\x03\x07\x8b\x03\x05I\x8fK\x91\x05A\x05C\x03\x03\x05\x97\x03\x03\x05\x99\x03\x01\x1dE\x1dG\x1dI\x1f\x1b1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05_c\r\x05SaUW\x1dK\r\x05SeUW\x1dM\x1dO\x1dQ\x1f\x05\x02\x04\xa6\x00NG\x9d\xc8\xed?\xf2\xa8X\n\xce\xa8\xc7?#E\xb8\xdc\x188\xbe\xbf\xb8|$"/1\xe1\xbf\xc4B*\xa6b6\xfb\xbf\xe8\xf9\x97\xfb\x87\xa6\xf5?)^\xd3\xd3\x01\x80\x03\xc0T\xab\xff\xf2+[\x02@4d\xb0\xc9#\'\xef\xbf~e\xf1 \x92\x9f\xea\xbf\x96\x81\xff\x98C\xfd\x03@W\xb0\xe6q\xb2~\x06\xc0F\xa48\xc2\x82\x1e\xf2?\xcc\x0b\xfc\x82\xf3\xe1\xfe\xbf\xdc\\b\xa4\x8e\xd6\xf8?\x8c\xc3\x87\xc1\x16$\x06@\x83h\xa2?\xd1\x1f\xe1\xbf\xdc\xcb\xbc\xc8\xe4\xda\x14\xc0\xe6\x00\x92L0\xb7\x07@Q8\xf1\xe6\x14P\x02@\t\x07\xc8/\xaa\xa0\x11\xc0\x8eH"F(g\xf4?\xf5Jd\xf6e(\x07\xc0\x9e\xddt\xad4t\x05\xc0\x1cv\xb7\x02\x7f\xa7\xf8?B^\xa9\xa9\xe8U\xeb\xbf\x1e:5\r\xe9\xaf\x19@\xa2\x9c\x00>\x9a(\xdc\xbf\xc1\xd1$\\\xf8\xd0\xd1\xbf}|BqFj\xf2?6\x8b\xd2\x1dU\xc2\xcb\xbfdk\x82\x03\xe5\x0c\xbe?\x0b\x03\x1dS\x1dU\x05\x01\r\x05wy{}\x1dW\x13\x0b\x11\x1dY\x13\x0b\x05\x03\x03Y\x03\x03\x83\x15\x03\x01\x01\x01\x03\x07Y\x87\x89\x1f\x1d!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17\t\x00\x00\x00\x00\x1f!\x01\t\x07\x07\x01\x1f\'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f+1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x0f\x01)\x05\t\r\x0f\x1b\x1d\x0b\x13)\x01\x0f)\x03\t\x0b)\x01\x0b\x11\x01\x05\x05\t)\x03\r\x11)\x03\t\x11)\x03\x05\x11)\x03\x01\r)\x03\t\x07)\x07\t\x05\x05\x07)\x03\x05\r)\x07\t\x11\x11\x07)\x03\r\r)\x05\t\x05\x07)\x05\t\r\x07)\x03\t\r\x042\x02\x05\x01\x11\x03\x13\x07\x03\x01\x05\t\x11\x03\x1f\x07\x03#A\x05\x03\x03)\x03\x05\x0b\x07\x01+\x07\x05\t\x15\x03\x01\x05\x03\x01E\x03\x17\x03\x07\x01\t\x03\x15\x03\t\r\x07\x01G\x03#\x05\x07\x0b\x03\x07\x01\x0f\x03%\x03\r\x05\x03\x01\x11\x03\x13\x03\x07\x01\t\x03\x05\x03\x11\x03\x07\x01M\x03)\x03\x0f\x07\x06\x01\x03\x05\x07\x15\x03\x13\x03\x07\x01\x0f\x03-\x03\r\x05\x03\x01\x11\x03\x13\x03\x07\x01\t\x03\t\x03\x1b\x03\x07\x01O\x03/\x03\x19\x07\x06\x01\x03\t\x07\x1f\x05\x1d\x0f\x04\x03\x05\x17!\x06\x03\x01\x05\x01\x00\x82\n[\t\x0b%\x03\x0f\x0b\t\t\x11#!+\x1bi?\x1f/!)!)#\x1f\x19\x1f\x15\x1d\x15\x13%)9\x13\r+\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00broadcast_dimensions\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_dgehrd_ffi\x00high\x00low\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py index 896ecad019e2..56479e82f9d9 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py @@ -15,7 +15,7 @@ # ruff: noqa import datetime -from numpy import array, float32 +from numpy import array, float32, complex64 data_2023_03_17=dict( # Pasted from the test output (see back_compat_test.py module docstring) @@ -1409,3 +1409,342 @@ xla_call_module_version=4, ) # End paste ) + +data_2024_09_30 = {} + +data_2024_09_30["f32"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_syevd_ffi'], + serialized_date=datetime.date(2024, 9, 30), + inputs=(), + expected_outputs=(array([[ 0.7941186 , -0.3696443 , -0.40418202 , 0.26339266 ], + [ 0.3696443 , 0.7941186 , 0.26339266 , 0.4041819 ], + [-0.054829806, -0.47930413 , 0.6857606 , 0.5449713 ], + [-0.4793042 , 0.05482992 , -0.5449712 , 0.68576056 ]], + dtype=float32), array([-3.7082872e+00, -4.0793765e-07, 4.4458108e-07, 3.3708286e+01], + dtype=float32)), + mlir_module_text=r""" +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:27) +#loc13 = loc("jit()/jit(main)/pjit"(#loc6)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<16xf32> loc(#loc8) + %1 = stablehlo.reshape %0 : (tensor<16xf32>) -> tensor<4x4xf32> loc(#loc9) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc10) + %3 = stablehlo.add %1, %2 : tensor<4x4xf32> loc(#loc11) + %4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc12) + %5 = stablehlo.divide %3, %4 : tensor<4x4xf32> loc(#loc12) + %6 = call @tril(%5) : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc13) + %7:3 = stablehlo.custom_call @cusolver_syevd_ffi(%6) {mhlo.backend_config = {algorithm = 0 : ui8, lower = true}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4xf32>, tensor) loc(#loc14) + %8 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc14) + %9 = stablehlo.compare EQ, %7#2, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc14) + %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc14) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc14) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc14) + %13 = stablehlo.select %12, %7#0, %11 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc14) + %14 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc14) + %15 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4xf32> loc(#loc14) + %16 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc14) + %17 = stablehlo.select %16, %7#1, %15 : tensor<4xi1>, tensor<4xf32> loc(#loc14) + return %13, %17 : tensor<4x4xf32>, tensor<4xf32> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<4x4xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit"(#loc6))) -> (tensor<4x4xf32> {mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc15) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc16) + %2 = stablehlo.add %0, %1 : tensor<4x4xi32> loc(#loc16) + %3 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc15) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc17) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc18) + %6 = stablehlo.select %4, %arg0, %5 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc19) + return %6 : tensor<4x4xf32> loc(#loc13) + } loc(#loc13) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:15) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:14) +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:11) +#loc8 = loc("jit()/jit(main)/iota"(#loc1)) +#loc9 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc10 = loc("jit()/jit(main)/transpose"(#loc3)) +#loc11 = loc("jit()/jit(main)/add"(#loc4)) +#loc12 = loc("jit()/jit(main)/div"(#loc5)) +#loc14 = loc("jit()/jit(main)/eigh"(#loc7)) +#loc15 = loc("jit()/jit(main)/jit(tril)/iota"(#loc6)) +#loc16 = loc("jit()/jit(main)/jit(tril)/add"(#loc6)) +#loc17 = loc("jit()/jit(main)/jit(tril)/ge"(#loc6)) +#loc18 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim"(#loc6)) +#loc19 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc6)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01-\x05\x01\x05\x1d\x01\x03\x0b\x03\x1b\x0f\x13\x17\x1b\x1f#'+/37;?\x03\xfb\xb17\x01U\x0f\x07\x0b\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0b\x17\x13\x0b\x0b\x17\x03]\x0f\x0b\x0b\x0b\x0b\x0f\x0b\x1f\x0f\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x1f\x0f\x0b\x1f\x1fO\x1b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x17/\x0f\x0bO/\x01\x05\x0b\x0f\x033\x17\x0f\x0f\x07\x07\x07\x13\x17\x07\x07\x17\x13\x17\x17\x13\x13\x07\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xae\x06\x1dQS\x1f\x05!\x17\x05J\x047\x1d\x1f\x07\x11\x03\x05\x1d!\x07\x1d#\x07\x1dIK\x03\x07\x15\x17\x19\x0b\x1b\x0b\x05#\x11\x01\x00\x05%\x05'\x05)\x05+\x05-\x05/\x1d'\x07\x051\x1d+\x07\x053\x1d/\x07\x055\x1d35\x057\x17\x05*\x045\x1d9;\x059\x17\x05*\x04\x1d\x1d?A\x05;\x17\x052\x04E\x1dEG\x05=\x17\x052\x04\x1f\x05?\x17\x052\x04\x1d\x03\x03O\x8d\x05A\x05C\x17\x05J\x04\x17\x1f!\x01\x1dE\x1dG\x03\x01\x1dI\x03\x03{\x1dK\x1f\t\t\x00\x00\x00\x00\x13\x0b\x01\t\x07\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05os\r\x05]qWY\x1dM\r\x05]uWY\x1dO\x1dQ\x1dS\r\x03WY#\x1f\x1dU\x1f\x07\t\x00\x00\x00\x00\x13\x0b\x05\x07\x05\x1f\x07\t\x00\x00\xc0\x7f\x1f\x07\t\x00\x00\x00@\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x05\x8f\x91\x93\x95\x1dW\x13%\x00\x1dY\x05\x03\x0b\x03\x1d[\x1d]\x05\x01\x03\x03i\x03\x03\xa3\x15\x03\x01\x01\x01\x03\x07i\xa7\xa9\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x01\x07\x01\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x0f)\x01\x0f)\x01\x17\x1d\x01\t)\x03\x11\x0f)\x05\x11\x11\x17\x13\x1b)\x05\x11\x11\r)\x03\t\x0b\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03A\x0f!)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03\x11\r)\x03\x05\x0b\x04b\x04\x05\x01Q\x03\x13\x01\x07\x04:\x04\x03\x01\t\x0bP\x03\x03\x07\x04\xba\x02\x03/Y\x05B\x03\x05\x03\x07\x05B\x03\x07\x03\t\x05B\x03\t\x03\x07\x07B1\x0b\x03#\x13\x067\x03\x05\x03\x07\x15F=\r\x03\x05\x03\t\r\x06C\x03\x05\x05\t\x0b\x03F\x11\x0f\x03\x05\x03\x05\x17\x06\x11\x03\x05\x05\r\x0f\x19F\t\x11\x03\x05\x03\x11\x1bG\x01M\x13\x07\x05\x11\t\x03\x13\x03F\x01\x0f\x03\t\x03\x03\x0fF\x01\x15\x03-\x05\x19\x1b\x03F\x01\x0f\x03/\x03\x1d\x03F\x01\x0f\x03\x05\x03\x01\x03F\x01\x17\x03\x19\x03\x1f\t\x06\x01\x03\x05\x07#\x15!\x03F\x01\x0f\x031\x03\x1d\x03F\x01\x0f\x03\x11\x03\x01\x03F\x01\x19\x033\x03'\t\x06\x01\x03\x11\x07+\x17)\x11\x04\x03\x05%-\x0bP\t\x1b\x07\x04\x9d\x03\x15+\x03\x0b\t\x00\x05B\x03\x1d\x03\x07\x05B\x03\x07\x03\t\x07B\r\x0b\x03\x13\x03F\x0f\x0f\x03\x13\x03\x05\r\x06\x0f\x03\x13\x05\x07\t\x07B\r\x1f\x03\x13\x0fF%!\x03\x19\x05\x0b\r\x03F)\x0f\x03\x05\x03\x03\t\x06-\x03\x05\x07\x0f\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xe6\r_'\x03\r\x15\x11\x0f\x0b\t\t\x0b!\x11#;)99EA;WgKMO;\x1b%)9i\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit\x00jit()/jit(main)/jit(tril)/iota\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00jit()/jit(main)/transpose\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00mhlo.backend_config\x00jit()/jit(main)/eigh\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00algorithm\x00lower\x00\x00cusolver_syevd_ffi\x00\x08k#\x05;\x01\x0b[kmwy\x03\x87\x03c\x03\x89\x03e\x03\x8b\x03U\x03a\x11\x97\x99\x9b[\x9d\x9f\xa1\xa5\x05g\xab\x03\xad\x03\xaf\x0b_}_a\x7f\x03\x81\x03\x83\x05g\x85", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_30["f64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_syevd_ffi'], + serialized_date=datetime.date(2024, 9, 30), + inputs=(), + expected_outputs=(array([[ 0.7941185704969033 , -0.36964433974346045, -0.4041819665640973 , + 0.2633926650306618 ], + [ 0.3696443397434605 , 0.7941185704969035 , 0.2633926650306616 , + 0.4041819665640974 ], + [-0.05482989100998295, -0.47930412176342563, 0.6857605696309688 , + 0.544971268097533 ], + [-0.47930412176342574, 0.05482989100998273, -0.544971268097533 , + 0.6857605696309688 ]]), array([-3.7082869338697053e+00, 7.7329581044653176e-17, + 8.6623770428558249e-16, 3.3708286933869694e+01])), + mlir_module_text=r""" +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:27) +#loc13 = loc("jit()/jit(main)/pjit"(#loc6)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<16xf64> loc(#loc8) + %1 = stablehlo.reshape %0 : (tensor<16xf64>) -> tensor<4x4xf64> loc(#loc9) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc10) + %3 = stablehlo.add %1, %2 : tensor<4x4xf64> loc(#loc11) + %4 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc12) + %5 = stablehlo.divide %3, %4 : tensor<4x4xf64> loc(#loc12) + %6 = call @tril(%5) : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc13) + %7:3 = stablehlo.custom_call @cusolver_syevd_ffi(%6) {mhlo.backend_config = {algorithm = 0 : ui8, lower = true}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4xf64>, tensor) loc(#loc14) + %8 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc14) + %9 = stablehlo.compare EQ, %7#2, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc14) + %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc14) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc14) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc14) + %13 = stablehlo.select %12, %7#0, %11 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc14) + %14 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc14) + %15 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4xf64> loc(#loc14) + %16 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc14) + %17 = stablehlo.select %16, %7#1, %15 : tensor<4xi1>, tensor<4xf64> loc(#loc14) + return %13, %17 : tensor<4x4xf64>, tensor<4xf64> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<4x4xf64> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit"(#loc6))) -> (tensor<4x4xf64> {mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc15) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc16) + %2 = stablehlo.add %0, %1 : tensor<4x4xi32> loc(#loc16) + %3 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc15) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc17) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc18) + %6 = stablehlo.select %4, %arg0, %5 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc19) + return %6 : tensor<4x4xf64> loc(#loc13) + } loc(#loc13) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:15) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:14) +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:11) +#loc8 = loc("jit()/jit(main)/iota"(#loc1)) +#loc9 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc10 = loc("jit()/jit(main)/transpose"(#loc3)) +#loc11 = loc("jit()/jit(main)/add"(#loc4)) +#loc12 = loc("jit()/jit(main)/div"(#loc5)) +#loc14 = loc("jit()/jit(main)/eigh"(#loc7)) +#loc15 = loc("jit()/jit(main)/jit(tril)/iota"(#loc6)) +#loc16 = loc("jit()/jit(main)/jit(tril)/add"(#loc6)) +#loc17 = loc("jit()/jit(main)/jit(tril)/ge"(#loc6)) +#loc18 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim"(#loc6)) +#loc19 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc6)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01-\x05\x01\x05\x1d\x01\x03\x0b\x03\x1b\x0f\x13\x17\x1b\x1f#'+/37;?\x03\xfb\xb17\x01U\x0f\x07\x0b\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0b\x17\x13\x0b\x0b\x17\x03]\x0f\x0b\x0b\x0b\x0b\x0f\x0b\x1f\x0f\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b/\x0f\x0b//O\x1b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x17/\x0f\x0bO/\x01\x05\x0b\x0f\x033\x17\x0f\x0f\x07\x07\x07\x13\x17\x07\x07\x17\x13\x17\x17\x13\x13\x07\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xde\x06\x1dQS\x1f\x05!\x17\x05J\x047\x1d\x1f\x07\x11\x03\x05\x1d!\x07\x1d#\x07\x1dIK\x03\x07\x15\x17\x19\x0b\x1b\x0b\x05#\x11\x01\x00\x05%\x05'\x05)\x05+\x05-\x05/\x1d'\x07\x051\x1d+\x07\x053\x1d/\x07\x055\x1d35\x057\x17\x05*\x045\x1d9;\x059\x17\x05*\x04\x1d\x1d?A\x05;\x17\x052\x04E\x1dEG\x05=\x17\x052\x04\x1f\x05?\x17\x052\x04\x1d\x03\x03O\x8d\x05A\x05C\x17\x05J\x04\x17\x1f!\x01\x1dE\x1dG\x03\x01\x1dI\x03\x03{\x1dK\x1f\t\t\x00\x00\x00\x00\x13\x0b\x01\t\x07\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05os\r\x05]qWY\x1dM\r\x05]uWY\x1dO\x1dQ\x1dS\r\x03WY#\x1f\x1dU\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x05\x8f\x91\x93\x95\x1dW\x13%\x00\x1dY\x05\x03\x0b\x03\x1d[\x1d]\x05\x01\x03\x03i\x03\x03\xa3\x15\x03\x01\x01\x01\x03\x07i\xa7\xa9\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x01\x07\x01\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x0f)\x01\x0f)\x01\x17\x1d\x01\x0b)\x03\x11\x0f)\x05\x11\x11\x17\x13\x1b)\x05\x11\x11\r)\x03\t\x0b\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03A\x0f!)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03\x11\r)\x03\x05\x0b\x04b\x04\x05\x01Q\x03\x13\x01\x07\x04:\x04\x03\x01\t\x0bP\x03\x03\x07\x04\xba\x02\x03/Y\x05B\x03\x05\x03\x07\x05B\x03\x07\x03\t\x05B\x03\t\x03\x07\x07B1\x0b\x03#\x13\x067\x03\x05\x03\x07\x15F=\r\x03\x05\x03\t\r\x06C\x03\x05\x05\t\x0b\x03F\x11\x0f\x03\x05\x03\x05\x17\x06\x11\x03\x05\x05\r\x0f\x19F\t\x11\x03\x05\x03\x11\x1bG\x01M\x13\x07\x05\x11\t\x03\x13\x03F\x01\x0f\x03\t\x03\x03\x0fF\x01\x15\x03-\x05\x19\x1b\x03F\x01\x0f\x03/\x03\x1d\x03F\x01\x0f\x03\x05\x03\x01\x03F\x01\x17\x03\x19\x03\x1f\t\x06\x01\x03\x05\x07#\x15!\x03F\x01\x0f\x031\x03\x1d\x03F\x01\x0f\x03\x11\x03\x01\x03F\x01\x19\x033\x03'\t\x06\x01\x03\x11\x07+\x17)\x11\x04\x03\x05%-\x0bP\t\x1b\x07\x04\x9d\x03\x15+\x03\x0b\t\x00\x05B\x03\x1d\x03\x07\x05B\x03\x07\x03\t\x07B\r\x0b\x03\x13\x03F\x0f\x0f\x03\x13\x03\x05\r\x06\x0f\x03\x13\x05\x07\t\x07B\r\x1f\x03\x13\x0fF%!\x03\x19\x05\x0b\r\x03F)\x0f\x03\x05\x03\x03\t\x06-\x03\x05\x07\x0f\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xe6\r_'\x03\r\x15\x11\x0f\x0b\t\t\x0b!\x11#;)99EA;WgKMO;\x1b%)9i\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit\x00jit()/jit(main)/jit(tril)/iota\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00jit()/jit(main)/transpose\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00mhlo.backend_config\x00jit()/jit(main)/eigh\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00algorithm\x00lower\x00\x00cusolver_syevd_ffi\x00\x08k#\x05;\x01\x0b[kmwy\x03\x87\x03c\x03\x89\x03e\x03\x8b\x03U\x03a\x11\x97\x99\x9b[\x9d\x9f\xa1\xa5\x05g\xab\x03\xad\x03\xaf\x0b_}_a\x7f\x03\x81\x03\x83\x05g\x85", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_30["c64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_syevd_ffi'], + serialized_date=datetime.date(2024, 9, 30), + inputs=(), + expected_outputs=(array([[ 0.79411864 +0.j, 0.3696443 +0.j, 0.40418214 +0.j, + -0.26339263 +0.j], + [ 0.3696443 +0.j, -0.7941186 +0.j, -0.26339272 +0.j, + -0.40418193 +0.j], + [-0.054829765+0.j, 0.47930422 +0.j, -0.6857606 +0.j, + -0.5449713 +0.j], + [-0.47930422 +0.j, -0.054829985+0.j, 0.5449712 +0.j, + -0.6857606 +0.j]], dtype=complex64), array([-3.7082872e+00, -2.9983883e-07, 3.5983098e-07, 3.3708286e+01], + dtype=float32)), + mlir_module_text=r""" +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:27) +#loc18 = loc("jit()/jit(main)/pjit"(#loc7)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %cst_0 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_1 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc11) + %3 = stablehlo.real %2 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc12) + %4 = stablehlo.imag %2 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc13) + %5 = stablehlo.negate %4 : tensor<4x4xf32> loc(#loc14) + %6 = stablehlo.complex %3, %5 : tensor<4x4xcomplex> loc(#loc15) + %7 = stablehlo.add %1, %6 : tensor<4x4xcomplex> loc(#loc16) + %8 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc17) + %9 = stablehlo.divide %7, %8 : tensor<4x4xcomplex> loc(#loc17) + %10 = call @tril(%9) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc18) + %11:3 = stablehlo.custom_call @cusolver_syevd_ffi(%10) {mhlo.backend_config = {algorithm = 0 : ui8, lower = true}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4xf32>, tensor) loc(#loc19) + %12 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc19) + %13 = stablehlo.compare EQ, %11#2, %12, SIGNED : (tensor, tensor) -> tensor loc(#loc19) + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc19) + %15 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc19) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc19) + %17 = stablehlo.select %16, %11#0, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc19) + %18 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> loc(#loc19) + %19 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4xf32> loc(#loc19) + %20 = stablehlo.broadcast_in_dim %18, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc19) + %21 = stablehlo.select %20, %11#1, %19 : tensor<4xi1>, tensor<4xf32> loc(#loc19) + return %17, %21 : tensor<4x4xcomplex>, tensor<4xf32> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit"(#loc7))) -> (tensor<4x4xcomplex> {mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc20) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc21) + %2 = stablehlo.add %0, %1 : tensor<4x4xi32> loc(#loc21) + %3 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc20) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc22) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc23) + %6 = stablehlo.select %4, %arg0, %5 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc24) + return %6 : tensor<4x4xcomplex> loc(#loc18) + } loc(#loc18) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:25) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:15) +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:14) +#loc8 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:11) +#loc9 = loc("jit()/jit(main)/iota"(#loc1)) +#loc10 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc11 = loc("jit()/jit(main)/transpose"(#loc3)) +#loc12 = loc("jit()/jit(main)/real"(#loc4)) +#loc13 = loc("jit()/jit(main)/imag"(#loc4)) +#loc14 = loc("jit()/jit(main)/neg"(#loc4)) +#loc15 = loc("jit()/jit(main)/complex"(#loc4)) +#loc16 = loc("jit()/jit(main)/add"(#loc5)) +#loc17 = loc("jit()/jit(main)/div"(#loc6)) +#loc19 = loc("jit()/jit(main)/eigh"(#loc8)) +#loc20 = loc("jit()/jit(main)/jit(tril)/iota"(#loc7)) +#loc21 = loc("jit()/jit(main)/jit(tril)/add"(#loc7)) +#loc22 = loc("jit()/jit(main)/jit(tril)/ge"(#loc7)) +#loc23 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim"(#loc7)) +#loc24 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc7)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x015\x05\x01\x05%\x01\x03\x0b\x03#\x0f\x13\x17\x1b\x1f#'+/37;?CGKO\x03*\x02\xc5=\x01g\x0f\x07\x0b\x17\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x17\x13\x0b\x0b\x17\x03_\x0f\x0b\x0b\x0b\x0b\x0f\x0b\x1f\x0f\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b/\x0f\x0b\x1f//O\x1b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x17/\x0f\x0bO/\x01\x05\x0b\x0f\x039\x17\x0f\x0f\x07\x07\x07\x13\x17\x0b\x17\x07\x07\x17\x0f\x13\x17\x17\x13\x13\x07\x13\x13\x13\x0f\x17\x13\x13\x13\x02\x86\x07\x1dce\x1f\x05)\x17\x05J\x047\x1d!\x07\x17\x052\x043\x11\x03\x05\x1d#\x07\x1d%\x07\x1d[]\x03\x07\x17\x19\x1b\r\x1d\r\x05+\x11\x01\x00\x05-\x05/\x051\x053\x055\x057\x1d)\x07\x059\x1d-\x07\x05;\x1d1\x07\x05=\x1d57\x05?\x17\x05*\x045\x1d;=\x05A\x17\x05*\x04\x1d\x1dAC\x05C\x17\x052\x04E\x1dG\x0b\x05E\x1dK\x0b\x05G\x1dO\x0b\x05I\x1dS\x0b\x05K\x1dWY\x05M\x17\x052\x04\x1f\x05O\x17\x052\x04\x1d\x03\x03a\xa1\x05Q\x05S\x17\x05J\x04\x17\x1f'\x01\x1dU\x1dW\x03\x01\x1dY\x03\x03\x8d\x1d[\x1f\t\t\x00\x00\x00\x00\x13\x0b\x01\t\x07\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00##\x03\x05\x81\x85\r\x05o\x83ik\x1d]\r\x05o\x87ik\x1d_\x1da\x1dc\r\x03ik#%\x1de\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x05\x07\x05\x1f\x1f\t\x00\x00\xc0\x7f\x1f\x07\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x07\x11\x00\x00\x00@\x00\x00\x00\x00\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x05\xa3\xa5\xa7\xa9\x1dg\x13+\x00\x1di\x05\x03\x0b\x03\x1dk\x1dm\x05\x01\x03\x03{\x03\x03\xb7\x15\x03\x01\x01\x01\x03\x07{\xbb\xbd\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f1\x01\x07\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x15)\x01\x15)\x01\x1b\x1d\x01\t)\x03\x11\x0f)\x05\x11\x11\x1b\x03\x0f)\x05\x11\x11\x0f\x13\x1b)\x05\x11\x11\r)\x01\x0f)\x03\t\x0b\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03A\x15!)\x03\t\x19)\x03\x05\x19)\x03\x01\x19)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03\x11\r)\x03\x05\x0b\x04\xee\x04\x05\x01Q\x03\x15\x01\x07\x04\xc6\x04\x03\x01\t\x0bP\x03\x03\x07\x04F\x03\x039m\x05B\x03\x05\x03\x1f\x05B\x03\x07\x03\x07\x05B\x03\t\x03\t\x05B\x03\x0b\x03\x07\x07B3\r\x03)\x13\x069\x03\x05\x03\t\x15F?\x0f\x03\x05\x03\x0b\x17\x06E\x03\x17\x03\r\x19\x06I\x03\x17\x03\r\x1b\x06M\x03\x17\x03\x11\x1d\x06Q\x03\x05\x05\x0f\x13\r\x06U\x03\x05\x05\x0b\x15\x03F\x13\x11\x03\x05\x03\x07\x1f\x06\x13\x03\x05\x05\x17\x19!F\t\x13\x03\x05\x03\x1b#G\x01_\x15\x07\x05\x11\t\x03\x1d\x03F\x01\x11\x03\t\x03\x05\x0fF\x01\x17\x033\x05#%\x03F\x01\x11\x035\x03'\x03F\x01\x11\x03\x05\x03\x03\x03F\x01\x19\x03\x1d\x03)\t\x06\x01\x03\x05\x07-\x1f+\x03F\x01\x11\x037\x03'\x03F\x01\x11\x03\x11\x03\x01\x03F\x01\x1b\x039\x031\t\x06\x01\x03\x11\x075!3\x11\x04\x03\x05/7\x0bP\t\x1d\x07\x04\x9d\x03\x15+\x03\x0b\t\x00\x05B\x03\x1f\x03\x07\x05B\x03\t\x03\t\x07B\x0f\r\x03\x13\x03F\x11\x11\x03\x13\x03\x05\r\x06\x11\x03\x13\x05\x07\t\x07B\x0f!\x03\x13\x0fF'#\x03\x1d\x05\x0b\r\x03F+\x11\x03\x05\x03\x03\t\x06/\x03\x05\x07\x0f\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00r\x10o'\x03\r\x15\x11\x0f\x0b\t\t\x0b!\x11#;)99A9;;EA;WgKMO;\x1b%)9i\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit\x00jit()/jit(main)/jit(tril)/iota\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00jit()/jit(main)/transpose\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00mhlo.backend_config\x00jit()/jit(main)/eigh\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00algorithm\x00lower\x00\x00cusolver_syevd_ffi\x00\x08o%\x05?\x01\x0bm}\x7f\x89\x8b\x03\x99\x03\x9b\x03u\x03\x9d\x03w\x03\x9f\x03g\x03s\x11\xab\xad\xafm\xb1\xb3\xb5\xb9\x05y\xbf\x03\xc1\x03\xc3\x0bq\x8fqs\x91\x03\x93\x03\x95\x05y\x97", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_30["c128"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_syevd_ffi'], + serialized_date=datetime.date(2024, 9, 30), + inputs=(), + expected_outputs=(array([[ 0.7941185704969035 +0.j, 0.3696443397434604 +0.j, + 0.4041819665640972 +0.j, -0.2633926650306618 +0.j], + [ 0.3696443397434601 +0.j, -0.7941185704969035 +0.j, + -0.2633926650306616 +0.j, -0.4041819665640975 +0.j], + [-0.05482989100998286+0.j, 0.4793041217634256 +0.j, + -0.6857605696309689 +0.j, -0.5449712680975332 +0.j], + [-0.47930412176342574+0.j, -0.05482989100998264+0.j, + 0.5449712680975333 +0.j, -0.6857605696309688 +0.j]]), array([-3.7082869338697044e+00, 3.5411017930205070e-16, + 6.5803628062392796e-16, 3.3708286933869694e+01])), + mlir_module_text=r""" +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:27) +#loc18 = loc("jit()/jit(main)/pjit"(#loc7)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %cst_0 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_1 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc11) + %3 = stablehlo.real %2 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc12) + %4 = stablehlo.imag %2 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc13) + %5 = stablehlo.negate %4 : tensor<4x4xf64> loc(#loc14) + %6 = stablehlo.complex %3, %5 : tensor<4x4xcomplex> loc(#loc15) + %7 = stablehlo.add %1, %6 : tensor<4x4xcomplex> loc(#loc16) + %8 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc17) + %9 = stablehlo.divide %7, %8 : tensor<4x4xcomplex> loc(#loc17) + %10 = call @tril(%9) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc18) + %11:3 = stablehlo.custom_call @cusolver_syevd_ffi(%10) {mhlo.backend_config = {algorithm = 0 : ui8, lower = true}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4xf64>, tensor) loc(#loc19) + %12 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc19) + %13 = stablehlo.compare EQ, %11#2, %12, SIGNED : (tensor, tensor) -> tensor loc(#loc19) + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc19) + %15 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc19) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc19) + %17 = stablehlo.select %16, %11#0, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc19) + %18 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> loc(#loc19) + %19 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4xf64> loc(#loc19) + %20 = stablehlo.broadcast_in_dim %18, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc19) + %21 = stablehlo.select %20, %11#1, %19 : tensor<4xi1>, tensor<4xf64> loc(#loc19) + return %17, %21 : tensor<4x4xcomplex>, tensor<4xf64> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit"(#loc7))) -> (tensor<4x4xcomplex> {mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc20) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc21) + %2 = stablehlo.add %0, %1 : tensor<4x4xi32> loc(#loc21) + %3 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc20) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc22) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc23) + %6 = stablehlo.select %4, %arg0, %5 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc24) + return %6 : tensor<4x4xcomplex> loc(#loc18) + } loc(#loc18) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":266:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:25) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:15) +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":268:14) +#loc8 = loc("third_party/py/jax/tests/export_back_compat_test.py":274:11) +#loc9 = loc("jit()/jit(main)/iota"(#loc1)) +#loc10 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc11 = loc("jit()/jit(main)/transpose"(#loc3)) +#loc12 = loc("jit()/jit(main)/real"(#loc4)) +#loc13 = loc("jit()/jit(main)/imag"(#loc4)) +#loc14 = loc("jit()/jit(main)/neg"(#loc4)) +#loc15 = loc("jit()/jit(main)/complex"(#loc4)) +#loc16 = loc("jit()/jit(main)/add"(#loc5)) +#loc17 = loc("jit()/jit(main)/div"(#loc6)) +#loc19 = loc("jit()/jit(main)/eigh"(#loc8)) +#loc20 = loc("jit()/jit(main)/jit(tril)/iota"(#loc7)) +#loc21 = loc("jit()/jit(main)/jit(tril)/add"(#loc7)) +#loc22 = loc("jit()/jit(main)/jit(tril)/ge"(#loc7)) +#loc23 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim"(#loc7)) +#loc24 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc7)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x015\x05\x01\x05%\x01\x03\x0b\x03#\x0f\x13\x17\x1b\x1f#'+/37;?CGKO\x03*\x02\xc5=\x01g\x0f\x07\x0b\x17\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x17\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x17\x13\x0b\x0b\x17\x03_\x0f\x0b\x0b\x0b\x0b\x0f\x0b\x1f\x0f\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0bO\x0f\x0b/OOO\x1b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x17/\x0f\x0bO/\x01\x05\x0b\x0f\x039\x17\x0f\x0f\x07\x07\x07\x13\x17\x0b\x17\x07\x07\x17\x0f\x13\x17\x17\x13\x13\x07\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xf6\x07\x1dce\x1f\x05)\x17\x05J\x047\x1d!\x07\x17\x052\x043\x11\x03\x05\x1d#\x07\x1d%\x07\x1d[]\x03\x07\x17\x19\x1b\r\x1d\r\x05+\x11\x01\x00\x05-\x05/\x051\x053\x055\x057\x1d)\x07\x059\x1d-\x07\x05;\x1d1\x07\x05=\x1d57\x05?\x17\x05*\x045\x1d;=\x05A\x17\x05*\x04\x1d\x1dAC\x05C\x17\x052\x04E\x1dG\x0b\x05E\x1dK\x0b\x05G\x1dO\x0b\x05I\x1dS\x0b\x05K\x1dWY\x05M\x17\x052\x04\x1f\x05O\x17\x052\x04\x1d\x03\x03a\xa1\x05Q\x05S\x17\x05J\x04\x17\x1f'\x01\x1dU\x1dW\x03\x01\x1dY\x03\x03\x8d\x1d[\x1f\t\t\x00\x00\x00\x00\x13\x0b\x01\t\x07\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00##\x03\x05\x81\x85\r\x05o\x83ik\x1d]\r\x05o\x87ik\x1d_\x1da\x1dc\r\x03ik#%\x1de\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x05\x07\x05\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x05\xa3\xa5\xa7\xa9\x1dg\x13+\x00\x1di\x05\x03\x0b\x03\x1dk\x1dm\x05\x01\x03\x03{\x03\x03\xb7\x15\x03\x01\x01\x01\x03\x07{\xbb\xbd\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f1\x01\x07\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x15)\x01\x15)\x01\x1b\x1d\x01\x0b)\x03\x11\x0f)\x05\x11\x11\x1b\x03\x0f)\x05\x11\x11\x0f\x13\x1b)\x05\x11\x11\r)\x01\x0f)\x03\t\x0b\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03A\x15!)\x03\t\x19)\x03\x05\x19)\x03\x01\x19)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03\x11\r)\x03\x05\x0b\x04\xee\x04\x05\x01Q\x03\x15\x01\x07\x04\xc6\x04\x03\x01\t\x0bP\x03\x03\x07\x04F\x03\x039m\x05B\x03\x05\x03\x1f\x05B\x03\x07\x03\x07\x05B\x03\t\x03\t\x05B\x03\x0b\x03\x07\x07B3\r\x03)\x13\x069\x03\x05\x03\t\x15F?\x0f\x03\x05\x03\x0b\x17\x06E\x03\x17\x03\r\x19\x06I\x03\x17\x03\r\x1b\x06M\x03\x17\x03\x11\x1d\x06Q\x03\x05\x05\x0f\x13\r\x06U\x03\x05\x05\x0b\x15\x03F\x13\x11\x03\x05\x03\x07\x1f\x06\x13\x03\x05\x05\x17\x19!F\t\x13\x03\x05\x03\x1b#G\x01_\x15\x07\x05\x11\t\x03\x1d\x03F\x01\x11\x03\t\x03\x05\x0fF\x01\x17\x033\x05#%\x03F\x01\x11\x035\x03'\x03F\x01\x11\x03\x05\x03\x03\x03F\x01\x19\x03\x1d\x03)\t\x06\x01\x03\x05\x07-\x1f+\x03F\x01\x11\x037\x03'\x03F\x01\x11\x03\x11\x03\x01\x03F\x01\x1b\x039\x031\t\x06\x01\x03\x11\x075!3\x11\x04\x03\x05/7\x0bP\t\x1d\x07\x04\x9d\x03\x15+\x03\x0b\t\x00\x05B\x03\x1f\x03\x07\x05B\x03\t\x03\t\x07B\x0f\r\x03\x13\x03F\x11\x11\x03\x13\x03\x05\r\x06\x11\x03\x13\x05\x07\t\x07B\x0f!\x03\x13\x0fF'#\x03\x1d\x05\x0b\r\x03F+\x11\x03\x05\x03\x03\t\x06/\x03\x05\x07\x0f\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00r\x10o'\x03\r\x15\x11\x0f\x0b\t\t\x0b!\x11#;)99A9;;EA;WgKMO;\x1b%)9i\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit\x00jit()/jit(main)/jit(tril)/iota\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00jit()/jit(main)/transpose\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00mhlo.backend_config\x00jit()/jit(main)/eigh\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00algorithm\x00lower\x00\x00cusolver_syevd_ffi\x00\x08o%\x05?\x01\x0bm}\x7f\x89\x8b\x03\x99\x03\x9b\x03u\x03\x9d\x03w\x03\x9f\x03g\x03s\x11\xab\xad\xafm\xb1\xb3\xb5\xb9\x05y\xbf\x03\xc1\x03\xc3\x0bq\x8fqs\x91\x03\x93\x03\x95\x05y\x97", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py index 82e5d2bd479a..be5c6e01f8d8 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py @@ -15,7 +15,7 @@ # ruff: noqa import datetime -from numpy import array, float32 +from numpy import array, float32, float64, complex64, complex128 data_2023_03_18 = {} @@ -155,3 +155,271 @@ mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff=\x01\x9f\x17\x0f\x0f\x0f\x07\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03ao/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x03=\x1b\x07\x07\x07\x0f\x17\x0f\x07\x13\x07\x13\x1b\x17\x13\x1b\x17\x17\x13\x17\x13\x13\x1b\x07\x13\x13\x13\x17\x13\x1b\x13\x02\x1a\n\x17\x1d\n\x06\x01\x1d\x8f\x01\x1dK\x01\x1dy\x01\x1f\x05!\x03\x03\x0f\xd1\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x03\x03!\xcd\x053\x1dS\x01\x055\x057\x03\x03\x0b\xd9\x17\x1d\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\x13E\x05I\x03\x0b\x15\xa3\x17\xb7\x19\xb9\x13\xc3\x1b\xc5\x03\x0b\x15\xa9\x17\xc9\x19\xa9\x13\xab\x1b\xcb\x05K\x1dO\x01\x05M\x03\x03\x0b\xcf\x05O\x03\x03!\xd3\x1dY\x01\x05Q\x03\x05%\xad'\xd5\x1d_\x01\x05S\x03\x03\x0f\xd7\x1de\x01\x05U\x1di\x01\x05W\x1dm\x01\x05Y\x1dq+\x05[\x1du+\x05]\x03\x11-\xaf/\xdb1\xdd3\xa35\xb17\xdf9\xb3;\xe3\x05_\x03\x03\x11\xeb\x1d\x7f\x01\x05a\x03\x07\x83\xa5\x85\xa5\x87\xa5\x05c\x05e\x05g\x1d\x8b\x01\x05i\x03\x11-\xaf/\xed1\xef3\xa35\xb17\xf19\xb3;\xf3\x05k\x03\x03\x0b\xf5\x03\x05%\xad'\xf7\x03\x03\x0f\xf9\x03\x03\x0b\xfb\x03\x03\x0f\xfd\x03\x03\x9d\xab\x05m\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1do\x03\x03\xc7\x1dq\t\x07\x0b\x05\x05\x01\x03\x03\xe1\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbb\xbf\r\x03\xa7\xbd\x1ds\r\x03\xa7\xc1\x1du\x1dw\x1dy\r\x01#!\x1d{\x13\x05\x01\x1f\r\t\xff\xff\xff\xff\x1f#\x01\x13\x05\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\x00\x00\x1d}\x1d\x7f\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb5\xa1\xa1\x13\x03\x01\x13\x03\x05\x13\x03\t\x13\x03\r\x1d\x81\x1d\x83\x03\x05\x9f\xb5\x03\x07\x9f\xa1\xa1\x1f\r\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00)\x07\t\r\r\x07\x1b\x1d\t)\x01\x07)\x05\r\r\x03)\x01\x03\x01)\x03A-\x13)\x03\t\x03)\x07\t\r\r\x0f)\x05\t\r\x07)\x03\r\x05)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x05)\x05\r\r\x0f)\x03\t\x05)\x03I\x07/\t\x01\x19\x11\x11\x17)\x03\r\x13)\x03\t\x13)\x03\x05\x13/\x07\x01\x15\x1d)\x03\t\x0f)\x07\t\x05\x05\x0f)\x03\x05\x05\x04r\x04\x05\x01\x11\tC\x07\x03\x01\t\x0b\x11\tG\x05\x03-]\t\x03o\x1f\x03)\x17\x06s\x03\x01\x03\x01\x13\x07\x07w\x03+\x03\x03\x05\x07\x07=\x03\x01\x03\x05\x05\x07\x07?\x03\x19\x03\x05\x05\x07\x07A\x03\x11\x03\x05\x05\x07\x07{\x03\x11\x03\x05\x07\x03})\x03\t\x19\x07\x89\x81\x03\x01\x05\x07\x0f\x13\x07\x03\x8d\x035\x05\x11\t\x05\x07\x03=\x03\x01\x03\x13\x05\x07\x03?\x03\x15\x03\x13\x05\x07\x03A\x03\x1d\x03\x13\x07\x03\x03\x91\x03\r\x03\x07\x03\r\x03\x15\x03\x1b\r\x07\x03\x93\x037\x05\x17\x1d\x03\x07\x03\x95\x039\x03\x1f\x07\x03\x03\x97\x03\t\x03\x07\x03\r\x03\x01\x03#\x03\x07\x03\x99\x03\x17\x03!\x0f\x06\x03\x03\x01\x07'\x15%\x1b\x07\x05\x9b\x03\x01\x03\x07\x11\x04\t\x05)+\x0b\x11\x05I\x05\x03\x17/\x03\x01\t\t\x03M\x1f\x03\x0b\x07\x03\x05Q\x03\r\x03\x07#\r\x03\x0b\x03\x05\x15\x06#\x03\x0b\x05\x03\x07\t\x03WU\x03\x0b\r\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x07\x03\x05)\x03\t\x03\x07g\r\x03\x01\x03\x11\x0f\x06k\x03\x01\x07\x0f\x13\x01\x11\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00Z\x1b\x85\x1f3+#\x11\x0f\x0b\t\t\x0b!\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15\x13\r+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19+)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00index\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00cublas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", xla_call_module_version=4, ) # End paste + +data_2024_09_26 = {} + +data_2024_09_26["f32"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_geqrf_ffi', 'cusolver_orgqr_ffi'], + serialized_date=datetime.date(2024, 9, 26), + inputs=(), + expected_outputs=(array([[[ 0. , 0.91287094, 0.40824842], + [-0.4472136 , 0.36514843, -0.8164965 ], + [-0.8944272 , -0.18257415, 0.40824836]], + + [[-0.42426407, 0.80828977, 0.4082495 ], + [-0.5656854 , 0.11547142, -0.8164964 ], + [-0.7071068 , -0.57735085, 0.4082474 ]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], + [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], + [ 0.0000000e+00, 0.0000000e+00, 4.8374091e-08]], + + [[-2.1213203e+01, -2.2910259e+01, -2.4607319e+01], + [ 0.0000000e+00, 3.4641042e-01, 6.9282258e-01], + [ 0.0000000e+00, 0.0000000e+00, 1.4548683e-06]]], dtype=float32)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<18xf32> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> loc(#loc5) + %2:2 = stablehlo.custom_call @cusolver_geqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>]} : (tensor<2x3x3xf32>) -> (tensor<2x3x3xf32>, tensor<2x3xf32>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> loc(#loc7) + %4 = stablehlo.custom_call @cusolver_orgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> tensor<2x3x3xf32> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %9, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc12) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc12) + %12 = stablehlo.select %10, %11, %2#0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc13) + return %4, %12 : tensor<2x3x3xf32>, tensor<2x3x3xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":391:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#'+/37\x03\xc7\x89+\x01C\x17\x07\x0b\x0f\x0b\x13\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0b\x03G\x0b/\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x0fo\x13\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x1f\x1f\x0b\x0b\x0f\x17O\x0b\x0f\x13\x0f\x0b\x0bO\x01\x05\x0b\x0f\x03'\x1b\x07\x07\x17\x0f\x07\x0f\x07\x07\x17\x13\x17\x13\x13\x13\x13\x17\x1b\x13\x02F\x05\x17\x05\x1e\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x03)q\x1d\t\x01\x1d7\x01\x1d=\x01\x03\x07\x15\x17\x19\x07\x1b\x07\x05!\x11\x01\x00\x05#\x05%\x05'\x1d\t!\x17\x05\x1a\x065\x1d%'\x05)\x17\x05\x1a\x06\x1d\x05+\x1d-\x01\x05-\x1d1\x01\x05/\x1d5\x01\x051\x053\x1d;\x01\x055\x057\x1dA\x01\x059\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d;\x1d=\x1d?\x13\x07\x01\x0b\x03\x1dA\x05\x01\x03\x03W\x1f\x1d1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05Wy\x1f#\x01#\x17\x03\x05ae\r\x05GcIK\x1dC\r\x05GgIK\x1dE\x1dG\x1dI\x1f\r\t\xff\xff\xff\xff\x1f\x11\t\x00\x00\x00\x00\r\x01\x1dK\x03\x03w\x15\x03\x01\x01\x01\x1f\x1f!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dM\x03\x03\x7f\x15\x01\x01\x01\x13\x07\x05\t\x07\x07\x05\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\t)\x05\r\r\x0f)\x01\x0f\x1b)\x01\t\x13\x01\x11\x01\x05\x05\x05)\x03I\t)\x05\t\r\t)\x03\r\x13)\x03\t\x13)\x03\r\x07)\x03\x01\x07)\x05\r\r\x15)\x07\t\r\r\x15)\x03\t\x07\x04F\x02\x05\x01Q\x03\x13\x01\x07\x04\x1e\x02\x03\x01\x05\x0bP\x03\x03\x07\x04\xfb\x03!A\x07B\x03\x05\x03\r\x07B\x03\x07\x03\x11\x03B\x1f\t\x03\x19\r\x06#\x03\x05\x03\x05\tG+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\tG3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\x05F\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\x05F\x11\x17\x03'\x03\x19\x05F\x11\x11\x03\x05\x03\x03\x15\x06?\x03\x05\x07\x1b\x1d\t\x17\x04\x03\x05\x0f\x1f\x06\x03\x01\x05\x01\x00J\x0bO''\x0f\x0b\t\t\x03\x11#!CS79Y9=)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11\x1f\x19)\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00main\x00public\x00cusolver_geqrf_ffi\x00cusolver_orgqr_ffi\x00\x08_\x19\x05;\x01\x0bC]_ik\x03m\x03o\x03M\x11OQsCSUuY\x07EEE\x11OQ{CSY}U\x03[\x03\x81\x05\x83\x85\x03\x87", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_26["f64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_geqrf_ffi', 'cusolver_orgqr_ffi'], + serialized_date=datetime.date(2024, 9, 26), + inputs=(), + expected_outputs=(array([[[ 0. , 0.9128709291752773 , + 0.408248290463862 ], + [-0.447213595499958 , 0.36514837167011005, + -0.8164965809277264 ], + [-0.894427190999916 , -0.1825741858350547 , + 0.40824829046386335]], + + [[-0.42426406871192857, 0.8082903768654768 , + 0.40824829046386124], + [-0.565685424949238 , 0.11547005383792364, + -0.8164965809277263 ], + [-0.7071067811865476 , -0.577350269189625 , + 0.4082482904638641 ]]]), array([[[-6.7082039324993694e+00, -8.0498447189992444e+00, + -9.3914855054991175e+00], + [ 0.0000000000000000e+00, 1.0954451150103344e+00, + 2.1908902300206661e+00], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + -1.7577018578317312e-15]], + + [[-2.1213203435596427e+01, -2.2910259710444144e+01, + -2.4607315985291855e+01], + [ 0.0000000000000000e+00, 3.4641016151377924e-01, + 6.9282032302755281e-01], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + -1.8103038069914667e-15]]])), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x3x3xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<18xf64> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<18xf64>) -> tensor<2x3x3xf64> loc(#loc5) + %2:2 = stablehlo.custom_call @cusolver_geqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>]} : (tensor<2x3x3xf64>) -> (tensor<2x3x3xf64>, tensor<2x3xf64>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf64>, tensor) -> tensor<2x3x3xf64> loc(#loc7) + %4 = stablehlo.custom_call @cusolver_orgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>]} : (tensor<2x3x3xf64>, tensor<2x3xf64>) -> tensor<2x3x3xf64> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %9, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc12) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x3x3xf64> loc(#loc12) + %12 = stablehlo.select %10, %11, %2#0 : tensor<2x3x3xi1>, tensor<2x3x3xf64> loc(#loc13) + return %4, %12 : tensor<2x3x3xf64>, tensor<2x3x3xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":391:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#'+/37\x03\xc7\x89+\x01C\x17\x07\x0b\x0f\x0b\x13\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0b\x03G\x0b/\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x0fo\x13\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x1f/\x0b\x0b\x0f\x17O\x0b\x0f\x13\x0f\x0b\x0bO\x01\x05\x0b\x0f\x03'\x1b\x07\x07\x17\x0f\x07\x0f\x07\x07\x17\x13\x17\x13\x13\x13\x13\x17\x1b\x13\x02V\x05\x17\x05\x1e\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x03)q\x1d\t\x01\x1d7\x01\x1d=\x01\x03\x07\x15\x17\x19\x07\x1b\x07\x05!\x11\x01\x00\x05#\x05%\x05'\x1d\t!\x17\x05\x1a\x065\x1d%'\x05)\x17\x05\x1a\x06\x1d\x05+\x1d-\x01\x05-\x1d1\x01\x05/\x1d5\x01\x051\x053\x1d;\x01\x055\x057\x1dA\x01\x059\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d;\x1d=\x1d?\x13\x07\x01\x0b\x03\x1dA\x05\x01\x03\x03W\x1f\x1d1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05Wy\x1f#\x01#\x17\x03\x05ae\r\x05GcIK\x1dC\r\x05GgIK\x1dE\x1dG\x1dI\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\x1dK\x03\x03w\x15\x03\x01\x01\x01\x1f\x1f!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dM\x03\x03\x7f\x15\x01\x01\x01\x13\x07\x05\t\x07\x07\x05\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\x0b)\x05\r\r\x0f)\x01\x0f\x1b)\x01\t\x13\x01\x11\x01\x05\x05\x05)\x03I\t)\x05\t\r\t)\x03\r\x13)\x03\t\x13)\x03\r\x07)\x03\x01\x07)\x05\r\r\x15)\x07\t\r\r\x15)\x03\t\x07\x04F\x02\x05\x01Q\x03\x13\x01\x07\x04\x1e\x02\x03\x01\x05\x0bP\x03\x03\x07\x04\xfb\x03!A\x07B\x03\x05\x03\r\x07B\x03\x07\x03\x11\x03B\x1f\t\x03\x19\r\x06#\x03\x05\x03\x05\tG+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\tG3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\x05F\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\x05F\x11\x17\x03'\x03\x19\x05F\x11\x11\x03\x05\x03\x03\x15\x06?\x03\x05\x07\x1b\x1d\t\x17\x04\x03\x05\x0f\x1f\x06\x03\x01\x05\x01\x00J\x0bO''\x0f\x0b\t\t\x03\x11#!CS79Y9=)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11\x1f\x19)\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00main\x00public\x00cusolver_geqrf_ffi\x00cusolver_orgqr_ffi\x00\x08_\x19\x05;\x01\x0bC]_ik\x03m\x03o\x03M\x11OQsCSUuY\x07EEE\x11OQ{CSY}U\x03[\x03\x81\x05\x83\x85\x03\x87", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_26["c64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_geqrf_ffi', 'cusolver_orgqr_ffi'], + serialized_date=datetime.date(2024, 9, 26), + inputs=(), + expected_outputs=(array([[[ 0. +0.j, 0.91287094+0.j, 0.40824836+0.j], + [-0.4472136 +0.j, 0.36514843+0.j, -0.81649655+0.j], + [-0.8944272 +0.j, -0.18257417+0.j, 0.4082483 +0.j]], + + [[-0.42426407+0.j, 0.8082899 +0.j, 0.40824962+0.j], + [-0.5656854 +0.j, 0.11547136+0.j, -0.8164964 +0.j], + [-0.7071068 +0.j, -0.57735085+0.j, 0.4082474 +0.j]]], + dtype=complex64), array([[[-6.7082038e+00+0.j, -8.0498447e+00+0.j, -9.3914852e+00+0.j], + [ 0.0000000e+00+0.j, 1.0954450e+00+0.j, 2.1908898e+00+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 4.8374091e-08+0.j]], + + [[-2.1213203e+01+0.j, -2.2910259e+01+0.j, -2.4607319e+01+0.j], + [ 0.0000000e+00+0.j, 3.4641042e-01+0.j, 6.9282258e-01+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 1.5032538e-06+0.j]]], + dtype=complex64)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<18xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<18xcomplex>) -> tensor<2x3x3xcomplex> loc(#loc5) + %2:2 = stablehlo.custom_call @cusolver_geqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>]} : (tensor<2x3x3xcomplex>) -> (tensor<2x3x3xcomplex>, tensor<2x3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xcomplex>, tensor>) -> tensor<2x3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @cusolver_orgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>]} : (tensor<2x3x3xcomplex>, tensor<2x3xcomplex>) -> tensor<2x3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %9, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc12) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<2x3x3xcomplex> loc(#loc12) + %12 = stablehlo.select %10, %11, %2#0 : tensor<2x3x3xi1>, tensor<2x3x3xcomplex> loc(#loc13) + return %4, %12 : tensor<2x3x3xcomplex>, tensor<2x3x3xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":391:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#'+/37\x03\xc9\x89-\x01C\x17\x07\x0b\x0f\x0b\x13\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0b\x03G\x0b/\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x0fo\x13\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x1f/\x0b\x0b\x0f\x17O\x0b\x0f\x13\x0f\x0b\x0bO\x01\x05\x0b\x0f\x03)\x1b\x07\x0b\x17\x0f\x07\x0f\x07\x07\x17\x07\x13\x17\x13\x13\x13\x13\x17\x1b\x13\x02^\x05\x17\x05\x1e\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x03)q\x1d\t\x01\x1d7\x01\x1d=\x01\x03\x07\x15\x17\x19\x07\x1b\x07\x05!\x11\x01\x00\x05#\x05%\x05'\x1d\t!\x17\x05\x1a\x065\x1d%'\x05)\x17\x05\x1a\x06\x1d\x05+\x1d-\x01\x05-\x1d1\x01\x05/\x1d5\x01\x051\x053\x1d;\x01\x055\x057\x1dA\x01\x059\x03\x01\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d;\x1d=\x1d?\x13\x07\x01\x0b\x03\x1dA\x05\x01\x03\x03W\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05Wy\x1f%\x01#\x17\x03\x05ae\r\x05GcIK\x1dC\r\x05GgIK\x1dE\x1dG\x1dI\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\x1dK\x03\x03w\x15\x03\x01\x01\x01\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dM\x03\x03\x7f\x15\x01\x01\x01\x13\x07\x05\t\x07\x07\x05\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\x03\x19)\x05\r\r\x0f)\x01\x0f\x1b)\x01\t\x13\x01\x11\x01\x05\x05\x05\t)\x03I\t)\x05\t\r\t)\x03\r\x13)\x03\t\x13)\x03\r\x07)\x03\x01\x07)\x05\r\r\x15)\x07\t\r\r\x15)\x03\t\x07\x04F\x02\x05\x01Q\x03\x13\x01\x07\x04\x1e\x02\x03\x01\x05\x0bP\x03\x03\x07\x04\xfb\x03!A\x07B\x03\x05\x03\r\x07B\x03\x07\x03\x11\x03B\x1f\t\x03\x1b\r\x06#\x03\x05\x03\x05\tG+\x0b\x0b\x05\x05\x1d\x03\x07\x0fF/\r\x03\x05\x05\t\x03\tG3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\x05F\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03'\x05\x15\x17\x05F\x11\x17\x03)\x03\x19\x05F\x11\x11\x03\x05\x03\x03\x15\x06?\x03\x05\x07\x1b\x1d\t\x17\x04\x03\x05\x0f\x1f\x06\x03\x01\x05\x01\x00J\x0bO''\x0f\x0b\t\t\x03\x11#!CS79Y9=)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11\x1f\x19)\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00main\x00public\x00cusolver_geqrf_ffi\x00cusolver_orgqr_ffi\x00\x08_\x19\x05;\x01\x0bC]_ik\x03m\x03o\x03M\x11OQsCSUuY\x07EEE\x11OQ{CSY}U\x03[\x03\x81\x05\x83\x85\x03\x87", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_26["c128"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_geqrf_ffi', 'cusolver_orgqr_ffi'], + serialized_date=datetime.date(2024, 9, 26), + inputs=(), + expected_outputs=(array([[[ 0. +0.j, 0.9128709291752773 +0.j, + 0.4082482904638621 +0.j], + [-0.447213595499958 +0.j, 0.36514837167011005+0.j, + -0.8164965809277263 +0.j], + [-0.894427190999916 +0.j, -0.18257418583505472+0.j, + 0.40824829046386335+0.j]], + + [[-0.42426406871192857+0.j, 0.808290376865477 +0.j, + 0.4082482904638615 +0.j], + [-0.565685424949238 +0.j, 0.11547005383792353+0.j, + -0.8164965809277263 +0.j], + [-0.7071067811865476 +0.j, -0.577350269189625 +0.j, + 0.4082482904638641 +0.j]]]), array([[[-6.7082039324993694e+00+0.j, -8.0498447189992444e+00+0.j, + -9.3914855054991175e+00+0.j], + [ 0.0000000000000000e+00+0.j, 1.0954451150103344e+00+0.j, + 2.1908902300206661e+00+0.j], + [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, + -1.7577018578317312e-15+0.j]], + + [[-2.1213203435596427e+01+0.j, -2.2910259710444144e+01+0.j, + -2.4607315985291855e+01+0.j], + [ 0.0000000000000000e+00+0.j, 3.4641016151377924e-01+0.j, + 6.9282032302755292e-01+0.j], + [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, + -1.7201790115224914e-15+0.j]]])), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<18xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<18xcomplex>) -> tensor<2x3x3xcomplex> loc(#loc5) + %2:2 = stablehlo.custom_call @cusolver_geqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>]} : (tensor<2x3x3xcomplex>) -> (tensor<2x3x3xcomplex>, tensor<2x3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xcomplex>, tensor>) -> tensor<2x3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @cusolver_orgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>]} : (tensor<2x3x3xcomplex>, tensor<2x3xcomplex>) -> tensor<2x3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %9, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc12) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<2x3x3xcomplex> loc(#loc12) + %12 = stablehlo.select %10, %11, %2#0 : tensor<2x3x3xi1>, tensor<2x3x3xcomplex> loc(#loc13) + return %4, %12 : tensor<2x3x3xcomplex>, tensor<2x3x3xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":391:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#'+/37\x03\xc9\x89-\x01C\x17\x07\x0b\x0f\x0b\x13\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0b\x03G\x0b/\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x0fo\x13\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x1fO\x0b\x0b\x0f\x17O\x0b\x0f\x13\x0f\x0b\x0bO\x01\x05\x0b\x0f\x03)\x1b\x07\x0b\x17\x0f\x07\x0f\x07\x07\x17\x07\x13\x17\x13\x13\x13\x13\x17\x1b\x13\x02~\x05\x17\x05\x1e\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x03)q\x1d\t\x01\x1d7\x01\x1d=\x01\x03\x07\x15\x17\x19\x07\x1b\x07\x05!\x11\x01\x00\x05#\x05%\x05'\x1d\t!\x17\x05\x1a\x065\x1d%'\x05)\x17\x05\x1a\x06\x1d\x05+\x1d-\x01\x05-\x1d1\x01\x05/\x1d5\x01\x051\x053\x1d;\x01\x055\x057\x1dA\x01\x059\x03\x01\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d;\x1d=\x1d?\x13\x07\x01\x0b\x03\x1dA\x05\x01\x03\x03W\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05Wy\x1f%\x01#\x17\x03\x05ae\r\x05GcIK\x1dC\r\x05GgIK\x1dE\x1dG\x1dI\x1f\r\t\xff\xff\xff\xff\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\x1dK\x03\x03w\x15\x03\x01\x01\x01\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dM\x03\x03\x7f\x15\x01\x01\x01\x13\x07\x05\t\x07\x07\x05\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\x03\x19)\x05\r\r\x0f)\x01\x0f\x1b)\x01\t\x13\x01\x11\x01\x05\x05\x05\x0b)\x03I\t)\x05\t\r\t)\x03\r\x13)\x03\t\x13)\x03\r\x07)\x03\x01\x07)\x05\r\r\x15)\x07\t\r\r\x15)\x03\t\x07\x04F\x02\x05\x01Q\x03\x13\x01\x07\x04\x1e\x02\x03\x01\x05\x0bP\x03\x03\x07\x04\xfb\x03!A\x07B\x03\x05\x03\r\x07B\x03\x07\x03\x11\x03B\x1f\t\x03\x1b\r\x06#\x03\x05\x03\x05\tG+\x0b\x0b\x05\x05\x1d\x03\x07\x0fF/\r\x03\x05\x05\t\x03\tG3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\x05F\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03'\x05\x15\x17\x05F\x11\x17\x03)\x03\x19\x05F\x11\x11\x03\x05\x03\x03\x15\x06?\x03\x05\x07\x1b\x1d\t\x17\x04\x03\x05\x0f\x1f\x06\x03\x01\x05\x01\x00J\x0bO''\x0f\x0b\t\t\x03\x11#!CS79Y9=)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11\x1f\x19)\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00main\x00public\x00cusolver_geqrf_ffi\x00cusolver_orgqr_ffi\x00\x08_\x19\x05;\x01\x0bC]_ik\x03m\x03o\x03M\x11OQsCSUuY\x07EEE\x11OQ{CSY}U\x03[\x03\x81\x05\x83\x85\x03\x87", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 5a975e3c5a61..70826eec8806 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -20,8 +20,7 @@ The tests in this file refer to the test data in jax/_src/internal_test_util/export_back_compat_test_data. -There is one test for each version of a custom call target, e.g., -`test_ducc_fft` tests the FFT custom calls on CPU. +There is one test for each version of a custom call target. Only custom call targets tested here should be listed in export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE. All other custom call targets will result in an error when encountered during serialization. diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 2c94907568d9..36879ce5f9db 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -59,7 +59,6 @@ from jax._src import test_util as jtu from jax._src.lax import control_flow as lax_control_flow from jax._src.lax import windowed_reductions as lax_windowed_reductions -from jax._src.lib import xla_client from jax._src import random as jax_random # mypy generates a lot of false positive due to re-assigned variables. @@ -654,7 +653,9 @@ def _make_device_put_harness(name, define( "device_put", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{device=}", - lambda x: dispatch.device_put_p.bind(x, devices=[_device_fn()], srcs=[None])[0], + lambda x: dispatch.device_put_p.bind( + x, devices=[_device_fn()], srcs=[None], + copy_semantics=[dispatch.CopySemantics.ALIAS])[0], [RandArg(shape, dtype)], shape=shape, dtype=dtype, @@ -1169,6 +1170,18 @@ def _make_broadcast_in_dim_harness(name, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3), True), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), + lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(1,), + start_index_map=(1,), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (1, 1), True), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + lax.GatherDimensionNumbers( + offset_dims=(2,), collapsed_slice_dims=(), + start_index_map=(2,), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (1, 1, 3), True) ]: dtype = np.float32 for enable_xla in ([True] if needs_xla else [True, False]): @@ -1276,15 +1289,16 @@ def _make_scatter_harness(name, update_shape=(2,), mode=lax.GatherScatterMode.FILL_OR_DROP, dtype=np.float32, - dimension_numbers=((), (0,), (0,)), + dimension_numbers=lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,)), enable_and_disable_xla=False): - dimension_numbers = lax.ScatterDimensionNumbers(*dimension_numbers) xla_options = [True, False] if enable_and_disable_xla else [True] for enable_xla in xla_options: define( f_lax.__name__, - f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_updatewindowdims={dimension_numbers.update_window_dims}_insertedwindowdims={dimension_numbers.inserted_window_dims}_scatterdimstooperanddims={dimension_numbers.scatter_dims_to_operand_dims}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}_{mode=!s}_enablexla={enable_xla}" + f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_{dimension_numbers=}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}_{mode=!s}_enablexla={enable_xla}" .replace(" ", ""), partial( f_lax, @@ -1328,8 +1342,19 @@ def _make_scatter_harness(name, # Validate shapes, dimension numbers and scatter indices. All are in bounds. for shape, scatter_indices, update_shape, dimension_numbers in [ - ((10,), [[0], [0], [0]], (3, 2), ((1,), (), (0,))), - ((10, 5), [[0], [2], [1]], (3, 3), ((1,), (0,), (0,))) + ((10,), [[0], [0], [0]], (3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0,))), + ((10, 5), [[0], [2], [1]], (3, 3), + lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ((2, 3, 10), [[[0], [1]], [[2], [3]], [[4], [5]]], (3, 2, 3), + lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]: _make_scatter_harness( "shapes_and_dimension_numbers", @@ -1358,13 +1383,16 @@ def _make_scatter_harness(name, _make_scatter_harness("modes_in_bounds", f_lax=f_lax, mode=mode) - _make_scatter_harness("modes_out_of_bounds", mode=mode, - shape=(1, 5), - f_lax=f_lax, - scatter_indices=np.array([10]), - update_shape=(1,), - dimension_numbers=((0,), (1,), (1,)), - enable_and_disable_xla=True) + _make_scatter_harness( + "modes_out_of_bounds", + mode=mode, + shape=(1, 5), + f_lax=f_lax, + scatter_indices=np.array([10]), + update_shape=(1,), + dimension_numbers=lax.ScatterDimensionNumbers((0,), (1,), (1,)), + enable_and_disable_xla=True, + ) # Validate no XLA scatters for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex) - set(jtu.dtypes.boolean): @@ -1372,22 +1400,34 @@ def _make_scatter_harness(name, lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min, lax.scatter ]: for shape, scatter_indices, update_shape, dimension_numbers in [ - ((1,), [0], (), ((), (0,), (0,))), # zero case - ((1, 1), [0], (1,), ((0,), (0,), (0,))), - ((1, 1, 1), [0], (1, 1), ((0, 1), (0,), (0,))), - ((1, 50, 3), [32], (1, 3), ((0, 1), (1,), (1,))), - ((1, 2, 3), [1], (1, 3), ((0, 1), (1,), (1,))), # slice 2nd dim - ((1, 2, 3), [0], (2, 3), ((0, 1), (0,), (0,))), # slice 1st dim - ((1, 2, 3), [1, 2], (1,), ((0,), (1, 2), (1, 2))), # 2nd and 3rd - ((4, 2, 3), [3, 2], (2,), ((0,), (0, 2), (0, 2))), # 1st and 3rd - ((4, 2, 3, 5), [0, 4], (4, 3), ((0, 1), (1, 3), (1, 3))), # 2nd and 4th + ((1,), [0], (), + lax.ScatterDimensionNumbers((), (0,), (0,))), # zero case + ((1, 1), [0], (1,), + lax.ScatterDimensionNumbers((0,), (0,), (0,))), + ((1, 1, 1), [0], (1, 1), + lax.ScatterDimensionNumbers((0, 1), (0,), (0,))), + ((1, 50, 3), [32], (1, 3), + lax.ScatterDimensionNumbers((0, 1), (1,), (1,))), + ((1, 2, 3), [1], (1, 3), + lax.ScatterDimensionNumbers((0, 1), (1,), (1,))), # slice 2nd dim + ((1, 2, 3), [0], (2, 3), + lax.ScatterDimensionNumbers((0, 1), (0,), (0,))), # slice 1st dim + ((1, 2, 3), [1, 2], (1,), + lax.ScatterDimensionNumbers((0,), (1, 2), (1, 2))), # 2nd and 3rd + ((4, 2, 3), [3, 2], (2,), + lax.ScatterDimensionNumbers((0,), (0, 2), (0, 2))), # 1st and 3rd + ((4, 2, 3, 5), [0, 4], (4, 3), + lax.ScatterDimensionNumbers((0, 1), (1, 3), (1, 3))), # 2nd and 4th ((5, 6, 7), [[0, 1], [2, 3]], (2, 7), - ((1,), (0, 1), (0, 1))), # .at[((3,4),(5,5))] shapes + lax.ScatterDimensionNumbers((1,), (0, 1), (0, 1))), + # .at[((3,4),(5,5))] shapes ((5, 6, 7), [[[0], [1]], [[2], [3]]], (5, 2, 2, 7), - ((0, 3), (1,), (1,))), # .at[:, ((3,4),(5,5))] shapes + lax.ScatterDimensionNumbers((0, 3), (1,), (1,))), + # .at[:, ((3,4),(5,5))] shapes ((5, 6, 7), [[[0, 1], [2, 3]], [[4, 0], [1, 2]]], (5, 2, 2), - ((0,), (1, 2), (1, 2))), # .at[:, ((3,4),(5,5)), 3] shapes - ((1, 125), [0], (1,), ((0,), (1,), (1,))), + lax.ScatterDimensionNumbers((0,), (1, 2), (1, 2))), + # .at[:, ((3,4),(5,5)), 3] shapes + ((1, 125), [0], (1,), lax.ScatterDimensionNumbers((0,), (1,), (1,))), ]: for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS, lax.GatherScatterMode.FILL_OR_DROP): @@ -1410,11 +1450,16 @@ def _make_scatter_harness(name, lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min ]: for shape, scatter_indices, update_shape, dimension_numbers in [ - ((1,), [[0],[0]], (2,), ((), (0,), (0,))), # .at[((0,0),)] - ((3,), [[1],[0],[1]], (3,), ((), (0,), (0,))), # .at[((1,0,1),)] - ((2, 3), [[[2],[2],[2]]], (2, 1, 3), ((0,), (1,), (1,))), # 2nd dim, .at[:, ((2,2,2),)] - ((3, 5, 40), [[1],[1]], (3, 5, 2), ((0, 1), (2,), (2,))), - ((3, 5, 4), [[1],[1]], (3, 2, 4), ((0, 2), (1,), (1,))), + ((1,), [[0],[0]], (2,), + lax.ScatterDimensionNumbers((), (0,), (0,))), # .at[((0,0),)] + ((3,), [[1],[0],[1]], (3,), + lax.ScatterDimensionNumbers((), (0,), (0,))), # .at[((1,0,1),)] + ((2, 3), [[[2],[2],[2]]], (2, 1, 3), + lax.ScatterDimensionNumbers((0,), (1,), (1,))), # 2nd dim, .at[:, ((2,2,2),)] + ((3, 5, 40), [[1],[1]], (3, 5, 2), + lax.ScatterDimensionNumbers((0, 1), (2,), (2,))), + ((3, 5, 4), [[1],[1]], (3, 2, 4), + lax.ScatterDimensionNumbers((0, 2), (1,), (1,))), ]: for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS, lax.GatherScatterMode.FILL_OR_DROP): @@ -1731,7 +1776,7 @@ def _make_fft_harness(name, *, shape=(14, 15, 16, 17), dtype=np.float32, - fft_type=xla_client.FftType.FFT, + fft_type=lax.FftType.FFT, fft_lengths=(17,)): def _fft_rng_factory(dtype): @@ -1759,12 +1804,12 @@ def _fft_rng_factory(dtype): # FFT, IFFT, RFFT, IRFFT -for fft_type in list(map(xla_client.FftType, [0, 1, 2, 3])): +for fft_type in list(map(lax.FftType, [0, 1, 2, 3])): # Validate dtypes per FFT type for dtype in (jtu.dtypes.floating - if fft_type == xla_client.FftType.RFFT else jtu.dtypes.complex): + if fft_type == lax.FftType.RFFT else jtu.dtypes.complex): shape = (14, 15, 16, 17) - if fft_type != xla_client.FftType.IRFFT: + if fft_type != lax.FftType.IRFFT: fft_lengths_list = [ (shape[-1],) ] else: fft_lengths_list = [ ((shape[-1] - 1) * 2,), (shape[-1] * 2 - 1,) ] @@ -1785,11 +1830,11 @@ def _fft_rng_factory(dtype): # Validate dimensions per FFT type for dtype in [ - np.float32 if fft_type == xla_client.FftType.RFFT else np.complex64 + np.float32 if fft_type == lax.FftType.RFFT else np.complex64 ]: for dims in [1, 2, 3]: for fft_lengths in [ - shape[-dims:] if fft_type != xla_client.FftType.IRFFT else + shape[-dims:] if fft_type != lax.FftType.IRFFT else shape[-dims:-1] + ((shape[-1] - 1) * 2,) ]: _make_fft_harness( diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index af773365b12d..b6e9b2f4ef07 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -277,7 +277,12 @@ def ir_constant(val: Any) -> IrValues: raise TypeError(f"No constant handler for type: {type(val)}") def _numpy_array_constant(x: np.ndarray | np.generic) -> IrValues: - attr = _numpy_array_attribute(x) + element_type = dtype_to_ir_type(x.dtype) + shape = x.shape + if x.dtype == np.bool_: + x = np.packbits(x, bitorder='little') # type: ignore + x = np.ascontiguousarray(x) + attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore return hlo.constant(attr) @@ -359,13 +364,26 @@ def _numpy_scalar_attribute(val: Any) -> ir.Attribute: else: raise TypeError(f"Unsupported scalar attribute type: {type(val)}") +_dtype_to_array_attr: dict[Any, AttributeHandler] = { + np.dtype(np.bool_): ir.DenseBoolArrayAttr.get, + np.dtype(np.float32): ir.DenseF32ArrayAttr.get, + np.dtype(np.float64): ir.DenseF64ArrayAttr.get, + np.dtype(np.int32): ir.DenseI32ArrayAttr.get, + np.dtype(np.int64): ir.DenseI64ArrayAttr.get, + np.dtype(np.int8): ir.DenseI8ArrayAttr.get, +} + def _numpy_array_attribute(x: np.ndarray | np.generic) -> ir.Attribute: - element_type = dtype_to_ir_type(x.dtype) shape = x.shape if x.dtype == np.bool_: x = np.packbits(x, bitorder='little') # type: ignore x = np.ascontiguousarray(x) - return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore + builder = _dtype_to_array_attr.get(x.dtype, None) + if builder: + return builder(x) + else: + element_type = dtype_to_ir_type(x.dtype) + return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute: if 0 in val.strides and val.size > 0: @@ -407,6 +425,8 @@ def _sequence_attribute_handler(val: Sequence[Any]) -> ir.Attribute: register_attribute_handler(list, _sequence_attribute_handler) register_attribute_handler(tuple, _sequence_attribute_handler) +register_attribute_handler(ir.Attribute, lambda x: x) +register_attribute_handler(ir.Type, lambda x: x) def ir_attribute(val: Any) -> ir.Attribute: """Convert a Python value to an MLIR attribute.""" @@ -1200,6 +1220,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, xla_donated_args = None out_donated_args = list(donated_args) + in_out_layout_not_none = in_layouts is not None and out_layouts is not None for i, (aval, rm) in enumerate(zip(avals_out, result_memory_kinds)): # Only donate if memory kinds match. Relax this when the compiler can # donate across memories. @@ -1207,14 +1228,26 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, if donations.get(key, ()): input_id = donations[key].popleft() out_donated_args[input_id] = False - # We can alias if XLA performs layout assignment because XLA will - # respect the aliases when assigning layouts. Its only for two - # mismatched explicitly assigned layouts that XLA will certainly fail. - if (in_layouts is None or - out_layouts is None or - in_layouts[input_id] == out_layouts[i] or - isinstance(in_layouts[input_id], AutoLayout) or + if (in_out_layout_not_none and + isinstance(in_layouts[input_id], AutoLayout) and + not isinstance(out_layouts[i], AutoLayout)): + raise ValueError( + f"Input layout being donated was {in_layouts[input_id]} while" + f" output layout was {out_layouts[i]}. Did you mean to set the" + " **output layout** to **DeviceLocalLayout.AUTO**?\nThis will" + " allow for the input and output layout to be chosen by XLA and" + " not the layout of the output which might not be optimal.") + if (in_out_layout_not_none and + not isinstance(in_layouts[input_id], AutoLayout) and isinstance(out_layouts[i], AutoLayout)): + raise ValueError( + f"Input layout being donated was {in_layouts[input_id]} while" + f" output layout was {out_layouts[i]}. Did you mean to set the" + " **input layout** to **DeviceLocalLayout.AUTO**?\nThis will allow" + " for the input and output layout to be chosen by XLA and not the" + " layout of the input which might not be optimal.") + if (in_layouts is None or out_layouts is None or + in_layouts[input_id] == out_layouts[i]): input_output_aliases[input_id] = i else: # Fallback to xla donation if layouts don't match. @@ -1488,7 +1521,6 @@ def lower_jaxpr_to_fun( aliases.extend([None] * len_ir_types(itypes)) else: aliases.extend(output_ids[alias]) - for attrs, alias in zip(arg_attrs, aliases): if alias is not None: attrs["tf.aliasing_output"] = i32_attr(alias) @@ -1778,13 +1810,9 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None for p in _platforms_for_eqn_ctx(eqn.ctx) or ctx.platforms: if eqn.primitive in _platform_specific_lowerings[p]: platform_rules[p] = _platform_specific_lowerings[p][eqn.primitive] - elif eqn.primitive in xla._backend_specific_translations[p]: - platform_rules[p] = xla_fallback_lowering(eqn.primitive) # Now the default rule if eqn.primitive in _lowerings: default_rule = _lowerings[eqn.primitive] - elif eqn.primitive in xla._translations: - default_rule = xla_fallback_lowering(eqn.primitive) effects = list(effects_lib.ordered_effects.filter_in(eqn.effects)) tokens_in = tokens.subset(effects) @@ -2579,48 +2607,6 @@ def merge_mlir_modules(dst_module: ir.Module, return renamings["main"] -def xla_fallback_lowering(prim: core.Primitive): - @cache_lowering - def fallback(ctx: LoweringRuleContext, *args, **params): - module_ctx = ctx.module_context - axis_ctx = module_ctx.axis_context - if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): - axis_env = axis_ctx.unsafe_axis_env - else: - axis_env = module_ctx.axis_env - - if any(hasattr(a, "shape") and - not core.is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)): - raise NotImplementedError( - f"Shape polymorphism for xla_fallback_lowering is not implemented ({ctx.primitive}); b/261682623") - - if len(module_ctx.platforms) > 1: - raise NotImplementedError( - "fallback lowering not implemented for multi-platform lowering") - xla_computation = xla.primitive_subcomputation( - module_ctx.platforms[0], axis_env, prim, ctx.avals_in, - ctx.avals_out, **params) - xla_module = xla_computation_to_mlir_module(xla_computation) - callee_name = merge_mlir_modules( - module_ctx.module, f"xla_fallback_{prim.name}", xla_module, - dst_symtab=module_ctx.symbol_table) - output_types = map(aval_to_ir_type, ctx.avals_out) - flat_output_types = flatten_ir_types(output_types) - output_type = (ir.TupleType.get_tuple(flat_output_types) - if prim.multiple_results else flat_output_types[0]) - - call = func_dialect.CallOp([output_type], - ir.FlatSymbolRefAttr.get(callee_name), - flatten_ir_values(args)).result - if not prim.multiple_results: - return [call] - flat_results = [hlo.get_tuple_element(call, i32_attr(i)) - for i in range(len(flat_output_types))] - - return unflatten_ir_values_like_types(flat_results, output_types) - return fallback - - DEVICE_TO_DEVICE_TYPE = 1 SEND_TO_HOST_TYPE = 2 RECV_FROM_HOST_TYPE = 3 diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 6bc3cceb7ab7..abece46e9602 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1162,12 +1162,13 @@ def has_effects(effects) -> bool: map(partial(write, False, False), eqn.outvars) elif isinstance(policy, Offloadable): # TODO(slebedev): This is a legit error which requires a BUILD fix. - from jax._src.dispatch import device_put_p, TransferToMemoryKind # pytype: disable=import-error + from jax._src.dispatch import device_put_p, TransferToMemoryKind, CopySemantics # pytype: disable=import-error resvars = [newvar(v.aval) for v in eqn.outvars] outvars_copy = list[Atom](eqn.outvars) offload_eqn = core.JaxprEqn( outvars_copy, resvars, device_put_p, - dict(devices=[TransferToMemoryKind(policy.dst)], srcs=[None]), + dict(devices=[TransferToMemoryKind(policy.dst)], srcs=[None], + copy_semantics=[CopySemantics.COPY]), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) known_eqns.append(offload_eqn) @@ -1176,7 +1177,8 @@ def has_effects(effects) -> bool: residuals.update(resvars) reload_eqn = core.JaxprEqn( resvars, eqn.outvars, device_put_p, - dict(devices=[TransferToMemoryKind(policy.src)], srcs=[None]), + dict(devices=[TransferToMemoryKind(policy.src)], srcs=[None], + copy_semantics=[CopySemantics.COPY]), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) staged_eqns.append(reload_eqn) @@ -1668,10 +1670,12 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: sentinel = object() jaxpr_effects = set() all_vars = {v: i for i, v in enumerate(it.chain(constvars, invars))} + mut_arrays = set() for eqn in eqns: if eqn.primitive is core.mutable_array_p: outvar, = eqn.outvars all_vars[outvar] = None # type: ignore + mut_arrays.add(outvar) for eff in eqn.effects: if isinstance(eff, effects.JaxprInputEffect): if eff.input_index >= len(eqn.invars): @@ -1681,6 +1685,8 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: "\n Jaxpr: " f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") invar = eqn.invars[eff.input_index] + if invar in mut_arrays: + continue if (input_index := all_vars.get(invar, sentinel)) is sentinel: raise ValueError( f"`JaxprInputEffect` {eff} does not have " diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 4c134f266da5..4aa05010cd7a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -19,7 +19,7 @@ from contextlib import contextmanager import collections from collections import namedtuple -from collections.abc import Callable, Sequence, Iterable, Iterator +from collections.abc import Callable, Sequence, Iterable import dataclasses from functools import partial, lru_cache, cached_property import functools @@ -62,7 +62,6 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -1985,14 +1984,17 @@ def _create_da_object( # pytype: disable=invalid-annotation return xc.DeviceList(device_assignment) +@weakref_lru_cache def jaxpr_transfer_mem_kinds( - jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]: + jaxpr: core.Jaxpr) -> Sequence[sharding_impls.TransferToMemoryKind]: + out = [] # type: ignore for eqn in jaxpr.eqns: if eqn.primitive is dispatch.device_put_p: - yield from (d for d in eqn.params['devices'] - if isinstance(d, sharding_impls.TransferToMemoryKind)) + out.extend(d for d in eqn.params['devices'] + if isinstance(d, sharding_impls.TransferToMemoryKind)) for subjaxpr in core.subjaxprs(jaxpr): - yield from jaxpr_transfer_mem_kinds(subjaxpr) + out.extend(jaxpr_transfer_mem_kinds(subjaxpr)) + return out def are_all_shardings_default_mem_kind(da_object, shardings): @@ -2001,7 +2003,9 @@ def are_all_shardings_default_mem_kind(da_object, shardings): except: return True for i in shardings: - if is_unspecified_or_auto(i) or i.memory_kind is None: + if is_unspecified_or_auto(i): + continue + if i.memory_kind is None: # pytype: disable=attribute-error continue if i.memory_kind != default_mem_kind: return False @@ -2173,12 +2177,14 @@ def lower_sharding_computation( else context_mesh._flat_devices_tuple) # Device assignment across all inputs, outputs and shardings inside jaxpr # should be the same. - unique_intermediate_shardings = list(util.stable_unique( - dispatch.get_intermediate_shardings(jaxpr))) + unique_intermediate_shardings = util.stable_unique( + dispatch.get_intermediate_shardings(jaxpr)) + unique_in_shardings = util.stable_unique(in_shardings) + unique_out_shardings = util.stable_unique(out_shardings) backend, device_assignment = _get_and_check_device_assignment( it.chain( - ((i, MismatchType.ARG_SHARDING, None) for i in util.stable_unique(in_shardings)), - ((o, MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings)), + ((i, MismatchType.ARG_SHARDING, None) for i in unique_in_shardings), + ((o, MismatchType.OUT_SHARDING, None) for o in unique_out_shardings), ((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) for js, source_info in unique_intermediate_shardings)), devices_from_context) @@ -2188,16 +2194,16 @@ def lower_sharding_computation( committed = bool( devices_from_context or len(device_assignment) > 1 or - any(not is_unspecified(i) for i in in_shardings) or + any(not is_unspecified(i) for i in unique_in_shardings) or any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or - any(not is_unspecified(o) for o in out_shardings)) + any(not is_unspecified(o) for o in unique_out_shardings)) da_object = _create_da_object(tuple(device_assignment)) - transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) + transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr) all_default_mem_kind = are_all_shardings_default_mem_kind( da_object, - it.chain(in_shardings, out_shardings, + it.chain(unique_in_shardings, unique_out_shardings, [js for js, _ in unique_intermediate_shardings], transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types @@ -2208,16 +2214,11 @@ def lower_sharding_computation( closed_jaxpr, in_shardings) # 2. Build up the HLO - semantic_in_shardings = SemanticallyEqualShardings( - in_shardings, global_in_avals) # type: ignore - semantic_out_shardings = SemanticallyEqualShardings( - out_shardings, global_out_avals) # type: ignore - prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) mesh_shape_tuple = None if config.use_shardy_partitioner.value or prim_requires_devices: - for sharding in it.chain(in_shardings, out_shardings, + for sharding in it.chain(unique_in_shardings, unique_out_shardings, [js for js, _ in unique_intermediate_shardings]): if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)): if (mesh_shape_tuple is not None and @@ -2228,6 +2229,11 @@ def lower_sharding_computation( f" {sharding.mesh.shape_tuple} for another") mesh_shape_tuple = sharding.mesh.shape_tuple + semantic_in_shardings = SemanticallyEqualShardings( + in_shardings, global_in_avals) # type: ignore + semantic_out_shardings = SemanticallyEqualShardings( + out_shardings, global_out_avals) # type: ignore + (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, @@ -3048,14 +3054,9 @@ def aot_cache_miss(*args, **kwargs): fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry - if xla_extension_version >= 286: - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], - JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) - else: - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, cc_shard_arg) + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], + JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) def cc_shard_arg(x, sharding, layout): return shard_args([sharding], [layout], [x])[0] diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 2db877d3f970..14635a46ea33 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -16,35 +16,25 @@ from __future__ import annotations -from collections import defaultdict from collections.abc import Callable, Sequence -import dataclasses -import functools from functools import partial -import itertools as it -from typing import Any, Protocol, Union +from typing import Any, Union import numpy as np from jax._src import core from jax._src import dtypes -from jax._src import source_info_util from jax._src.abstract_arrays import numpy_scalar_types from jax._src.core import ConcreteArray, ShapedArray -from jax._src.sharding_impls import AxisEnv from jax._src.util import safe_zip, safe_map from jax._src.typing import Shape -from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip -xe = xc._xla -xops = xc._xla.ops - # Types def identity(x): return x @@ -58,18 +48,6 @@ def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]: # Utilities -def parameter(builder, num, shape, name=None, replicated=None): - if name is None: - name = '' - if replicated is None: - replicated = [] - elif isinstance(replicated, bool): - replicated = [replicated] * shape.leaf_count() - - return xops.Parameter(builder, num, - shape.with_major_to_minor_layout_if_absent(), name, - replicated) - # HLO instructions optionally can be annotated to say how the output should be # spatially partitioned (represented in XLA as OpSharding protos, see # sharding_to_proto). For array outputs, the annotation is either an int per @@ -208,126 +186,7 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: (t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types) -def primitive_subcomputation(platform: str, axis_env: AxisEnv, - prim: core.Primitive, - avals_in: Sequence[core.AbstractValue], - avals_out: Sequence[core.AbstractValue], - **params): - c = xc.XlaBuilder(f"primitive_computation_{prim.name}") - counts = it.count() - xla_args = [parameter(c, next(counts), xla_shape) - for a in avals_in for xla_shape in aval_to_xla_shapes(a)] - if (platform is not None and - prim in _backend_specific_translations[platform]): - rule = _backend_specific_translations[platform][prim] - elif prim in _translations: - rule = _translations[prim] - - ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env, - name_stack=source_info_util.new_name_stack()) - ans = rule(ctx, avals_in, avals_out, *xla_args, **params) - - if prim.multiple_results: - return c.build(xops.Tuple(c, ans)) - else: - x, = ans - return c.build(x) - - -### compiling jaxprs - -@dataclasses.dataclass -class TranslationContext: - builder: xc.XlaBuilder - # TODO(phawkins): make platform non-optional. We should always be translating - # with a specific platform in mind. - platform: str | None - axis_env: AxisEnv - name_stack: str | source_info_util.NameStack - - def replace(self, **kw): return dataclasses.replace(self, **kw) - -def xla_destructure(c, ans): - num_elements = len(c.get_shape(ans).tuple_shapes()) - return [xops.GetTupleElement(ans, i) for i in range(num_elements)] - - -### translation tables - -MYPY = False -if not MYPY: - class TranslationRule(Protocol): - def __call__(self, ctx: TranslationContext, - avals_in: Sequence[core.AbstractValue], - avals_out: Sequence[core.AbstractValue], - *args: xc.XlaOp, **kw - ) -> Sequence[xc.XlaOp]: - """A translation rule lowers a primitive invocation into an XLA HLO.""" -else: - TranslationRule = Any - -_translations: dict[core.Primitive, TranslationRule] = {} -_backend_specific_translations: dict[str, dict[core.Primitive, TranslationRule]] -_backend_specific_translations = defaultdict(dict) - initial_style_primitives: set[core.Primitive] = set() def register_initial_style_primitive(prim: core.Primitive): initial_style_primitives.add(prim) - -def register_translation(prim: core.Primitive, rule: TranslationRule, *, - platform: str | None = None) -> None: - if platform is None: - _translations[prim] = rule - else: - # For backward compatibility reasons, we allow rules to be registered - # under "gpu" even though the platforms are now called "cuda" and "rocm". - # TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove - # this expansion. - for p in xb.expand_platform_alias(platform): - _backend_specific_translations[p][prim] = rule - - -# As a temporary backward compatibility measure, we use an adapter class to -# convert from the old styles of translation rules to the newer ones. -# TODO(phawkins): update users of the older translation rule styles and remove -# the adapters. -class _TranslationRuleAdapter: - def __init__(self, translations, - wrap_fn: Callable[[core.Primitive, Callable], TranslationRule]): - self._translations = translations - self._wrap_fn = wrap_fn - - def __setitem__(self, key: core.Primitive, value: Callable): - wrapped = self._wrap_fn(key, value) - for translations in self._translations: - translations[key] = wrapped - - -def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule: - @functools.wraps(f) - def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue], - avals_out: Sequence[core.AbstractValue], - *args: xc.XlaOp, **kw) -> Sequence[xc.XlaOp]: - ans = f(ctx.builder, *args, **kw) - if (prim.multiple_results or - any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)): - return xla_destructure(ctx.builder, ans) - else: - return [ans] - return wrapped - - -translations : _TranslationRuleAdapter -translations = _TranslationRuleAdapter([_translations], _wrap_old_translation) - -class _BackendSpecificTranslationsAdapter(defaultdict): - def __missing__(self, key): - translation_tables = [_backend_specific_translations[p] - for p in xb.expand_platform_alias(key)] - ret = self[key] = _TranslationRuleAdapter( - translation_tables, _wrap_old_translation) - return ret - -backend_specific_translations: dict[str, _TranslationRuleAdapter] -backend_specific_translations = _BackendSpecificTranslationsAdapter() diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index d3065d0f96d7..e654ce953b51 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -33,7 +33,7 @@ from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import util -from jax._src.state.discharge import register_discharge_rule, discharge_state +from jax._src.state.discharge import register_partial_discharge_rule, discharge_state from jax._src.state.types import AbstractRef, RefEffect from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects from jax._src.interpreters import ad @@ -854,19 +854,48 @@ def _cond_lowering(ctx, index, *args, branches): mlir.register_lowering(cond_p, _cond_lowering) -@register_discharge_rule(cond_p) -def _cond_state_discharge_rule(in_avals, out_avals, *args, branches): +@register_partial_discharge_rule(cond_p) +def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, branches): + assert not should_discharge[0], "Can't discharge the index." discharged_branches = tuple( - core.ClosedJaxpr(discharge_state(branch.jaxpr, ())[0], ()) - for branch in branches) - out_vals = cond_p.bind(*args, branches=discharged_branches) - out_vals, out_ref_vals = util.split_list( - out_vals, [len(out_avals)]) + discharge_state(branch.jaxpr, (), should_discharge=should_discharge[1:])[0] + for branch in branches + ) + # Don't thread the ref values through the cond if they never change. + forwarded_outvars = None + for branch in discharged_branches: + invar_pos = {v: i for i, v in enumerate(branch.invars)} + branch_forwarding = [ + invar_pos.get(v, None) if isinstance(v, core.Var) else None + for v in branch.outvars[len(out_avals) :] + ] + if forwarded_outvars is None: + forwarded_outvars = branch_forwarding + else: + forwarded_outvars = [ + i if i == j else None + for i, j in zip(forwarded_outvars, branch_forwarding) + ] + assert forwarded_outvars is not None + all_outvars_fwd = [None] * len(out_avals) + forwarded_outvars + new_branches = tuple( + core.ClosedJaxpr( + branch.replace(outvars=[v for v, fwd in zip(branch.outvars, all_outvars_fwd) + if fwd is None]), ()) + for branch in discharged_branches + ) + out_vals_no_fwd = cond_p.bind(index, *args, branches=new_branches) + out_vals, out_ref_vals_no_fwd = util.split_list(out_vals_no_fwd, [len(out_avals)]) + # Insert forwarded values into reference outputs + ref_val_no_fwd_iter = iter(out_ref_vals_no_fwd) + out_ref_vals = [next(ref_val_no_fwd_iter) if fwd is None else args[fwd] + for fwd in forwarded_outvars] + # Map reference outputs back to their invars ref_val_iter = iter(out_ref_vals) new_invals = [] - for aval in in_avals: - new_invals.append( - next(ref_val_iter) if isinstance(aval, AbstractRef) else None) + for should, aval in zip(should_discharge, in_avals): + discharged_inval = isinstance(aval, AbstractRef) and should + new_invals.append(next(ref_val_iter) if discharged_inval else None) return new_invals, out_vals diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 0e41fe5bb18f..290d027cc6bc 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -841,8 +841,7 @@ def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers lhs_shape: tuple of nonnegative integers, shape of the convolution input. rhs_shape: tuple of nonnegative integers, shape of the convolution kernel. dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers - object following the convolution dimension number specification format in - xla_client.py. + object. Returns: A `ConvDimensionNumbers` object that represents `dimension_numbers` in the diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 36553e512cd7..6ca1a4abd193 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections.abc import Sequence +import enum from functools import partial import math @@ -30,35 +31,50 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lib.mlir.dialects import hlo -from jax._src.lib import xla_client __all__ = [ "fft", "fft_p", ] -def _str_to_fft_type(s: str) -> xla_client.FftType: +class FftType(enum.IntEnum): + "Describes which FFT operation to perform." + + FFT = 0 + "Forward complex-to-complex FFT." + + IFFT = 1 + "Inverse complex-to-complex FFT." + + RFFT = 2 + "Forward real-to-complex FFT." + + IRFFT = 3 + "Inverse real-to-complex FFT." + + +def _str_to_fft_type(s: str) -> FftType: if s in ("fft", "FFT"): - return xla_client.FftType.FFT + return FftType.FFT elif s in ("ifft", "IFFT"): - return xla_client.FftType.IFFT + return FftType.IFFT elif s in ("rfft", "RFFT"): - return xla_client.FftType.RFFT + return FftType.RFFT elif s in ("irfft", "IRFFT"): - return xla_client.FftType.IRFFT + return FftType.IRFFT else: raise ValueError(f"Unknown FFT type '{s}'") @partial(jit, static_argnums=(1, 2)) -def fft(x, fft_type: xla_client.FftType | str, fft_lengths: Sequence[int]): +def fft(x, fft_type: FftType | str, fft_lengths: Sequence[int]): if isinstance(fft_type, str): typ = _str_to_fft_type(fft_type) - elif isinstance(fft_type, xla_client.FftType): + elif isinstance(fft_type, FftType): typ = fft_type else: raise TypeError(f"Unknown FFT type value '{fft_type}'") - if typ == xla_client.FftType.RFFT: + if typ == FftType.RFFT: if np.iscomplexobj(x): raise ValueError("only real valued inputs supported for rfft") x = lax.convert_element_type(x, dtypes.to_inexact_dtype(dtypes.dtype(x))) @@ -80,7 +96,7 @@ def fft_abstract_eval(x, fft_type, fft_lengths): if len(fft_lengths) > x.ndim: raise ValueError(f"FFT input shape {x.shape} must have at least as many " f"input dimensions as fft_lengths {fft_lengths}.") - if fft_type == xla_client.FftType.RFFT: + if fft_type == FftType.RFFT: if x.dtype not in (np.float32, np.float64): raise ValueError(f"RFFT input must be float32 or float64, got {x.dtype}") if x.shape[-len(fft_lengths):] != fft_lengths: @@ -89,7 +105,7 @@ def fft_abstract_eval(x, fft_type, fft_lengths): shape = (x.shape[:-len(fft_lengths)] + fft_lengths[:-1] + (fft_lengths[-1] // 2 + 1,)) dtype = _complex_dtype(x.dtype) - elif fft_type == xla_client.FftType.IRFFT: + elif fft_type == FftType.IRFFT: if not np.issubdtype(x.dtype, np.complexfloating): raise ValueError("IRFFT input must be complex64 or complex128, got " f"{x.dtype}") @@ -121,7 +137,7 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths): def _naive_rfft(x, fft_lengths): - y = fft(x, xla_client.FftType.FFT, fft_lengths) + y = fft(x, FftType.FFT, fft_lengths) n = fft_lengths[-1] return y[..., : n//2 + 1] @@ -144,7 +160,7 @@ def _irfft_transpose(t, fft_lengths): # factor and a mask. The mask scales the cotangent for the Hermitian # symmetric components of the RFFT by a factor of two, since these components # are de-duplicated in the RFFT. - x = fft(t, xla_client.FftType.RFFT, fft_lengths) + x = fft(t, FftType.RFFT, fft_lengths) n = x.shape[-1] is_odd = fft_lengths[-1] % 2 full = partial(lax.full_like, t, dtype=x.dtype) @@ -161,9 +177,9 @@ def _irfft_transpose(t, fft_lengths): return lax.conj(out) def _fft_transpose_rule(t, operand, fft_type, fft_lengths): - if fft_type == xla_client.FftType.RFFT: + if fft_type == FftType.RFFT: result = _rfft_transpose(t, fft_lengths) - elif fft_type == xla_client.FftType.IRFFT: + elif fft_type == FftType.IRFFT: result = _irfft_transpose(t, fft_lengths) else: result = fft(t, fft_type, fft_lengths) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f51f0436b7a9..a740285bf250 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -60,22 +60,16 @@ from jax._src.lax.utils import ( _input_dtype, dtype_to_string, standard_abstract_eval, standard_multi_result_abstract_eval, standard_primitive) -from jax._src import xla_bridge -from jax._src.lib import xla_client from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo -from jax._src.sharding_impls import PmapSharding, NamedSharding, PartitionSpec +from jax._src.sharding_impls import (PmapSharding, NamedSharding, + PartitionSpec as P) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis, split_list, NumpyComplexWarning) -xb = xla_bridge -xc = xla_client -xops = xla_client.ops -xe = xla_client._xla - _max = builtins.max _min = builtins.min _reduce = functools.reduce @@ -275,8 +269,17 @@ def ceil(x: ArrayLike) -> Array: return ceil_p.bind(x) class RoundingMethod(enum.IntEnum): + """Rounding strategies for handling halfway values (e.g., 0.5) in + :func:`jax.lax.round`. + """ + AWAY_FROM_ZERO = 0 + """Rounds halfway values away from zero (e.g., 0.5 -> 1, -0.5 -> -1).""" + TO_NEAREST_EVEN = 1 + """Rounds halfway values to the nearest even integer. This is also known + as “banker’s rounding” (e.g., 0.5 -> 0, 1.5 -> 2). + """ def round(x: ArrayLike, rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO @@ -288,8 +291,7 @@ def round(x: ArrayLike, Args: x: an array or scalar value to round. rounding_method: the method to use when rounding halfway values - (e.g., `0.5`). See ``lax.RoundingMethod`` for the list of possible - values. + (e.g., `0.5`). See :class:`jax.lax.RoundingMethod` for possible values. Returns: An array containing the elementwise rounding of x. @@ -701,27 +703,19 @@ def __str__(self) -> str: _precision_strings[None] = Precision.DEFAULT -PrecisionLike = Union[ - str, - Precision, - tuple[str, str], - tuple[Precision, Precision], - None, -] - - class DotAlgorithm(NamedTuple): """Specify the algorithm used for computing dot products. - When used as input to :func:`~jax.lax.dot_general`, this data structure is - used for controlling the properties of the algorithm used for computing the - dot product. This API controls the precision used for the computation, and - allows users to access hardware-specific accelerations. + When used to specify the ``precision`` input to :func:`~jax.lax.dot`, + :func:`~jax.lax.dot_general`, and other dot product functions, this data + structure is used for controlling the properties of the algorithm used for + computing the dot product. This API controls the precision used for the + computation, and allows users to access hardware-specific accelerations. Support for these algorithms is platform dependent, and using an unsupported algorithm will raise a Python exception when the computation is compiled. The algorithms that are known to be supported on at least some platforms are - listed in the :class:`~jax.lax.DotAlgorithm.Preset` enum, and these are a + listed in the :class:`~jax.lax.DotAlgorithmPreset` enum, and these are a good starting point for experimenting with this API. A "dot algorithm" is specified by the following parameters: @@ -756,13 +750,24 @@ class DotAlgorithm(NamedTuple): ... ) >>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) - >>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP - array([ 1., 4., 9., 16.], dtype=float32) + >>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP + array([ 1., 4., 9., 16.], dtype=float16) Or, equivalently, using a preset: - >>> algorithm = DotAlgorithm.Preset.F16_F16_F32 - >>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP + >>> algorithm = DotAlgorithmPreset.F16_F16_F32 + >>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP + array([ 1., 4., 9., 16.], dtype=float16) + + Presets can also be specified by name: + + >>> dot(lhs, rhs, precision="F16_F16_F32") # doctest: +SKIP + array([ 1., 4., 9., 16.], dtype=float16) + + The ``preferred_element_type`` parameter can be used to return the output + without downcasting the accumulation type: + + >>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32) # doctest: +SKIP array([ 1., 4., 9., 16.], dtype=float32) """ @@ -787,50 +792,149 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, self.allow_imprecise_accumulation, ) - # mypy doesn't currently support nested classes in a NamedTuple definition. - class Preset(enum.Enum): # type: ignore[misc] - DEFAULT = 0 - ANY_F8_ANY_F8_F32 = 1 - ANY_F8_ANY_F8_F32_FAST_ACCUM = 2 - F16_F16_F16 = 3 - F16_F16_F32 = 4 - BF16_BF16_BF16 = 5 - BF16_BF16_F32 = 6 - BF16_BF16_F32_X3 = 7 - BF16_BF16_F32_X6 = 8 - TF32_TF32_F32 = 9 - TF32_TF32_F32_X3 = 10 - F32_F32_F32 = 11 - F64_F64_F64 = 12 - - def __repr__(self) -> str: - return f'{self.__class__.__name__}.{self.name}' - - def __str__(self) -> str: - return self.name - - @property - def accumulation_type(self) -> DTypeLike: - match self: - case DotAlgorithm.Preset.DEFAULT: - raise TypeError( - "The default dot algorithm does not have an accumulation type.") - case DotAlgorithm.Preset.F16_F16_F16: - return np.float16 - case DotAlgorithm.Preset.BF16_BF16_BF16: - return dtypes.bfloat16 - case DotAlgorithm.Preset.F64_F64_F64: - return np.float64 - case _: - return np.float32 - - def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, - rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None: - if self == DotAlgorithm.Preset.DEFAULT: + +class DotAlgorithmPreset(enum.Enum): + """An enum of known algorithms for computing dot products. + + This ``Enum`` provides a named set of :class:`~jax.lax.DotAlgorithm` objects + that are known to be supported on at least platform. See the + :class:`~jax.lax.DotAlgorithm` documentation for more details about the + behavior of these algorithms. + + An algorithm can be selected from this list when calling :func:`~jax.lax.dot`, + :func:`~jax.lax.dot_general`, or most other JAX dot product functions, by + passing either a member of this ``Enum`` or it's name as a string using the + ``precision`` argument. + + For example, users can specify the preset using this ``Enum`` directly: + + >>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) + >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) + >>> algorithm = DotAlgorithmPreset.F16_F16_F32 + >>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP + array([ 1., 4., 9., 16.], dtype=float16) + + or, equivalently, they can be specified by name: + + >>> dot(lhs, rhs, precision="F16_F16_F32") # doctest: +SKIP + array([ 1., 4., 9., 16.], dtype=float16) + + The names of the presets are typically ``LHS_RHS_ACCUM`` where ``LHS`` and + ``RHS`` are the element types of the ``lhs`` and ``rhs`` inputs + respectively, and ``ACCUM`` is the element type of the accumulator. Some + presets have an extra suffix, and the meaning of each of these is + documented below. The supported presets are: + """ + DEFAULT = enum.auto() + """An algorithm will be selected based on input and output types.""" + + ANY_F8_ANY_F8_F32 = enum.auto() + """Accepts any float8 input types and accumulates into float32.""" + + ANY_F8_ANY_F8_F32_FAST_ACCUM = enum.auto() + """Like ``ANY_F8_ANY_F8_F32``, but using faster accumulation with the cost + of lower accuracy. + """ + + ANY_F8_ANY_F8_ANY = enum.auto() + """Like ``ANY_F8_ANY_F8_F32``, but the accumulation type is controlled by + ``preferred_element_type``. + """ + + ANY_F8_ANY_F8_ANY_FAST_ACCUM = enum.auto() + """Like ``ANY_F8_ANY_F8_F32_FAST_ACCUM``, but the accumulation type is + controlled by ``preferred_element_type``. + """ + + F16_F16_F16 = enum.auto() + F16_F16_F32 = enum.auto() + BF16_BF16_BF16 = enum.auto() + BF16_BF16_F32 = enum.auto() + BF16_BF16_F32_X3 = enum.auto() + """The ``_X3`` suffix indicates that the algorithm uses 3 operations to + emulate higher precision. + """ + + BF16_BF16_F32_X6 = enum.auto() + """Like ``BF16_BF16_F32_X3``, but using 6 operations instead of 3.""" + + TF32_TF32_F32 = enum.auto() + TF32_TF32_F32_X3 = enum.auto() + """The ``_X3`` suffix indicates that the algorithm uses 3 operations to + emulate higher precision. + """ + + F32_F32_F32 = enum.auto() + F64_F64_F64 = enum.auto() + + def __repr__(self) -> str: + return f'{self.__class__.__name__}.{self.name}' + + def __str__(self) -> str: + return self.name + + @property + def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: + match self: + case ( + DotAlgorithmPreset.DEFAULT | + DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | + DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM | + DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | + DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + ): return None + case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32: + return np.float16 + case ( + DotAlgorithmPreset.BF16_BF16_BF16 | + DotAlgorithmPreset.BF16_BF16_F32 + ): + # These algorithms support either f32 or bf32 input storage types. + # If either of those types are provided as input, we use the provided + # type. If not, we explicitly cast to bfloat16. + return (dtypes.bfloat16, np.float32) + case DotAlgorithmPreset.F64_F64_F64: + return np.float64 + case _: + return np.float32 + + @property + def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: + return self.lhs_precision_type + + @property + def accumulation_type(self) -> DTypeLike | None: + match self: + case ( + DotAlgorithmPreset.DEFAULT | + DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | + DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + ): + return None + case DotAlgorithmPreset.F16_F16_F16: + return np.float16 + case DotAlgorithmPreset.BF16_BF16_BF16: + return dtypes.bfloat16 + case DotAlgorithmPreset.F64_F64_F64: + return np.float64 + case _: + return np.float32 - if self in (DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, - DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM): + def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, + rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None: + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + f64 = ir.F64Type.get() + bf16 = ir.BF16Type.get() + tf32 = ir.FloatTF32Type.get() + match self: + case ( + DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | + DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM | + DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | + DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + ): fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), @@ -845,65 +949,53 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, acc = ir.F32Type.get() return hlo.DotAlgorithm.get( lhs, rhs, acc, 1, 1, 1, - self == DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM) + self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM) + case DotAlgorithmPreset.F16_F16_F16: + return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False) + case DotAlgorithmPreset.F16_F16_F32: + return hlo.DotAlgorithm.get(f16, f16, f32, 1, 1, 1, False) + case DotAlgorithmPreset.BF16_BF16_BF16: + return hlo.DotAlgorithm.get(bf16, bf16, bf16, 1, 1, 1, False) + case DotAlgorithmPreset.BF16_BF16_F32: + return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 1, False) + case DotAlgorithmPreset.BF16_BF16_F32_X3: + return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False) + case DotAlgorithmPreset.BF16_BF16_F32_X6: + return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False) + case DotAlgorithmPreset.TF32_TF32_F32: + return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False) + case DotAlgorithmPreset.TF32_TF32_F32_X3: + return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 3, False) + case DotAlgorithmPreset.F32_F32_F32: + return hlo.DotAlgorithm.get(f32, f32, f32, 1, 1, 1, False) + case DotAlgorithmPreset.F64_F64_F64: + return hlo.DotAlgorithm.get(f64, f64, f64, 1, 1, 1, False) + case _: + return None - else: - f16 = ir.F16Type.get() - f32 = ir.F32Type.get() - f64 = ir.F64Type.get() - bf16 = ir.BF16Type.get() - tf32 = ir.FloatTF32Type.get() - match self: - case DotAlgorithm.Preset.F16_F16_F16: - return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False) - case DotAlgorithm.Preset.F16_F16_F32: - return hlo.DotAlgorithm.get(f16, f16, f32, 1, 1, 1, False) - case DotAlgorithm.Preset.BF16_BF16_BF16: - return hlo.DotAlgorithm.get(bf16, bf16, bf16, 1, 1, 1, False) - case DotAlgorithm.Preset.BF16_BF16_F32: - return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 1, False) - case DotAlgorithm.Preset.BF16_BF16_F32_X3: - return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False) - case DotAlgorithm.Preset.BF16_BF16_F32_X6: - return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False) - case DotAlgorithm.Preset.TF32_TF32_F32: - return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False) - case DotAlgorithm.Preset.TF32_TF32_F32_X3: - return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 3, False) - case DotAlgorithm.Preset.F32_F32_F32: - return hlo.DotAlgorithm.get(f32, f32, f32, 1, 1, 1, False) - case DotAlgorithm.Preset.F64_F64_F64: - return hlo.DotAlgorithm.get(f64, f64, f64, 1, 1, 1, False) - case _: - raise NotImplementedError("unreachable") - - -DotAlgorithmLike = Union[ - DotAlgorithm, - DotAlgorithm.Preset, - str, + +PrecisionLike = Union[ None, -] -_DotAlgorithmLike = Union[ + str, + Precision, + tuple[str, str], + tuple[Precision, Precision], DotAlgorithm, - DotAlgorithm.Preset, - None, + DotAlgorithmPreset, ] -DotTransposeAlgorithmLike = Union[ - DotAlgorithmLike, - tuple[DotAlgorithmLike, DotAlgorithmLike], +CanonicalPrecision = Union[ + None, + tuple[Precision, Precision], + DotAlgorithm, + DotAlgorithmPreset, ] -DotTransposeAlgorithm = tuple[_DotAlgorithmLike, _DotAlgorithmLike] def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None, - algorithm: DotAlgorithmLike = None, - transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array: + preferred_element_type: DTypeLike | None = None) -> Array: """Vector/vector, matrix/vector, and matrix/matrix multiplication. - Wraps XLA's `Dot - `_ + Wraps XLA's `Dot `_ operator. For more general contraction, see the :func:`jax.lax.dot_general` operator. @@ -911,24 +1003,25 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, Args: lhs: an array of dimension 1 or 2. rhs: an array of dimension 1 or 2. - precision: Optional. Either ``None``, which means the default precision for - the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, - ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two - :class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``. - preferred_element_type: Optional. Either ``None``, which means the default - accumulation type for the input types, or a datatype, indicating to - accumulate results to and return a result with that datatype. - algorithm: Optional. Specify the algorithm used for accumulating the dot - product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument - cannot be used with ``precision`` or ``preferred_element_type``. - transpose_algorithm: Optional. This allows specifying the algorithm used when - this operation is transposed, typically as part of reverse-mode automatic - differentiation. This argument can either be a single - :class:`~jax.lax.DotAlgorithm` or a tuple of two - :class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the - algorithm for transposing the LHS and RHS, respectively. - ``transpose_algorithm`` must be explicitly specified when transposing a - dot product where a specific ``algorithm`` was used on the forward pass. + precision: Optional. This parameter controls the numerics of the + computation, and it can be one of the following: + + - ``None``, which means the default precision for the current backend, + - a :class:`~jax.lax.Precision` enum value or a tuple of two + :class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and + ``rhs``, or + - a :class:`~jax.lax.DotAlgorithm` or a + :class:`~jax.lax.DotAlgorithmPreset` indicating the algorithm that + must be used to accumulate the dot product. + + preferred_element_type: Optional. This parameter controls the data type + output by the dot product. By default, the output element type of this + operation will match the ``lhs`` and ``rhs`` input element types under + the usual type promotion rules. Setting ``preferred_element_type`` to a + specific ``dtype`` will mean that the operation returns that element type. + When ``precision`` is not a :class:`~jax.lax.DotAlgorithm` or + :class:`~jax.lax.DotAlgorithmPreset`, ``preferred_element_type`` provides + a hint to the compiler to accumulate the dot product using this data type. Returns: An array containing the product. @@ -936,9 +1029,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]): return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())), precision=precision, - preferred_element_type=preferred_element_type, - algorithm=algorithm, - transpose_algorithm=transpose_algorithm) + preferred_element_type=preferred_element_type) else: raise TypeError("Incompatible shapes for dot: got {} and {}.".format( lhs.shape, rhs.shape)) @@ -949,9 +1040,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None, - algorithm: DotAlgorithmLike = None, - transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array: + preferred_element_type: DTypeLike | None = None) -> Array: """General dot product/contraction operator. Wraps XLA's `DotGeneral @@ -970,29 +1059,31 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN lhs: an array rhs: an array dimension_numbers: a tuple of tuples of sequences of ints of the form - ``((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`` - precision: Optional. Either ``None``, which means the default precision for - the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, - ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two - :class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``. - preferred_element_type: Optional. Either ``None``, which means the default - accumulation type for the input types, or a datatype, indicating to - accumulate results to and return a result with that datatype. - algorithm: Optional. Specify the algorithm used for accumulating the dot - product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument - cannot be used with ``precision`` or ``preferred_element_type``. - transpose_algorithm: Optional. This allows specifying the algorithm used when - this operation is transposed, typically as part of reverse-mode automatic - differentiation. This argument can either be a single - :class:`~jax.lax.DotAlgorithm` or a tuple of two - :class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the - algorithm for transposing the LHS and RHS, respectively. - ``transpose_algorithm`` must be explicitly specified when transposing a - dot product where a specific ``algorithm`` was used on the forward pass. + ``((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, + rhs_batch_dims))`` + precision: Optional. This parameter controls the numerics of the + computation, and it can be one of the following: + + - ``None``, which means the default precision for the current backend, + - a :class:`~jax.lax.Precision` enum value or a tuple of two + :class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and + ``rhs``, or + - a :class:`~jax.lax.DotAlgorithm` or a + :class:`~jax.lax.DotAlgorithmPreset` indicating the algorithm that + must be used to accumulate the dot product. + + preferred_element_type: Optional. This parameter controls the data type + output by the dot product. By default, the output element type of this + operation will match the ``lhs`` and ``rhs`` input element types under + the usual type promotion rules. Setting ``preferred_element_type`` to a + specific ``dtype`` will mean that the operation returns that element type. + When ``precision`` is not a :class:`~jax.lax.DotAlgorithm` or + :class:`~jax.lax.DotAlgorithmPreset`, ``preferred_element_type`` provides + a hint to the compiler to accumulate the dot product using this data type. Returns: - An array whose first dimensions are the (shared) batch dimensions, followed by - the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs`` + An array whose first dimensions are the (shared) batch dimensions, followed + by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs`` non-contracting/non-batch dimensions. """ (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers @@ -1006,9 +1097,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), - preferred_element_type=preferred_element_type, - algorithm=canonicalize_dot_algorithm(algorithm), - transpose_algorithm=canonicalize_dot_transpose_algorithm(transpose_algorithm)) + preferred_element_type=preferred_element_type) def ragged_dot( @@ -1989,23 +2078,29 @@ def broadcasting_sharding_rule(name, *avals): f' another mesh: {a.sharding.mesh}') assert mesh is not None - result_specs = [] - for ss, ds in zip(zip(*specs), zip(*shapes)): + result_specs = [None] * len(shapes[0]) + for i, (ss, ds) in enumerate(zip(zip(*specs), zip(*shapes))): if all(s == ss[0] for s in ss[1:]): # if all dimension shardings are same, the resulting dimension sharding is # the same. - result_specs.append(ss[0]) + result_specs[i] = ss[0] else: non_trivial_s = [s for s, d in zip(ss, ds) if not (core.definitely_equal(d, 1) and s is None)] if not non_trivial_s: - result_specs.append(None) + result_specs[i] = None elif all(non_trivial_s[0] == s for s in non_trivial_s[1:]): - result_specs.append(non_trivial_s[0]) + result_specs[i] = non_trivial_s[0] else: - raise TypeError(f'{name} got incompatible shardings for broadcasting: ' - f'{", ".join(map(str, map(tuple, specs)))}.') - return NamedSharding(mesh, PartitionSpec(*result_specs)) + for s in ss: + if result_specs[i] is None and s is not None: + result_specs[i] = s + elif (result_specs[i] is not None and s is not None and + result_specs[i] != s): + raise TypeError( + f'{name} got incompatible shardings for broadcasting: ' + f'{", ".join(map(str, map(tuple, specs)))}.') + return NamedSharding(mesh, P(*result_specs)) def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, @@ -2091,12 +2186,8 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): def _nary_lower_hlo(op: Callable, ctx, - *args: ir.Value, - explicit_type=False, **params) -> Sequence[ir.Value]: + *args: ir.Value, **params) -> Sequence[ir.Value]: """Lowers an elementwise operator to its MLIR equivalent. - - Args: - explicit_type: does the MLIR op require its output type to be provided? """ del params avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out @@ -2104,10 +2195,7 @@ def _nary_lower_hlo(op: Callable, ctx, if config.sharding_in_types.value: args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) - if explicit_type: - out = op(mlir.aval_to_ir_type(aval_out), *args) - else: - out = op(*args) + out = op(*args) if config.sharding_in_types.value: if config.use_shardy_partitioner.value: out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim) @@ -3055,9 +3143,7 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type): def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None, - algorithm: _DotAlgorithmLike = None, - transpose_algorithm: DotTransposeAlgorithm | None = None): + preferred_element_type: DTypeLike | None): (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) for d in (lhs_contracting, lhs_batch)): @@ -3133,10 +3219,8 @@ def tuple_delete(tup, idx): def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None, - algorithm: _DotAlgorithmLike = None, - transpose_algorithm: DotTransposeAlgorithm | None = None): - del dimension_numbers, precision # unused + preferred_element_type: DTypeLike | None): + del dimension_numbers # unused # We're mostly matching XLA's logic here, namely in shape_inference.cc and # primitive_util.h's HigherPrecisionType, e.g. # https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329 @@ -3157,23 +3241,9 @@ def type_properties(dt): raise TypeError( f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}") result_dtype = lhs.dtype - - if transpose_algorithm is not None and algorithm is None: - raise ValueError( - "When the algorithm argument to dot_general is None, the " - "transpose_algorithm argument is unused and must also be None.") - - if algorithm is not None and algorithm != DotAlgorithm.Preset.DEFAULT: - if preferred_element_type is not None: - raise ValueError( - "The preferred_element_type and algorithm arguments to dot_general " - "cannot both be specified.") - - # This is used to ensure that the output type is equal to the accumulation - # type whenever an algorithm is specified. - preferred_element_type = algorithm.accumulation_type - - return _maybe_upcast(result_dtype, preferred_element_type) + has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)) + return _maybe_upcast(result_dtype, preferred_element_type, + check_bit_width=not has_algorithm) def _bit_width(d): if dtypes.issubdtype(d, np.inexact): return dtypes.finfo(d).bits @@ -3181,12 +3251,12 @@ def _bit_width(d): elif d == np.dtype('bool'): return 1 else: assert False, d # should be unreachable, open an issue! -def _maybe_upcast(result_dtype, preferred_element_type): +def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width): # replicates the logic in shape_inference.cc's MaybeUpcast if (preferred_element_type is None or result_dtype == preferred_element_type): return result_dtype - if (not dtypes.issubdtype(result_dtype, np.floating) and + if (check_bit_width and not dtypes.issubdtype(result_dtype, np.floating) and _bit_width(preferred_element_type) < _bit_width(result_dtype)): raise TypeError("`preferred_element_type` must not be narrower than the " "original type, got preferred_element_type of " @@ -3196,8 +3266,6 @@ def _maybe_upcast(result_dtype, preferred_element_type): def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - algorithm: _DotAlgorithmLike = None, - transpose_algorithm: DotTransposeAlgorithm | None = None, swap_ans=False): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.aval.ndim @@ -3210,47 +3278,41 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, dims = ((ans_y, y_kept), (ans_batch, y_batch)) x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) - if algorithm is not None: - if transpose_algorithm is None or transpose_algorithm[0] is None: - raise ValueError( - "When a dot_general algorithm is specified on the forward pass, " - "transpose_algorithm must be specified for the backward pass.") - lhs_alg, rhs_alg = transpose_algorithm - transpose_algorithm = (algorithm, rhs_alg) - algorithm = lhs_alg x_bar = transpose(dot_general(g, y, dims, precision=precision, - preferred_element_type=preferred_element_type, - algorithm=algorithm, - transpose_algorithm=transpose_algorithm), + preferred_element_type=preferred_element_type), tuple(out_axes)) if x_bar.dtype != x.aval.dtype: x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type) return x_bar def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None, - algorithm: _DotAlgorithmLike = None, - transpose_algorithm: DotTransposeAlgorithm | None = None): + preferred_element_type: DTypeLike | None): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) - transpose_algorithm = None if transpose_algorithm is None else ( - transpose_algorithm[1], transpose_algorithm[0]) y_bar = _dot_general_transpose_lhs( g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision, - preferred_element_type=preferred_element_type, algorithm=algorithm, - transpose_algorithm=transpose_algorithm, - swap_ans=True) + preferred_element_type=preferred_element_type, swap_ans=True) if y_bar.dtype != y.aval.dtype: y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type) return y_bar -def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, - precision, - preferred_element_type: DTypeLike | None, - algorithm: _DotAlgorithmLike = None, - transpose_algorithm: DotTransposeAlgorithm | None = None): - lhs, rhs = batched_args - lbd, rbd = batch_dims + +def _dot_batch_rule( + unpack_args, + unpack_dims, + invoke_prim, + batched_args, + batch_dims, + *, + dimension_numbers, + precision, + preferred_element_type: DTypeLike | None, + **_, +): + + lhs, rhs = unpack_args(batched_args) + lbd, rbd = unpack_dims(batch_dims) + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd @@ -3272,16 +3334,19 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, rhs_shape = batching.bdim_as_shape(rbd, rhs.shape) else: rhs_shape = np.shape(rhs) - batched_out = dot_general(lhs, rhs, new_dimension_numbers, - precision=precision, - preferred_element_type=preferred_element_type, - algorithm=algorithm, - transpose_algorithm=transpose_algorithm) + batched_out = invoke_prim( + lhs, + rhs, + new_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + ) result_batch_dim = batching.shape_as_bdim( result_stack_dim, _dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers)) return batched_out, result_batch_dim + def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): # There are three kinds of dimensions in a dot_general: # - contraction dimensions appear in lhs and rhs but not the result @@ -3356,14 +3421,41 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc: dot_general_p = standard_primitive(_dot_general_shape_rule, _dot_general_dtype_rule, 'dot_general') + + +def _dot_general_batch_unpack_args(batch_args): + lhs, rhs = batch_args + return (lhs, rhs) + + +def _dot_general_batch_unpack_dims(batch_dims): + lbd, rbd = batch_dims + return (lbd, rbd) + +# DotDimensionNumbers used in the dot_general call for ragged_dot(). +_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = ( + ([2, 0], [1, 0]), + ([], []), +) +_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = ( + ([3, 1], [2, 1]), + ([0], [0]), +) + ad.defbilinear(dot_general_p, _dot_general_transpose_lhs, _dot_general_transpose_rhs) +_dot_general_batch_rule = functools.partial( + _dot_batch_rule, + _dot_general_batch_unpack_args, + _dot_general_batch_unpack_dims, + dot_general, +) batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule pe.padding_rules[dot_general_p] = _dot_general_padding_rule core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule def precision_attr(precision: Precision) -> ir.ArrayAttr: - if precision is None: + if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): full_precision = (Precision.DEFAULT, Precision.DEFAULT) elif not isinstance(precision, tuple): full_precision = (precision, precision) @@ -3373,19 +3465,16 @@ def precision_attr(precision: Precision) -> ir.ArrayAttr: [hlo.PrecisionAttr.get(str(p)) for p in full_precision]) -def dot_algorithm_attr(algorithm: _DotAlgorithmLike, lhs_dtype: DTypeLike, +def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike, rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None: - if algorithm is None: + if not isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): return None - return algorithm._convert_to_hlo_attr(lhs_dtype, rhs_dtype) + return precision._convert_to_hlo_attr(lhs_dtype, rhs_dtype) def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: np.dtype | None, - algorithm: _DotAlgorithmLike = None, - transpose_algorithm: DotTransposeAlgorithm | None = None, platform: str = "default"): - del transpose_algorithm # unused def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) @@ -3394,63 +3483,87 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): lhs_aval, rhs_aval = ctx.avals_in lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype aval_out, = ctx.avals_out + accumulation_aval = aval_out (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers - # TODO(b/...): JAX's dot_general primitive accepts the same input dtype - # combinations that are accepted in XLA's shape_inference.cc (the canonical - # reference for the HLO type system), but actually different XLA platforms - # fail on codegen for different accepted cases. To handle those cases, we - # insert ConvertOps on the input, in a platform-dependent way. - if lhs_dtype != rhs_dtype: - if platform == "tpu": - handled = lambda dt: (dtypes.issubdtype(dt, np.floating) or - dtypes.issubdtype(dt, np.integer)) - if not (handled(lhs_dtype) and handled(rhs_dtype)): - lhs = mlir.convert_hlo(ctx, lhs, lhs_aval, - core.ShapedArray(lhs_aval.shape, aval_out.dtype)) - rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, - core.ShapedArray(rhs_aval.shape, aval_out.dtype)) - lhs_dtype = rhs_dtype = aval_out.dtype - else: # cpu and gpu - # Do not convert mixed fp8 types to output type. - if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype): - lhs = mlir.convert_hlo(ctx, lhs, lhs_aval, - core.ShapedArray(lhs_aval.shape, aval_out.dtype)) - rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, - core.ShapedArray(rhs_aval.shape, aval_out.dtype)) - lhs_dtype = rhs_dtype = aval_out.dtype - - dot_dnums = hlo.DotDimensionNumbers.get( lhs_batching_dimensions=list(lhs_batch), rhs_batching_dimensions=list(rhs_batch), lhs_contracting_dimensions=list(lhs_contracting), rhs_contracting_dimensions=list(rhs_contracting)) - if algorithm is not None and precision not in { - None, Precision.DEFAULT, (Precision.DEFAULT, Precision.DEFAULT)}: - raise ValueError( - "The dot_general precision must be None or DEFAULT when an algorithm " - "is specified.") - if jaxlib_version <= (0, 4, 33): - if algorithm is not None: + algorithm_kwarg = {} + if isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): + # The CPU backend silently ignores the algorithm spec, so we check here to + # make sure that the selected algorithm is supported. We could be a little + # bit more liberal here (any algorithm where the input and output types + # match and all the other parameters have default values should work), but + # it's probably sufficient to just check the presets here. + if platform == "cpu" and precision not in { + DotAlgorithmPreset.DEFAULT, DotAlgorithmPreset.F16_F16_F16, + DotAlgorithmPreset.F32_F32_F32, DotAlgorithmPreset.F64_F64_F64, + }: raise ValueError( - "The dot_general algorithm parameter is only supported for jaxlib " - "versions larger than 0.4.33.") - algorithm_kwargs = {} + f"The precision '{precision}' is not supported by dot_general on CPU") + + # If an explicit algorithm was specified, we always cast the input types to + # the correct types. + def maybe_convert_dtype(operand, operand_aval, target_dtype): + if target_dtype is None: + return operand, operand_aval.dtype + if not isinstance(target_dtype, tuple): + target_dtype = (target_dtype,) + if any(operand_aval.dtype == d for d in target_dtype): + return operand, operand_aval.dtype + aval = core.ShapedArray(operand_aval.shape, target_dtype[0]) + return mlir.convert_hlo(ctx, operand, operand_aval, aval), target_dtype[0] + + lhs, lhs_dtype = maybe_convert_dtype(lhs, lhs_aval, precision.lhs_precision_type) + rhs, rhs_dtype = maybe_convert_dtype(rhs, rhs_aval, precision.rhs_precision_type) + accumulation_type = precision.accumulation_type + if accumulation_type is not None: + accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_type) + + if precision != DotAlgorithmPreset.DEFAULT: + algorithm_kwarg = { + "algorithm": dot_algorithm_attr(precision, lhs_dtype, rhs_dtype) + } else: - algorithm_kwargs = {"algorithm": dot_algorithm_attr(algorithm, lhs_dtype, - rhs_dtype)} - return [ - hlo.dot_general( - mlir.aval_to_ir_type(aval_out), - lhs, - rhs, - dot_dnums, - precision_config=precision_attr(precision), - **algorithm_kwargs, - ) - ] + # TODO(b/...): JAX's dot_general primitive accepts the same input dtype + # combinations that are accepted in XLA's shape_inference.cc (the canonical + # reference for the HLO type system), but actually different XLA platforms + # fail on codegen for different accepted cases. To handle those cases, we + # insert ConvertOps on the input, in a platform-dependent way. + if lhs_dtype != rhs_dtype: + if platform == "tpu": + handled = lambda dt: (dtypes.issubdtype(dt, np.floating) or + dtypes.issubdtype(dt, np.integer)) + if not (handled(lhs_dtype) and handled(rhs_dtype)): + lhs = mlir.convert_hlo(ctx, lhs, lhs_aval, + core.ShapedArray(lhs_aval.shape, aval_out.dtype)) + rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, + core.ShapedArray(rhs_aval.shape, aval_out.dtype)) + lhs_dtype = rhs_dtype = aval_out.dtype + else: # cpu and gpu + # Do not convert mixed fp8 types to output type. + if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype): + lhs = mlir.convert_hlo(ctx, lhs, lhs_aval, + core.ShapedArray(lhs_aval.shape, aval_out.dtype)) + rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, + core.ShapedArray(rhs_aval.shape, aval_out.dtype)) + lhs_dtype = rhs_dtype = aval_out.dtype + + result = hlo.dot_general( + mlir.aval_to_ir_type(accumulation_aval), + lhs, + rhs, + dot_dnums, + precision_config=precision_attr(precision), + **algorithm_kwarg, + ) + if accumulation_aval.dtype != aval_out.dtype: + result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) + return [result] mlir.register_lowering(dot_general_p, _dot_general_lower) @@ -3461,6 +3574,34 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape: + if len(lhs.shape) == 3: + # Batched case + b, m, k = lhs.shape + b2, group_count, rk, n = rhs.shape + b3 = group_sizes.shape[0] + if b != b2: + raise TypeError( + f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and' + f' {b2}.' + ) + if b3 != b: + raise TypeError( + 'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got' + f' {b3} and {b}.' + ) + if k != rk: + raise TypeError( + f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and' + f' {rk}.' + ) + num_groups = group_sizes.shape[1] + if group_count != num_groups: + raise TypeError( + 'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got' + f' {group_count} and {num_groups}.' + ) + return (b, m, n) + m, k = lhs.shape group_count, rk, n = rhs.shape if k != rk: @@ -3470,17 +3611,13 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.") return (m, n) -# DotDimensionNumbers used in the dot_general call for ragged_dot(). -_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], [])) - def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype: if not dtypes.issubdtype(group_sizes.dtype, np.integer): raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.") # defer the output dtype to dot_general, which is part of the _ragged_dot_impl. return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, - precision=precision, preferred_element_type=preferred_element_type, - algorithm=None, transpose_algorithm=None) + precision=precision, preferred_element_type=preferred_element_type) def _ragged_dot_jvp_rule( @@ -3584,11 +3721,63 @@ def _ragged_dot_transpose_rule( return grad_x, grad_y, None +def _ragged_dot_batch_unpack_args(batched_args): + lhs, rhs, _ = batched_args + return (lhs, rhs) + + +def _ragged_dot_batch_unpack_dims(batch_dims): + if not all(dim == 0 for dim in batch_dims): + raise NotImplementedError('ragged_dot vmap over any dim but 0 - NYI') + lbd, rbd, _ = batch_dims + return (lbd, rbd) + + +def _ragged_dot_invoke_prim( + group_sizes, + lhs, + rhs, + new_dimension_numbers, + precision, + preferred_element_type, +): + return ragged_dot( + lhs, + rhs, + group_sizes, + precision=precision, + preferred_element_type=preferred_element_type, + ) + + +def _ragged_dot_batch_rule( + batched_args, + batch_dims, + *, + precision, + preferred_element_type: DTypeLike | None, + **_, +): + invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2]) + + return _dot_batch_rule( + _ragged_dot_batch_unpack_args, + _ragged_dot_batch_unpack_dims, + invoke, + batched_args, + batch_dims, + dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + precision=precision, + preferred_element_type=preferred_element_type, + ) + + ragged_dot_p = standard_primitive(_ragged_dot_shape_rule, _ragged_dot_dtype_rule, 'ragged_dot') ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p)) ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule +batching.primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule def _ragged_dot_impl( lhs: Array, @@ -3600,11 +3789,20 @@ def _ragged_dot_impl( ) -> Array: if group_offset is not None: raise NotImplementedError("Unimplemented group_offset support.") - lhs = _ragged_to_dense(lhs, rhs, group_sizes=group_sizes) + + if len(lhs.shape) == 3: + ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS + ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0)) + else: + ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS + ragged_to_dense = _ragged_to_dense + + lhs = ragged_to_dense(lhs, rhs, group_sizes) + return dot_general( lhs, rhs, - dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + dimension_numbers=ragged_dot_dims, precision=precision, preferred_element_type=preferred_element_type, ) @@ -4941,18 +5139,15 @@ def _top_k_jvp(primals, tangents, *, k): idx_shape = k_idxs.shape rank = len(idx_shape) gather_index_shape = idx_shape + (1,) - gather_indices = [] - for i in range(rank-1): - _iota = iota(k_idxs.dtype, idx_shape[i]) - _iota = broadcast_in_dim(_iota, gather_index_shape, (i,)) - gather_indices.append(_iota) - gather_indices.append(reshape(k_idxs, gather_index_shape)) - gather_indices = concatenate(gather_indices, dimension=rank) + gather_indices = reshape(k_idxs, gather_index_shape) slice_sizes = (1,) * rank dnums = slicing.GatherDimensionNumbers( - offset_dims=(), - collapsed_slice_dims=tuple(range(rank)), - start_index_map=tuple(range(rank))) + offset_dims=(), + collapsed_slice_dims=(rank - 1,), + operand_batching_dims=tuple(range(rank - 1)), + start_indices_batching_dims=tuple(range(rank - 1)), + start_index_map=(rank - 1,), + ) tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes) return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1])) @@ -5197,7 +5392,20 @@ def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm): del shape, dtype, algorithm return (key.weak_type, False) -RandomAlgorithm = xops.RandomAlgorithm + +class RandomAlgorithm(enum.IntEnum): + """Describes which PRNG algorithm to use for rng_bit_generator.""" + + RNG_DEFAULT = 0 + "The platform's default algorithm." + + RNG_THREE_FRY = 1 + "The Threefry-2x32 PRNG algorithm." + + RNG_PHILOX = 2 + "The Philox-4x32 PRNG algorithm." + + RandomAlgorithm.__str__ = lambda algorithm: algorithm.name # type: ignore[method-assign] def _rng_algorithm(algorithm: RandomAlgorithm): @@ -5220,7 +5428,7 @@ def _rng_bit_generator_lowering( # need to convert u32[4] -> u64[2] here in the translation rule. However, we # also polymorphically allow a u64[2] for backward compatibility. # - # Separately, xops.RngBitGenerator doesn't support generating u8 or + # Separately, RngBitGenerator doesn't support generating u8 or # u16, so we request u32 and truncate in that case. u32_type = ir.IntegerType.get_unsigned(32) u64_type = ir.IntegerType.get_unsigned(64) @@ -5639,7 +5847,7 @@ def remaining(original, *removed_lists): return [i for i in original if i not in removed] -def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None: +def canonicalize_precision(precision: PrecisionLike) -> CanonicalPrecision: """Turns an API precision specification into a pair of enumeration values. The API can take the precision as a string, or int, and either as a single @@ -5649,56 +5857,44 @@ def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precisi if config.default_matmul_precision.value is None: return None try: - return ( - Precision(config.default_matmul_precision.value), - Precision(config.default_matmul_precision.value), - ) - except TypeError: + return canonicalize_precision(config.default_matmul_precision.value) + except ValueError: raise ValueError( - "jax_default_matmul_precision flag must be set to None or a value in " - f"{list(_precision_strings)}, but got {config.default_matmul_precision.value}" + "jax_default_matmul_precision flag must be set to None, a value in " + f"{list(_precision_strings)}, or the name of a lax.DotAlgorithmPreset, " + f"but got {config.default_matmul_precision.value}" ) from None - elif isinstance(precision, str) and precision in _precision_strings: - return Precision(precision), Precision(precision) + elif isinstance(precision, str): + if precision in _precision_strings: + return Precision(precision), Precision(precision) + else: + try: + return DotAlgorithmPreset[precision] + except KeyError: + pass elif isinstance(precision, Precision): return precision, precision + elif isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): + return precision elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and all(isinstance(p, Precision) for p in precision)): return type_cast(tuple[Precision, Precision], precision) elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and all(isinstance(s, str) for s in precision)): - s1, s2 = precision + s1, s2 = type_cast(tuple[str, str], precision) p1 = type_cast(tuple[Precision, Precision], canonicalize_precision(s1))[0] p2 = type_cast(tuple[Precision, Precision], canonicalize_precision(s2))[0] return (p1, p2) - else: - raise ValueError( - f"Precision argument must be None, a string in {list(_precision_strings)}, " - "a lax.Precision value or a tuple of two lax.Precision values or " - f"strings; got {precision}.") - -def canonicalize_dot_algorithm(algorithm: DotAlgorithmLike) -> _DotAlgorithmLike: - if isinstance(algorithm, str): - algorithm = DotAlgorithm.Preset[algorithm] - if algorithm is None or algorithm == DotAlgorithm.Preset.DEFAULT: - return None - return algorithm + raise ValueError( + "Precision argument must be one of:\n" + "- None,\n" + f"- a string in {list(_precision_strings)},\n" + "- a lax.Precision value,\n" + "- a tuple of two lax.Precision values or strings,\n" + "- a lax.DotAlgorithmPreset or the name of one of these presets, or\n" + "- a lax.DotAlgorithm value;\n" + f"but got {precision}.") -def canonicalize_dot_transpose_algorithm( - algorithm: DotTransposeAlgorithmLike) -> DotTransposeAlgorithm | None: - if algorithm is None: - return None - elif isinstance(algorithm, DotAlgorithm): - return (algorithm, algorithm) - elif isinstance(algorithm, tuple): - if len(algorithm) != 2: - raise ValueError( - "The transpose_algorithm argument must be a single value or a tuple " - f"of two values; got {algorithm}.") - return (canonicalize_dot_algorithm(algorithm[0]), - canonicalize_dot_algorithm(algorithm[1])) - algorithm = canonicalize_dot_algorithm(algorithm) - return (algorithm, algorithm) def _balanced_eq(x, z, y): return div(select(_eq_meet(x, z), _ones(z), _zeros(z)), diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index ef6a5a11a56e..d352a47dbc74 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -47,7 +47,6 @@ from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -709,8 +708,7 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - ctx_args = (ctx,) - w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand, + w, vl, vr, info = lapack.geev_hlo(ctx, operand_aval.dtype, operand, input_shape_vals=op_shape_vals, jobvl=compute_left_eigenvectors, jobvr=compute_right_eigenvectors) @@ -875,7 +873,7 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( - "Argument to symmetric eigendecomposition must have shape [..., n, n]," + "Argument to symmetric eigendecomposition must have shape [..., n, n], " "got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] @@ -896,33 +894,39 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index): def _eigh_cpu_gpu_lowering( - syevd_impl, ctx, operand, *, lower, sort_eigenvalues, subset_by_index, - platform=None + ctx, operand, *, lower, sort_eigenvalues, subset_by_index, + target_name_prefix: str ): del sort_eigenvalues # The CPU/GPU implementations always sort. operand_aval, = ctx.avals_in v_aval, w_aval = ctx.avals_out n = operand_aval.shape[-1] - batch_dims = operand_aval.shape[:-2] - - # The eigh implementation on CPU and GPU uses lapack helper routines to - # find the size of the workspace based on the non-batch dimensions. - # Therefore, we cannot yet support dynamic non-batch dimensions. - if not is_constant_shape(operand_aval.shape[-2:]): - raise NotImplementedError( - "Shape polymorphism for native lowering for eigh is implemented " - f"only for the batch dimensions: {operand_aval.shape}") - if not (subset_by_index is None or subset_by_index == (0, n)): - raise NotImplementedError("subset_by_index not implemented for CPU and GPU") + raise NotImplementedError("subset_by_index not supported on CPU and GPU") + batch_dims = operand_aval.shape[:-2] + nb = len(batch_dims) + layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1)) + result_layouts = [layout, tuple(range(nb, -1, -1)), + tuple(range(nb - 1, -1, -1))] + if target_name_prefix == "cpu": + dtype = operand_aval.dtype + prefix = "he" if dtypes.issubdtype(dtype, np.complexfloating) else "sy" + target_name = lapack.prepare_lapack_call(f"{prefix}evd_ffi", + operand_aval.dtype) + kwargs = { + "mode": np.uint8(ord("V")), + "uplo": np.uint8(ord("L" if lower else "U")), + } + else: + target_name = f"{target_name_prefix}solver_syevd_ffi" + kwargs = {"lower": lower, "algorithm": np.uint8(0)} - op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - cpu_args = [] - if platform == "cpu": - ctx_args = (ctx,) - cpu_args.extend(ctx_args) - v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand, - a_shape_vals=op_shape_vals, lower=lower) + rule = ffi.ffi_lowering(target_name, operand_layouts=[layout], + result_layouts=result_layouts, + operand_output_aliases={0: 0}) + info_aval = ShapedArray(batch_dims, np.dtype(np.int32)) + sub_ctx = ctx.replace(avals_out=[v_aval, w_aval, info_aval]) + v, w, info = rule(sub_ctx, operand, **kwargs) zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED") @@ -1056,17 +1060,15 @@ def _eigh_batching_rule( batching.primitive_batchers[eigh_p] = _eigh_batching_rule mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo, platform='cpu'), + eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cpu'), platform='cpu') if gpu_solver is not None: mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd, - platform='cuda'), + eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd, - platform='rocm'), + eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='hip'), platform='rocm') mlir.register_lowering( @@ -1303,13 +1305,9 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *, def _lu_pivots_to_permutation_gpu_lowering(platform, ctx, pivots, *, permutation_size): + del permutation_size # unused rule = ffi.ffi_lowering(f"{platform}_lu_pivots_to_permutation") - # TODO(b/358275922): remove unused once jaxlib v0.4.32 is the minimum version. - if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32): - kwargs = dict(permutation_size=np.int32(permutation_size)) - else: - kwargs = {} - return rule(ctx, pivots, **kwargs) + return rule(ctx, pivots) lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation') @@ -1484,47 +1482,29 @@ def _lu_batching_rule(batched_args, batch_dims): x = batching.moveaxis(x, bd, 0) return lu_p.bind(x), (0, 0, 0) -def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, platform: str, - target_name_prefix: str): +def _lu_cpu_gpu_lowering(ctx, operand, *, target_name_prefix: str): operand_aval, = ctx.avals_in out_aval, pivot_aval, perm_aval = ctx.avals_out batch_dims = operand_aval.shape[:-2] info_aval = ShapedArray(batch_dims, np.dtype(np.int32)) m = operand_aval.shape[-2] - # TODO(b/357034884): Remove version gate on the forward compat flag after the - # 3 week compatibility window. - if ctx.is_forward_compat(): - if not is_constant_shape(operand_aval.shape[-2:]): - raise NotImplementedError( - "Shape polymorphism for native lowering for lu on CPU and GPU is " - f"implemented only for the batch dimensions: {operand_aval.shape}") - if platform in ["cuda", "rocm"]: - if not is_constant_shape(operand_aval.shape): - raise NotImplementedError( - "Shape polymorphism for native serialization for lu on GPU is not " - f"implemented; b/261671778; {operand_aval.shape}") - lu, pivot, info = getrf_impl(operand_aval.dtype, operand) - else: - op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - lu, pivot, info = getrf_impl( - operand_aval.dtype, operand, a_shape_vals=op_shape_vals) + if target_name_prefix == "cpu": + target_name = lapack.prepare_lapack_call("getrf_ffi", operand_aval.dtype) else: - if target_name_prefix == "cpu": - target_name = lapack.prepare_lapack_call("getrf_ffi", operand_aval.dtype) - else: - target_name = f"{target_name_prefix}solver_getrf_ffi" - # We manually construct the layouts because the input and output are - # expected to be in Fortran order. - nb = len(batch_dims) - layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1)) - result_layouts = [layout, tuple(range(nb, -1, -1)), - tuple(range(nb - 1, -1, -1))] - rule = ffi.ffi_lowering(target_name, operand_layouts=[layout], - result_layouts=result_layouts, - operand_output_aliases={0: 0}) - sub_ctx = ctx.replace(avals_out=[out_aval, pivot_aval, info_aval]) - lu, pivot, info = rule(sub_ctx, operand) + target_name = f"{target_name_prefix}solver_getrf_ffi" + + # We manually construct the layouts because the input and output are + # expected to be in Fortran order. + nb = len(batch_dims) + layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1)) + result_layouts = [layout, tuple(range(nb, -1, -1)), + tuple(range(nb - 1, -1, -1))] + rule = ffi.ffi_lowering(target_name, operand_layouts=[layout], + result_layouts=result_layouts, + operand_output_aliases={0: 0}) + sub_ctx = ctx.replace(avals_out=[out_aval, pivot_aval, info_aval]) + lu, pivot, info = rule(sub_ctx, operand) # Subtract 1 from the pivot to get 0-based indices. pivot = hlo.subtract(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)) @@ -1571,19 +1551,16 @@ def _lu_tpu_lowering_rule(ctx, operand): ad.primitive_jvps[lu_p] = _lu_jvp_rule batching.primitive_batchers[lu_p] = _lu_batching_rule -mlir.register_lowering(lu_p, - partial(_lu_cpu_gpu_lowering, lapack.getrf_hlo, - platform='cpu', target_name_prefix="cpu"), - platform='cpu') +mlir.register_lowering( + lu_p, partial(_lu_cpu_gpu_lowering, target_name_prefix="cpu"), + platform="cpu") mlir.register_lowering( - lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.cuda_getrf, - platform='cuda', target_name_prefix="cu"), - platform='cuda') + lu_p, partial(_lu_cpu_gpu_lowering, target_name_prefix="cu"), + platform="cuda") mlir.register_lowering( - lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.rocm_getrf, - platform='rocm', target_name_prefix="hip"), - platform='rocm') + lu_p, partial(_lu_cpu_gpu_lowering, target_name_prefix="hip"), + platform="rocm") mlir.register_lowering(lu_p, _lu_tpu_lowering_rule, platform='tpu') @@ -1601,7 +1578,7 @@ def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array conjugate_a=conj) x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True, transpose_a=True, conjugate_a=conj) - _, ind = lax.sort_key_val(permutation, lax.iota('int32', len(permutation))) + _, ind = lax.sort_key_val(permutation, lax.iota('int32', permutation.shape[0])) x = x[ind, :] else: raise ValueError(f"'trans' value must be 0, 1, or 2, got {trans}") @@ -1678,7 +1655,7 @@ def _geqrf_abstract_eval(operand): if operand.ndim < 2: raise ValueError("Argument to QR decomposition must have ndims >= 2") *batch_dims, m, n = operand.shape - taus = operand.update(shape=(*batch_dims, min(m, n))) + taus = operand.update(shape=(*batch_dims, core.min_dim(m, n))) return operand, taus def _geqrf_batching_rule(batched_args, batch_dims): @@ -1707,60 +1684,20 @@ def _geqrf_lowering_rule(ctx, operand): ) return op.results -def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *, - platform: str): - a_aval, taus_aval = ctx.avals_out - *batch_dims, m, n = a_aval.shape - # It should be possible to support fully-dynamic shapes, but since - # the last two dimensions (m, n) are used in more involved ways, we only - # support dynamic dimensions for the batch size for now. - if not is_constant_shape([m, n]): - raise NotImplementedError( - "Shape polymorphism for native serialization for qr on CPU and GPU is " - f"implemented only for the batch dimensions: {a_aval.shape}") - batch = math.prod(batch_dims) - - if batch == 0 or m == 0 or n == 0: - return mlir.full_like_aval(ctx, 0, a_aval), mlir.full_like_aval(ctx, 0, taus_aval) - - if not is_constant_shape(a_aval.shape): - if platform in ["cuda", "rocm"]: - # TODO(necula): remove the platform kwarg when we implement GPU support. - raise NotImplementedError( - "Shape polymorphism for native serialization for QR is not " - f"implemented, try to upgrade jaxlib; b/261671778; {a_aval.shape}") - - if (batched_geqrf_impl is not None and batch > 1 and m // batch <= 128 and - n // batch <= 128): - a_out, taus = batched_geqrf_impl(a_aval.dtype, a) +def _geqrf_cpu_gpu_lowering(ctx, a, *, target_name_prefix: str): + operand_aval, = ctx.avals_in + batch_dims = operand_aval.shape[:-2] + nb = len(batch_dims) + layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1)) + result_layouts = [layout, tuple(range(nb, -1, -1))] + if target_name_prefix == "cpu": + target_name = lapack.prepare_lapack_call("geqrf_ffi", operand_aval.dtype) else: - if platform in ["cuda", "rocm"]: - a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a) - else: - a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) - ctx_args = ( - (ctx,) if platform == "cpu" else () - ) - a_out, taus, *maybe_info_geqrf = geqrf_impl( - *ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals - ) - if not ctx.is_forward_compat(): - # Skip the info parameter verification for the FFI kernel. - return a_out, taus - # TODO(b/344892332): This parameter will no longer be needed after - # the forward compatibility period - info_geqrf = maybe_info_geqrf[0] - zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) - ok = mlir.compare_hlo(info_geqrf, zeros, "EQ", "SIGNED") - select_ok_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_)) - ok_a = mlir.broadcast_in_dim(ctx, ok, select_ok_a_aval, - broadcast_dimensions=range(len(batch_dims))) - a_out = _broadcasting_select_hlo(ctx, ok_a, select_ok_a_aval, a_out, a_aval, _nan_like_hlo(ctx, a_aval), a_aval) - select_ok_taus_aval = ShapedArray(batch_dims + [1], np.dtype(np.bool_)) - ok_taus = mlir.broadcast_in_dim(ctx, ok, select_ok_taus_aval, - broadcast_dimensions=range(len(batch_dims))) - taus = _broadcasting_select_hlo(ctx, ok_taus, select_ok_taus_aval, taus, taus_aval, _nan_like_hlo(ctx, taus_aval), taus_aval) - return a_out, taus + target_name = f"{target_name_prefix}solver_geqrf_ffi" + rule = ffi.ffi_lowering(target_name, operand_layouts=[layout], + result_layouts=result_layouts, + operand_output_aliases={0: 0}) + return rule(ctx, a) geqrf_p = Primitive('geqrf') geqrf_p.multiple_results = True @@ -1770,20 +1707,15 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *, mlir.register_lowering(geqrf_p, _geqrf_lowering_rule) mlir.register_lowering( - geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_hlo, None, - platform='cpu'), + geqrf_p, partial(_geqrf_cpu_gpu_lowering, target_name_prefix='cpu'), platform='cpu') mlir.register_lowering( geqrf_p, - partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf, - gpu_solver.cuda_geqrf_batched, - platform='cuda'), + partial(_geqrf_cpu_gpu_lowering, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( geqrf_p, - partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf, - gpu_solver.rocm_geqrf_batched, - platform='rocm'), + partial(_geqrf_cpu_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -1813,7 +1745,7 @@ def _householder_product_abstract_eval(a, taus): raise ValueError("Argument to Householder product must have ndims >= 2") *batch_dims, m, n = a.shape *taus_batch_dims, k = taus.shape - if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n): + if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > core.min_dim(m, n): raise ValueError(f"Type mismatch for Householder product: {a=} {taus=}") if m < n: raise ValueError("Householder product inputs must have at least as many " @@ -1841,48 +1773,23 @@ def _householder_product_lowering_rule(ctx, a, taus): result_shapes=result_shapes) return [op.result] -def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *, - platform: str): - a_aval, taus_aval = ctx.avals_in - *batch_dims, m, n = a_aval.shape - if not is_constant_shape([m, n]): - raise NotImplementedError( - "Shape polymorphism for native serialization for householder_product on " - f"CPU and GPU is implemented only for the batch dimensions: {a_aval.shape}") - - if m == 0 or n == 0: - return [mlir.full_like_aval(ctx, 0, a_aval)] - - if platform in ["rocm", "cuda"]: - # TODO(necula): remove the platform kwarg when we implement GPU support. - if not is_constant_shape(a_aval.shape): - raise NotImplementedError( - "Shape polymorphism for native serialization for householder_product " - f"on GPU is not implemented; b/261671778; {a_aval.shape}") - a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus) +def _householder_product_cpu_gpu_lowering(ctx, a, taus, *, + target_name_prefix: str): + a_aval, _ = ctx.avals_in + batch_dims = a_aval.shape[:-2] + nb = len(batch_dims) + layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1)) + tau_layout = tuple(range(nb, -1, -1)) + if target_name_prefix == "cpu": + dtype = a_aval.dtype + prefix = "un" if dtypes.issubdtype(dtype, np.complexfloating) else "or" + target_name = lapack.prepare_lapack_call(f"{prefix}gqr_ffi", dtype) else: - ctx_args = ( - (ctx,) if platform == "cpu" else () - ) - a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) - tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape) - a, *maybe_info_orgqr = orgqr_impl(*ctx_args, a_aval.dtype, a, taus, - a_shape_vals=a_shape_vals, - tau_shape_vals=tau_shape_vals) - if not ctx.is_forward_compat(): - # Skip the info parameter verification for the FFI kernel. - return [a] - # TODO(b/344892332): This parameter will no longer be needed after - # the forward compatibility period - info_orgqr = maybe_info_orgqr[0] - zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) - ok = mlir.compare_hlo(info_orgqr, zeros, "EQ", "SIGNED") - select_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_)) - ok = mlir.broadcast_in_dim(ctx, ok, select_a_aval, - broadcast_dimensions=range(len(batch_dims))) - a = _broadcasting_select_hlo(ctx, ok, select_a_aval, a, a_aval, _nan_like_hlo(ctx, a_aval), a_aval) - return [a] - + target_name = f"{target_name_prefix}solver_orgqr_ffi" + rule = ffi.ffi_lowering(target_name, operand_layouts=[layout, tau_layout], + result_layouts=[layout], + operand_output_aliases={0: 0}) + return rule(ctx, a, taus) householder_product_p = Primitive('householder_product') householder_product_p.def_impl(partial(dispatch.apply_primitive, householder_product_p)) @@ -1892,18 +1799,15 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *, mlir.register_lowering( householder_product_p, - partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_hlo, - platform='cpu'), + partial(_householder_product_cpu_gpu_lowering, target_name_prefix='cpu'), platform='cpu') mlir.register_lowering( householder_product_p, - partial(_householder_product_cpu_gpu_lowering, gpu_solver.cuda_orgqr, - platform='cuda'), + partial(_householder_product_cpu_gpu_lowering, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( householder_product_p, - partial(_householder_product_cpu_gpu_lowering, gpu_solver.rocm_orgqr, - platform='rocm'), + partial(_householder_product_cpu_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -1916,7 +1820,7 @@ def _qr_abstract_eval(operand, *, full_matrices): if operand.ndim < 2: raise ValueError("Argument to QR decomposition must have ndims >= 2") *batch_dims, m, n = operand.shape - k = m if full_matrices else min(m, n) + k = m if full_matrices else core.min_dim(m, n) q = operand.update(shape=(*batch_dims, m, k)) r = operand.update(shape=(*batch_dims, k, n)) else: @@ -1953,7 +1857,7 @@ def _qr_batching_rule(batched_args, batch_dims, *, full_matrices): def _qr_lowering(a, *, full_matrices): *batch_dims, m, n = a.shape if m == 0 or n == 0: - k = m if full_matrices else min(m, n) + k = m if full_matrices else core.min_dim(m, n) q = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (m, k)), (*batch_dims, m, k), (len(batch_dims), len(batch_dims) + 1)) @@ -2131,8 +2035,7 @@ def _svd_cpu_gpu_lowering( compute_uv=compute_uv) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - ctx_args = (ctx,) - s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand, + s, u, vt, info = gesvd_impl(ctx, operand_aval.dtype, operand, full_matrices=full_matrices, compute_uv=compute_uv, a_shape_vals=a_shape_vals) @@ -2575,7 +2478,7 @@ def _hessenberg_batching_rule(batched_args, batch_dims): def _hessenberg_cpu_hlo(ctx, a): a_aval, = ctx.avals_in batch_dims = a_aval.shape[:-2] - a, taus, info = lapack.gehrd_hlo(a_aval.dtype, a) + a, taus, info = lapack.gehrd_hlo(ctx, a_aval.dtype, a) ok = mlir.compare_hlo( info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), "EQ", "SIGNED") diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index 67f274e829ff..00e15ef6a91d 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -205,11 +205,13 @@ def conv_general_dilated_local( lhs_array = lax.asarray(lhs) c_precision = lax.canonicalize_precision(precision) - lhs_precision = ( - c_precision[0] - if (isinstance(c_precision, tuple) and len(c_precision) == 2) - else c_precision - ) + if c_precision is None: + lhs_precision = None + elif isinstance(c_precision, tuple) and len(c_precision) == 2: + lhs_precision = c_precision[0] + else: + raise ValueError( + f"Unsupported precision for conv_general_dilated_local: {precision}") patches = conv_general_dilated_patches( lhs=lhs_array, diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index c9a07072ddc7..9d4614f344fb 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1500,7 +1500,8 @@ def _pgather_impl(src, idx, *, axes): dnums = slicing.GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=(0,), - start_index_map=(0,)) + start_index_map=(0,), + ) return slicing.gather(src_one_axis_front, idx, dimension_numbers=dnums, slice_sizes=tuple(slice_sizes)) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 60dfa0e1b3d2..6e7aab7a1b02 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -233,6 +233,16 @@ class GatherDimensionNumbers(NamedTuple): start_index_map: for each dimension in `start_indices`, gives the corresponding dimension in the `operand` that is to be sliced. Must be a tuple of integers with size equal to `start_indices.shape[-1]`. + operand_batching_dims: the set of batching dimensions `i` in `operand` that + have `slice_sizes[i] == 1` and that should have a corresponding dimension + in both the `start_indices` (at the same index in + `start_indices_batching_dims`) and output of the gather. Must be a tuple + of integers in ascending order. + start_indices_batching_dims: the set of batching dimensions `i` in + `start_indices` that should have a corresponding dimension in both the + `operand` (at the same index in `operand_batching_dims`) and output of the + gather. Must be a tuple of integers (order is fixed based on + correspondence with `operand_batching_dims`). Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is implicit; there is always an index vector dimension and it must always be the @@ -241,6 +251,8 @@ class GatherDimensionNumbers(NamedTuple): offset_dims: tuple[int, ...] collapsed_slice_dims: tuple[int, ...] start_index_map: tuple[int, ...] + operand_batching_dims: tuple[int, ...] = () + start_indices_batching_dims: tuple[int, ...] = () class GatherScatterMode(enum.Enum): @@ -267,6 +279,7 @@ class GatherScatterMode(enum.Enum): CLIP = enum.auto() FILL_OR_DROP = enum.auto() PROMISE_IN_BOUNDS = enum.auto() + ONE_HOT = enum.auto() @staticmethod def from_any(s: str | GatherScatterMode | None): @@ -278,6 +291,8 @@ def from_any(s: str | GatherScatterMode | None): return GatherScatterMode.FILL_OR_DROP if s == "promise_in_bounds": return GatherScatterMode.PROMISE_IN_BOUNDS + if s == "one_hot": + return GatherScatterMode.ONE_HOT else: raise ValueError(f'Unknown gather mode "{s}"') @@ -370,6 +385,17 @@ class ScatterDimensionNumbers(NamedTuple): scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives the corresponding dimension in `operand`. Must be a sequence of integers with size equal to `scatter_indices.shape[-1]`. + operand_batching_dims: the set of batching dimensions `i` in `operand` that + should have a corresponding dimension in both the `scatter_indices` (at + the same index in `scatter_indices_batching_dims`) and `updates`. Must be + a tuple of integers in ascending order. These are the mirror image of + `operand_batching_dims` in the case of `gather`. + scatter_indices_batching_dims: the set of batching dimensions `i` in + `scatter_indices` that should have a corresponding dimension in both the + `operand` (at the same index in `operand_batching_dims`) and output of the + gather. Must be a tuple of integers (order is fixed based on + correspondence with `input_batching_dims`). These are the mirror image of + `start_indices_batching_dims` in the case of `gather`. Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is implicit; there is always an index vector dimension and it must always be the @@ -378,6 +404,8 @@ class ScatterDimensionNumbers(NamedTuple): update_window_dims: Sequence[int] inserted_window_dims: Sequence[int] scatter_dims_to_operand_dims: Sequence[int] + operand_batching_dims: Sequence[int] = () + scatter_indices_batching_dims: Sequence[int] = () def scatter_add( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, @@ -426,6 +454,67 @@ def scatter_add( indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=GatherScatterMode.from_any(mode)) + +def scatter_sub( + operand: ArrayLike, + scatter_indices: ArrayLike, + updates: ArrayLike, + dimension_numbers: ScatterDimensionNumbers, + *, + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | GatherScatterMode | None = None, +) -> Array: + """Scatter-sub operator. + + Wraps `XLA's Scatter operator + `_, where + subtraction is used to combine updates and values from `operand`. + + The semantics of scatter are complicated, and its API might change in the + future. For most use cases, you should prefer the + :attr:`jax.numpy.ndarray.at` property on JAX arrays which uses + the familiar NumPy indexing syntax. + + Args: + operand: an array to which the scatter should be applied + scatter_indices: an array that gives the indices in `operand` to which each + update in `updates` should be applied. + updates: the updates that should be scattered onto `operand`. + dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes how + dimensions of `operand`, `start_indices`, `updates` and the output relate. + indices_are_sorted: whether `scatter_indices` is known to be sorted. If + true, may improve performance on some backends. + unique_indices: whether the elements to be updated in ``operand`` are + guaranteed to not overlap with each other. If true, may improve + performance on some backends. JAX does not check this promise: if the + updated elements overlap when ``unique_indices`` is ``True`` the behavior + is undefined. + mode: how to handle indices that are out of bounds: when set to 'clip', + indices are clamped so that the slice is within bounds, and when set to + 'fill' or 'drop' out-of-bounds updates are dropped. The behavior for + out-of-bounds indices when set to 'promise_in_bounds' is + implementation-defined. + + Returns: + An array containing the sum of `operand` and the scattered updates. + """ + jaxpr, consts = lax._reduction_jaxpr( + lax.sub, lax._abstractify(lax._const(operand, 0)) + ) + return scatter_sub_p.bind( + operand, + scatter_indices, + updates, + update_jaxpr=jaxpr, + update_consts=consts, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=GatherScatterMode.from_any(mode), + ) + + def scatter_mul( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, dimension_numbers: ScatterDimensionNumbers, *, @@ -694,7 +783,8 @@ def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array: dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=tuple(axes), - start_index_map=tuple(axes)) + start_index_map=tuple(axes), + ) return gather(src, indices, dimension_numbers=dnums, slice_sizes=tuple(slice_sizes)) @@ -1256,8 +1346,11 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): dims = tuple(range(ndims)) start_indices, dyn_slice_sizes = util.split_list(start_indices_and_dyn, [ndims]) start_idx_bds, dyn_slice_size_bds = util.split_list(start_idx_and_dyn_bds, [ndims]) - dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(), - start_index_map=dims) + dnums = GatherDimensionNumbers( + offset_dims=dims, + collapsed_slice_dims=(), + start_index_map=dims, + ) index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds) return _gather_batching_rule( [operand, index, *dyn_slice_sizes], @@ -1396,9 +1489,11 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): update_shape = (np.shape(update) if update_bd is batching.not_mapped else tuple(np.delete(np.shape(update), update_bd))) dims = tuple(range(len(update_shape))) - dnums = ScatterDimensionNumbers(update_window_dims=dims, - inserted_window_dims=(), - scatter_dims_to_operand_dims=dims) + dnums = ScatterDimensionNumbers( + update_window_dims=dims, + inserted_window_dims=(), + scatter_dims_to_operand_dims=dims, + ) index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd) return api.vmap( partial(scatter, dimension_numbers=dnums, @@ -1437,6 +1532,12 @@ def _is_sorted(dims, op_name, name): if dims[i] < dims[i - 1]: raise TypeError(f"{name} in {op_name} op must be sorted; got {dims}") +def _dims_in_range(dims, rank, op_name, name): + for dim in dims: + if dim < 0 or dim >= rank: + raise TypeError(f"Invalid {name} set in {op_name} op; valid range is " + f"[0, {rank}); got: {dim}.") + def _sorted_dims_in_range(dims, rank, op_name, name): if len(dims) == 0: return @@ -1453,6 +1554,11 @@ def _no_duplicate_dims(dims, op_name, name): if len(set(dims)) != len(dims): raise TypeError(f"{name} in {op_name} op must not repeat; got: {dims}.") +def _disjoint_dims(dims1, dims2, op_name, name1, name2): + if not set(dims1).isdisjoint(set(dims2)): + raise TypeError(f"{name1} and {name2} in {op_name} op must be disjoint; " + f"got: {dims1} and {dims2}.") + def _gather_shape_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1466,6 +1572,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims + operand_batching_dims = dimension_numbers.operand_batching_dims + start_indices_batching_dims = dimension_numbers.start_indices_batching_dims start_index_map = dimension_numbers.start_index_map # Note: in JAX, index_vector_dim is always computed as below, cf. the @@ -1521,6 +1629,50 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", "collapsed_slice_dims") _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims") + + _no_duplicate_dims(operand_batching_dims, "gather", "operand_batching_dims") + _is_sorted(operand_batching_dims, "gather", "operand_batching_dims") + _sorted_dims_in_range( + operand_batching_dims, _rank(operand), "gather", "operand_batching_dims" + ) + + _disjoint_dims(collapsed_slice_dims, operand_batching_dims, "gather", + "collapsed_slice_dims", "operand_batching_dims") + _disjoint_dims(start_index_map, operand_batching_dims, "gather", + "start_index_map", "operand_batching_dims") + + _no_duplicate_dims( + start_indices_batching_dims, "gather", "start_indices_batching_dims" + ) + _dims_in_range( + start_indices_batching_dims, + _rank(indices), + "gather", + "start_indices_batching_dims", + ) + if index_vector_dim in start_indices_batching_dims: + raise TypeError( + "Gather op cannot have the index vector dimension as a batching " + f"dimension; got {start_indices_batching_dims}." + ) + + if len(operand_batching_dims) != len(start_indices_batching_dims): + raise TypeError( + "Gather op requires equal numbers of operand_batching_dims and " + f"start_indices_batching_dims, got {operand_batching_dims} and" + f"{start_indices_batching_dims}." + ) + + operand_batch_shape = tuple(operand.shape[i] for i in operand_batching_dims) + indices_batch_shape = tuple( + indices.shape[i] for i in start_indices_batching_dims + ) + if not core.definitely_equal_shape(operand_batch_shape, indices_batch_shape): + raise TypeError( + "Gather op requires operand batching dimensions and indices batching " + f"dimensions to have the same shape, got {operand_batch_shape} and " + f"{indices_batch_shape}." + ) # End ValidateGatherDimensions if _rank(operand) != len(slice_sizes): @@ -1528,12 +1680,17 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, f"dimension; got: len(slice_sizes)={len(slice_sizes)}, " f"input_shape.rank={_rank(operand)}") - if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims): - raise TypeError(f"All components of the offset index in a gather op must " - f"either be a offset dimension or explicitly collapsed; " - f"got len(slice_sizes)={len(slice_sizes)}, " - f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" - f"{collapsed_slice_dims}.") + if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims) + len( + operand_batching_dims + ): + raise TypeError( + "All components of the offset index in a gather op must " + "either be a offset dimension or explicitly collapsed/batching; " + f"got len(slice_sizes)={len(slice_sizes)}, " + f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" + f"{collapsed_slice_dims}, operand_batching_dims=" + f"{operand_batching_dims}." + ) for i in range(len(slice_sizes)): slice_size = slice_sizes[i] @@ -1552,12 +1709,21 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, f"but bound is {bound} for index " f"{collapsed_slice_dims[i]} at position {i}.") + for i in range(len(operand_batching_dims)): + bound = slice_sizes[operand_batching_dims[i]] + if bound > 1: + raise TypeError(f"Gather op can only have operand batching dims with " + f"bound 0/1, but bound is {bound} for index " + f"{operand_batching_dims[i]} at position {i}." + ) + return _gather_shape_computation(indices, dimension_numbers, slice_sizes) def _gather_shape_computation(indices, dimension_numbers, slice_sizes): offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims + operand_batching_dims = dimension_numbers.operand_batching_dims output_shape_rank = len(offset_dims) + _rank(indices) - 1 index_vector_dim = _rank(indices) - 1 @@ -1572,8 +1738,11 @@ def _gather_shape_computation(indices, dimension_numbers, slice_sizes): indices_shape_gen = iter(expanded_indices_shape) - slice_sizes_gen = (s for i, s in enumerate(slice_sizes) - if i not in collapsed_slice_dims) + slice_sizes_gen = ( + s + for i, s in enumerate(slice_sizes) + if i not in collapsed_slice_dims and i not in operand_batching_dims + ) ans = tuple(next(slice_sizes_gen) if i in offset_dims else next(indices_shape_gen) for i in range(output_shape_rank)) return ans @@ -1631,9 +1800,12 @@ def _gather_transpose_rule(t, operand, indices, *, dimension_numbers, else: zeros = lax.full(operand_shape, lax._zero(t)) scatter_dnums = ScatterDimensionNumbers( - update_window_dims=dimension_numbers.offset_dims, - inserted_window_dims=dimension_numbers.collapsed_slice_dims, - scatter_dims_to_operand_dims=dimension_numbers.start_index_map) + update_window_dims=dimension_numbers.offset_dims, + inserted_window_dims=dimension_numbers.collapsed_slice_dims, + scatter_dims_to_operand_dims=dimension_numbers.start_index_map, + operand_batching_dims=dimension_numbers.operand_batching_dims, + scatter_indices_batching_dims=dimension_numbers.start_indices_batching_dims, + ) out = scatter_add(zeros, indices, t, scatter_dnums, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, @@ -1652,11 +1824,17 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, slice_sizes = (operand.shape[0],) + slice_sizes offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims)) collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) + operand_batching_dims = tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) start_index_map = tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, - start_index_map=start_index_map) + start_index_map=start_index_map, + operand_batching_dims=operand_batching_dims, + start_indices_batching_dims=dimension_numbers.start_indices_batching_dims, + ) if isinstance(operand_bdim, batching.RaggedAxis): ragged_slice_sizes = batching.bdim_as_shape(operand_bdim, slice_sizes) for orig, fabricated in zip( @@ -1687,10 +1865,16 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, elif operand_bdim is None and indices_bdim is not None: indices = batching.moveaxis(indices, indices_bdim, 0) offset_dims = tuple(1 + d for d in dimension_numbers.offset_dims) + start_indices_batching_dims = tuple( + np.add(1, dimension_numbers.start_indices_batching_dims) + ) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=dimension_numbers.collapsed_slice_dims, - start_index_map=dimension_numbers.start_index_map) + start_index_map=dimension_numbers.start_index_map, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=start_indices_batching_dims, + ) # If batching indexed accesses into the same array, the batched gather may # no longer have sorted or unique indices. return gather(operand, indices, dimension_numbers=dnums, @@ -1702,61 +1886,34 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, operand = batching.moveaxis(operand, operand_bdim, 0) indices = batching.moveaxis(indices, indices_bdim, 0) - # This slightly awkward special case is needed because the shape rule for - # gather does not allow size-1 slices out of a size-0 dimension, even if - # the number of slices is zero. Likely the best fix would be to change the - # definition of gather() so it can be batched without the construction of - # an explicit iota of size-1 slices. if core.definitely_equal(operand.shape[0], 0): - output_shape = _gather_shape_rule( - core.ShapedArray(operand.shape[1:], operand.dtype), - core.ShapedArray(indices.shape[1:], - dtypes.canonicalize_dtype(indices.dtype)), - dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, - unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, - mode=mode, fill_value=fill_value) - return lax.full((0,) + output_shape, lax._zero(operand)), 0 - - # Example: user code had indices shape (3, 4, 5), and we have to deal with - # indices shape (7, 3, 4, 5). We transform that to indices of shape - # (7, 3, 4, 6) where we concatenated an iota that counts along our batch - # dimension to the front of the ndindex. - index_dtype = _promote_dtype_for_size(indices.dtype, indices.shape[0]) - count_shape = list(indices.shape) - count_shape[-1] = 1 - counts = lax.broadcasted_iota(index_dtype, tuple(count_shape), 0) - indices = lax.concatenate([counts, indices.astype(index_dtype)], - len(count_shape) - 1) - - slice_sizes = (1,) + slice_sizes - collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) + slice_sizes = (0,) + slice_sizes + else: + slice_sizes = (1,) + slice_sizes + collapsed_slice_dims = tuple( + np.add(1, dimension_numbers.collapsed_slice_dims) + ) + operand_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) + start_indices_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.start_indices_batching_dims) + ) offset_dims = tuple(np.add(1, dimension_numbers.offset_dims)) - start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map)) + start_index_map = tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, - start_index_map=start_index_map) + start_index_map=start_index_map, + operand_batching_dims=operand_batching_dims, + start_indices_batching_dims=start_indices_batching_dims, + ) return gather(operand, indices, dimension_numbers=dnums, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value), 0 -def _promote_dtype_for_size(dtype, size): - if not dtypes.issubdtype(dtype, np.integer): - return dtype - # size may be a dynamic shape, in which case we return at least int32 - try: - size = int(size) - except: - return dtype if np.iinfo(dtype).bits >= 32 else np.dtype('int32') - if size <= np.iinfo(dtype).max: - return dtype - elif size <= np.iinfo(np.int32).max: - return np.dtype('int32') - else: - return dtypes.canonicalize_dtype(np.int64) - def _gather_pad_rule(in_avals, out_avals, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1821,8 +1978,10 @@ def _gather_lower(ctx, operand, indices, *, GatherScatterMode.CLIP), mode dnums = hlo.GatherDimensionNumbers.get( collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims), - operand_batching_dims=[], - start_indices_batching_dims=[], + operand_batching_dims=list(dimension_numbers.operand_batching_dims), + start_indices_batching_dims=list( + dimension_numbers.start_indices_batching_dims + ), index_vector_dim=len(ctx.avals_in[1].shape) - 1, offset_dims=list(dimension_numbers.offset_dims), start_index_map=list(dimension_numbers.start_index_map), @@ -1872,6 +2031,8 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, update_window_dims = dimension_numbers.update_window_dims inserted_window_dims = dimension_numbers.inserted_window_dims + operand_batching_dims = dimension_numbers.operand_batching_dims + scatter_indices_batching_dims = dimension_numbers.scatter_indices_batching_dims scatter_dims_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims # Note: in JAX, index_vector_dim is always computed as below, cf. the # documentation of the ScatterDimensionNumbers class. @@ -1909,8 +2070,55 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, _sorted_dims_in_range(inserted_window_dims, _rank(operand), "scatter", "inserted_window_dims") + # Validate operand_batching_dims and scatter_indices_batching_dims + _is_sorted(operand_batching_dims, "scatter", "operand_batching_dims") + _no_duplicate_dims(operand_batching_dims, "scatter", "operand_batching_dims") + _sorted_dims_in_range( + operand_batching_dims, _rank(operand), "scatter", "operand_batching_dims" + ) + _disjoint_dims(inserted_window_dims, operand_batching_dims, "scatter", + "inserted_window_dims", "operand_batching_dims") + _disjoint_dims(scatter_dims_to_operand_dims, operand_batching_dims, "scatter", + "scatter_dims_to_operand_dims", "operand_batching_dims") + + _no_duplicate_dims( + scatter_indices_batching_dims, "scatter", "scatter_indices_batching_dims" + ) + _dims_in_range( + scatter_indices_batching_dims, + _rank(indices), + "scatter", + "scatter_indices_batching_dims", + ) + if index_vector_dim in scatter_indices_batching_dims: + raise TypeError( + "Scatter op cannot have the index vector dimension as a batching " + f"dimension; got {scatter_indices_batching_dims}.") + + if len(operand_batching_dims) != len(scatter_indices_batching_dims): + raise TypeError( + "Scatter op requires equal numbers of operand_batching_dims and " + f"scatter_indices_batching_dims, got {operand_batching_dims} and " + f"{scatter_indices_batching_dims}." + ) + + operand_batch_shape = tuple(operand.shape[i] for i in operand_batching_dims) + indices_batch_shape = tuple( + indices.shape[i] for i in scatter_indices_batching_dims + ) + if not core.definitely_equal_shape(operand_batch_shape, indices_batch_shape): + raise TypeError( + "Scatter op requires operand batching dimensions and indices batching " + f"dimensions to have the same shape, got {operand_batch_shape} and " + f"{indices_batch_shape}." + ) + # Validate window_size - window_size = len(update_window_dims) + len(inserted_window_dims) + window_size = ( + len(update_window_dims) + + len(inserted_window_dims) + + len(operand_batching_dims) + ) if _rank(operand) != window_size: raise TypeError(f"Scatter op has window of size {window_size}; doesn't " f"match operand of rank {_rank(operand)}.") @@ -1933,8 +2141,14 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, _no_duplicate_dims(scatter_dims_to_operand_dims, "scatter", "scatter_dims_to_operand_dims") - max_update_slice_sizes = [operand.shape[i] for i in range(len(operand.shape)) - if not i in set(inserted_window_dims)] + max_update_slice_sizes = [ + operand.shape[i] + for i in range(len(operand.shape)) + if ( + i not in set(inserted_window_dims) + and i not in set(operand_batching_dims) + ) + ] for i in range(len(update_window_dims)): update_window_dim = update_window_dims[i] @@ -1968,7 +2182,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums): slice_sizes = [] pos = 0 for i in range(len(operand.shape)): - if i in dnums.inserted_window_dims: + if i in dnums.inserted_window_dims or i in dnums.operand_batching_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) @@ -1988,32 +2202,66 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums): return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64), upper_bound) -def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts, - dimension_numbers, indices_are_sorted, unique_indices, - mode): + +def _scatter_addsub_jvp( + prim, + primals, + tangents, + *, + update_jaxpr, + update_consts, + dimension_numbers, + indices_are_sorted, + unique_indices, + mode, +): operand, indices, updates = primals g_operand, g_indices, g_updates = tangents del g_indices # ignored - val_out = scatter_add_p.bind( - operand, indices, updates, update_jaxpr=update_jaxpr, - update_consts=update_consts, dimension_numbers=dimension_numbers, - indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, - mode=mode) + val_out = prim.bind( + operand, + indices, + updates, + update_jaxpr=update_jaxpr, + update_consts=update_consts, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, + ) if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) - tangent_out = scatter_add_p.bind( - g_operand, indices, g_updates, update_jaxpr=update_jaxpr, - update_consts=update_consts, dimension_numbers=dimension_numbers, - indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, - mode=mode) + tangent_out = prim.bind( + g_operand, + indices, + g_updates, + update_jaxpr=update_jaxpr, + update_consts=update_consts, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, + ) return val_out, tangent_out -def _scatter_add_transpose_rule(t, operand, indices, updates, *, - update_jaxpr, update_consts, dimension_numbers, - indices_are_sorted, unique_indices, mode): + +def _scatter_addsub_transpose_rule( + prim, + t, + operand, + indices, + updates, + *, + update_jaxpr, + update_consts, + dimension_numbers, + indices_are_sorted, + unique_indices, + mode, +): assert not ad.is_undefined_primal(indices) if ad.is_undefined_primal(updates): updates_shape = updates.aval.shape @@ -2029,19 +2277,27 @@ def _scatter_add_transpose_rule(t, operand, indices, updates, *, if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if i in dimension_numbers.inserted_window_dims: + if ( + i in dimension_numbers.inserted_window_dims + or i in dimension_numbers.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) pos += 1 update_t = gather(t, indices, dimension_numbers=gather_dnums, slice_sizes=slice_sizes, mode=mode, fill_value=0) + if prim is scatter_sub_p: + update_t = lax.neg(update_t) return [operand_t, None, update_t] def _scatter_mul_transpose_rule(t, operand, indices, updates, *, @@ -2067,13 +2323,19 @@ def _scatter_mul_transpose_rule(t, operand, indices, updates, *, raise NotImplementedError( "scatter_mul gradients are only implemented if `unique_indices=True`") gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if i in dimension_numbers.inserted_window_dims: + if ( + i in dimension_numbers.inserted_window_dims + or i in dimension_numbers.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) @@ -2095,40 +2357,52 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims) if ax is not None) operand = batching.bdim_at_front(operand, operand_bdim, size) - operand_bdim = 0 updates = batching.bdim_at_front(updates, updates_bdim, size) if indices_bdim is None: inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims)) update_window_dims = (0,) + tuple(np.add(1, dimension_numbers.update_window_dims)) + operand_batching_dims = tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) scatter_dims_to_operand_dims = tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) + scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, + operand_batching_dims=operand_batching_dims, + scatter_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) return scatter_op.bind( operand, indices, updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, update_jaxpr=update_jaxpr, update_consts=update_consts), 0 - # see the third case in _gather_batching_rule for comparison and comments indices = batching.bdim_at_front(indices, indices_bdim, size) - count_shape = list(indices.shape) - count_shape[-1] = 1 - counts = lax.broadcasted_iota(indices.dtype, tuple(count_shape), 0) - indices = lax.concatenate([counts, indices], len(count_shape) - 1) - update_window_dims = tuple(np.add(1, dimension_numbers.update_window_dims)) - inserted_window_dims = (0,) + tuple(np.add(1, dimension_numbers.inserted_window_dims)) - scatter_dims_to_operand_dims = (0,) + tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) + inserted_window_dims = tuple( + np.add(1, dimension_numbers.inserted_window_dims) + ) + operand_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) + scatter_indices_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.scatter_indices_batching_dims) + ) + scatter_dims_to_operand_dims = tuple( + np.add(1, dimension_numbers.scatter_dims_to_operand_dims) + ) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) + scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, + operand_batching_dims=operand_batching_dims, + scatter_indices_batching_dims=scatter_indices_batching_dims, + ) return scatter_op.bind( operand, indices, updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, @@ -2137,11 +2411,23 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_add_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add', weak_type_rule=_argnum_weak_type(0)) -ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp -ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule +ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p) +ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p) batching.primitive_batchers[scatter_add_p] = ( partial(_scatter_batching_rule, scatter_add_p)) +scatter_sub_p = standard_primitive( + _scatter_shape_rule, + _scatter_dtype_rule, + "scatter-sub", + weak_type_rule=_argnum_weak_type(0), +) +ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p) +ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p) +batching.primitive_batchers[scatter_sub_p] = partial( + _scatter_batching_rule, scatter_sub_p +) + scatter_mul_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul', weak_type_rule=_argnum_weak_type(0)) @@ -2190,12 +2476,18 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, gather_dnums = GatherDimensionNumbers( offset_dims=scatter_dnums.update_window_dims, collapsed_slice_dims=scatter_dnums.inserted_window_dims, - start_index_map=scatter_dnums.scatter_dims_to_operand_dims) + start_index_map=scatter_dnums.scatter_dims_to_operand_dims, + operand_batching_dims=scatter_dnums.operand_batching_dims, + start_indices_batching_dims=scatter_dnums.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(operand.shape)): - if i in scatter_dnums.inserted_window_dims: + if ( + i in scatter_dnums.inserted_window_dims + or i in scatter_dnums.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[scatter_dnums.update_window_dims[pos]]) @@ -2323,7 +2615,6 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, # of using scatter-add here is that we don't need a `scatter` transpose # rule. - # a) attach a positive ID to each update in `updates`, and perform a scatter # on the IDs. ids_shape = list(updates.shape) @@ -2344,13 +2635,16 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, # b) compute the inverse gather that "undoes" the scatter on the id values. gather_dnums = GatherDimensionNumbers( - offset_dims=dnums.update_window_dims, - collapsed_slice_dims=dnums.inserted_window_dims, - start_index_map=dnums.scatter_dims_to_operand_dims) + offset_dims=dnums.update_window_dims, + collapsed_slice_dims=dnums.inserted_window_dims, + start_index_map=dnums.scatter_dims_to_operand_dims, + operand_batching_dims=dnums.operand_batching_dims, + start_indices_batching_dims=dnums.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(scattered_ids.shape)): - if i in dnums.inserted_window_dims: + if i in dnums.inserted_window_dims or i in dnums.operand_batching_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) @@ -2405,13 +2699,19 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if i in dimension_numbers.inserted_window_dims: + if ( + i in dimension_numbers.inserted_window_dims + or i in dimension_numbers.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) @@ -2479,8 +2779,8 @@ def _scatter_lower(ctx, operand, indices, updates, *, scatter_dnums = hlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), - input_batching_dims=[], - scatter_indices_batching_dims=[], + input_batching_dims=list(dnums.operand_batching_dims), + scatter_indices_batching_dims=list(dnums.scatter_indices_batching_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1, ) @@ -2510,6 +2810,7 @@ def _scatter_lower(ctx, operand, indices, updates, *, mlir.register_lowering(scatter_p, _scatter_lower) mlir.register_lowering(scatter_add_p, _scatter_lower) +mlir.register_lowering(scatter_sub_p, _scatter_lower) mlir.register_lowering(scatter_mul_p, _scatter_lower) mlir.register_lowering(scatter_min_p, _scatter_lower) mlir.register_lowering(scatter_max_p, _scatter_lower) @@ -2517,9 +2818,21 @@ def _scatter_lower(ctx, operand, indices, updates, *, def _real_dtype(dtype): return np.finfo(dtype).dtype -def _scatter_add_lower_gpu(ctx, operand, indices, updates, - *, update_jaxpr, update_consts, dimension_numbers, - indices_are_sorted, unique_indices, mode): + +def _scatter_addsub_lower_gpu( + ctx, + operand, + indices, + updates, + *, + update_jaxpr, + update_consts, + dimension_numbers, + indices_are_sorted, + unique_indices, + mode, + reduce_op, +): operand_aval_in, _, updates_aval_in = ctx.avals_in if operand_aval_in.dtype != np.complex128: return _scatter_lower(ctx, operand, indices, updates, @@ -2539,8 +2852,8 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates, scatter_dnums = hlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), - input_batching_dims=[], - scatter_indices_batching_dims=[], + input_batching_dims=list(dnums.operand_batching_dims), + scatter_indices_batching_dims=list(dnums.scatter_indices_batching_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1, ) @@ -2563,15 +2876,24 @@ def _scatter(operand_part, updates_part): scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype)) reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): - add = hlo.AddOp(*reducer.arguments).result - hlo.return_([add]) + hlo.return_([reduce_op(*reducer.arguments).result]) return scatter.result real = _scatter(hlo.real(operand), hlo.real(updates)) imag = _scatter(hlo.imag(operand), hlo.imag(updates)) return [hlo.complex(real, imag)] -mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu") + +mlir.register_lowering( + scatter_add_p, + partial(_scatter_addsub_lower_gpu, reduce_op=hlo.AddOp), + platform="gpu", +) +mlir.register_lowering( + scatter_sub_p, + partial(_scatter_addsub_lower_gpu, reduce_op=hlo.SubtractOp), + platform="gpu", +) def _dynamic_slice_indices( diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 5e3e9bcd8df2..a6eeb18b5203 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -23,14 +23,11 @@ from jax._src import config from jax._src import dtypes from jax._src.util import safe_zip -from jax._src.lib import xla_client zip, unsafe_zip = safe_zip, zip import numpy as np -xops = xla_client.ops - def _input_dtype(x, *_, **__): return dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) @@ -96,13 +93,6 @@ def standard_multi_result_abstract_eval( else: raise TypeError(avals, least_specialized) -def standard_translate(prim): - xla_opname = ''.join(term.capitalize() for term in prim.name.split('_')) - op = getattr(xops, xla_opname) - def translation_rule(ctx, avals_in, avals_out, *args, **kwargs): - del ctx, avals_in, avals_out - return [op(*args, **kwargs)] - return translation_rule def _standard_weak_type_rule(*avals, **kwargs): return all(aval.weak_type for aval in avals) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index e8fcb433438a..b72c6ee46fdd 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -18,6 +18,7 @@ from __future__ import annotations import gc +import os import pathlib import re from typing import Any @@ -128,13 +129,29 @@ def _xla_gc_callback(*args): # TODO(rocm): check if we need the same for rocm. def _cuda_path() -> str | None: - _jaxlib_path = pathlib.Path(jaxlib.__file__).parent - # If the pip package nvidia-cuda-nvcc-cu11 is installed, it should have - # both of the things XLA looks for in the cuda path, namely bin/ptxas and - # nvvm/libdevice/libdevice.10.bc - path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc" - if path.is_dir(): - return str(path) + def _try_cuda_root_environment_variable() -> str | None: + """Use `CUDA_ROOT` environment variable if set.""" + return os.environ.get('CUDA_ROOT', None) + + def _try_cuda_nvcc_import() -> str | None: + """Try to import `cuda_nvcc` and get its path directly. + + If the pip package `nvidia-cuda-nvcc-cu11` is installed, it should have + both of the things XLA looks for in the cuda path, namely `bin/ptxas` and + `nvvm/libdevice/libdevice.10.bc`. + """ + try: + from nvidia import cuda_nvcc # pytype: disable=import-error + except ImportError: + return None + cuda_nvcc_path = pathlib.Path(cuda_nvcc.__file__).parent + return str(cuda_nvcc_path) + + if (path := _try_cuda_root_environment_variable()) is not None: + return path + elif (path := _try_cuda_nvcc_import()) is not None: + return path + return None cuda_path = _cuda_path() diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 08c8bfcb3a29..0b528f817bec 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -407,6 +407,11 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): raise RuntimeError("AbstractMesh is not a context manager") + @staticmethod + def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): + jax_config.update_thread_local_jit_state(mesh_context_manager=mesh) + return + # Create this indirection because pytype fails to recognize a property if a # property raises an exception unconditionally. Remove this once that is fixed. diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index c81d51ea054b..f5b290af67b9 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -19,6 +19,7 @@ from collections.abc import Sequence from functools import partial import operator +import math import numpy as np from typing import Any, Literal import warnings @@ -34,6 +35,8 @@ from jax._src.core import AxisName from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention as cudnn_dot_product_attention, MaskType) +from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.numpy import util as numpy_util from jax._src.typing import Array, ArrayLike, DType from jax._src.ops.special import logsumexp as _logsumexp @@ -902,6 +905,68 @@ def _reshape_to_grouped(t): encoded = jnp.reshape(encoded, (B, T, N, H)) return encoded +def bias_fwd_rule(a, query_head_num): + return bias_fwd_p.bind(a, query_head_num), a +def bias_bwd_rule(query_head_num, res, g): + a = res + if a.shape[0] > 1 or a.shape[-3] != query_head_num: + raise ValueError("cuDNN only supports bias gradient when the batch size is " + f"1 and the head number matches the query, but got " + f"B={a.shape[0]}, N={a.shape[-3]}.") + return (bias_bwd_p.bind(g, a, query_head_num),) + +# This function uses two custom primitives, `bias_fwd` and `bias_bwd`, to work +# around a cuDNN issue where bias gradients are only supported when the batch +# size is 1 and the number of heads matches the query. +# TODO(kaixih@nvidia): Remove this workaround once cuDNN resolves the issue. +@partial(jax.custom_vjp, nondiff_argnums=(1,)) +def check_valid_bias_batch(x, query_head_num): + output, _ = bias_fwd_rule(x, query_head_num) + return output +check_valid_bias_batch.defvjp(bias_fwd_rule, bias_bwd_rule) + +bias_fwd_p = core.Primitive('bias_fwd') +bias_fwd_p.multiple_results = False +bias_bwd_p = core.Primitive('bias_bwd') +bias_bwd_p.multiple_results = False + +def bias_fwd_impl(a, query_head_num): + return a +def bias_bwd_impl(g, a, query_head_num): + return g +bias_fwd_p.def_impl(bias_fwd_impl) +bias_bwd_p.def_impl(bias_bwd_impl) + +def bias_fwd_abstract_eval(a, query_head_num): + return core.ShapedArray(a.shape, a.dtype) +def bias_bwd_abstract_eval(g, a, query_head_num): + return core.ShapedArray(g.shape, g.dtype) +bias_fwd_p.def_abstract_eval(bias_fwd_abstract_eval) +bias_bwd_p.def_abstract_eval(bias_bwd_abstract_eval) + +def bias_fwd_lowering(ctx, a, query_head_num): + return [a] +def bias_bwd_lowering(ctx, g, a, query_head_num): + return [g] +mlir.register_lowering(bias_fwd_p, bias_fwd_lowering) +mlir.register_lowering(bias_bwd_p, bias_bwd_lowering) + +def bias_fwd_batch_rule(batched_args, batch_dims): + x, query_head_num = batched_args + a = batch_dims[0] + output, _ = bias_fwd_rule(x, query_head_num) + return output, a +def bias_bwd_batch_rule(batched_args, batch_dims): + g, x, query_head_num = batched_args + b = batch_dims[0] + *Bs, _, _, _ = x.shape + B = math.prod(Bs) + x = jnp.reshape(x, (B,) + x.shape[-3:]) + output, = bias_bwd_rule(query_head_num, x, g) + return output, b +batching.primitive_batchers[bias_fwd_p] = bias_fwd_batch_rule +batching.primitive_batchers[bias_bwd_p] = bias_bwd_batch_rule + def dot_product_attention( query: ArrayLike, key: ArrayLike, @@ -1034,6 +1099,9 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], local_window_size=local_window_size, ) case 'cudnn': + if bias is not None: + bias = check_valid_bias_batch(bias, query_arr.shape[-2]) + bias = jnp.asarray(bias) use_padding = ( query_seq_lengths is not None or key_value_seq_lengths is not None ) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 95d681cad8e5..24b8d315d8ac 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -661,6 +661,7 @@ class _IndexUpdateHelper: ============================== ================================ ``x = x.at[idx].set(y)`` ``x[idx] = y`` ``x = x.at[idx].add(y)`` ``x[idx] += y`` + ``x = x.at[idx].subtract(y)`` ``x[idx] -= y`` ``x = x.at[idx].multiply(y)`` ``x[idx] *= y`` ``x = x.at[idx].divide(y)`` ``x[idx] /= y`` ``x = x.at[idx].power(y)`` ``x[idx] **= y`` @@ -826,6 +827,20 @@ def add(self, values, *, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) + def subtract(self, values, *, indices_are_sorted=False, unique_indices=False, + mode=None): + """Pure equivalent of ``x[idx] -= y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] -= y``. + + See :mod:`jax.ops` for details. + """ + return scatter._scatter_update(self.array, self.index, values, + lax.scatter_sub, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode) + def multiply(self, values, *, indices_are_sorted=False, unique_indices=False, mode=None): """Pure equivalent of ``x[idx] *= y``. diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 8b914680fea3..4a9ec23fd3bd 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -22,7 +22,7 @@ from jax import lax from jax._src.lib import xla_client from jax._src.util import safe_zip -from jax._src.numpy.util import check_arraylike, implements, promote_dtypes_inexact +from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import ufuncs, reductions from jax._src.sharding import Sharding @@ -45,7 +45,7 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array: '"ortho" or "forward".') -def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, +def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, s: Shape | None, axes: Sequence[int] | None, norm: str | None) -> Array: full_name = f"jax.numpy.fft.{func_name}" @@ -87,14 +87,14 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, in_s = list(arr.shape) for axis, x in safe_zip(axes, s): in_s[axis] = x - if fft_type == xla_client.FftType.IRFFT: + if fft_type == lax.FftType.IRFFT: in_s[-1] = (in_s[-1] // 2 + 1) # Cropping arr = arr[tuple(map(slice, in_s))] # Padding arr = jnp.pad(arr, [(0, x-y) for x, y in zip(in_s, arr.shape)]) else: - if fft_type == xla_client.FftType.IRFFT: + if fft_type == lax.FftType.IRFFT: s = [arr.shape[axis] for axis in axes[:-1]] if axes: s += [max(0, 2 * (arr.shape[axes[-1]] - 1))] @@ -181,7 +181,7 @@ def fftn(a: ArrayLike, s: Shape | None = None, >>> jnp.allclose(x, jnp.fft.ifftn(x_fftn)) Array(True, dtype=bool) """ - return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm) + return _fft_core('fftn', lax.FftType.FFT, a, s, axes, norm) def ifftn(a: ArrayLike, s: Shape | None = None, @@ -249,7 +249,7 @@ def ifftn(a: ArrayLike, s: Shape | None = None, [[ 2.5 +0.j 0. -0.58j 0. +0.58j] [ 0.17+0.j -0.83-0.29j -0.83+0.29j]] """ - return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm) + return _fft_core('ifftn', lax.FftType.IFFT, a, s, axes, norm) def rfftn(a: ArrayLike, s: Shape | None = None, @@ -340,7 +340,7 @@ def rfftn(a: ArrayLike, s: Shape | None = None, >>> jnp.fft.rfftn(x1) Array([10.+0.j, -2.+2.j, -2.+0.j], dtype=complex64) """ - return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm) + return _fft_core('rfftn', lax.FftType.RFFT, a, s, axes, norm) def irfftn(a: ArrayLike, s: Shape | None = None, @@ -417,7 +417,7 @@ def irfftn(a: ArrayLike, s: Shape | None = None, [[-2., -2., -2.], [-2., -2., -2.]]], dtype=float32) """ - return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm) + return _fft_core('irfftn', lax.FftType.IRFFT, a, s, axes, norm) def _axis_check_1d(func_name: str, axis: int | None): @@ -428,7 +428,7 @@ def _axis_check_1d(func_name: str, axis: int | None): "Got axis = %r." % (full_name, full_name, axis) ) -def _fft_core_1d(func_name: str, fft_type: xla_client.FftType, +def _fft_core_1d(func_name: str, fft_type: lax.FftType, a: ArrayLike, n: int | None, axis: int | None, norm: str | None) -> Array: _axis_check_1d(func_name, axis) @@ -496,7 +496,7 @@ def fft(a: ArrayLike, n: int | None = None, >>> jnp.allclose(x, jnp.fft.ifft(x_fft)) Array(True, dtype=bool) """ - return _fft_core_1d('fft', xla_client.FftType.FFT, a, n=n, axis=axis, + return _fft_core_1d('fft', lax.FftType.FFT, a, n=n, axis=axis, norm=norm) @@ -552,7 +552,7 @@ def ifft(a: ArrayLike, n: int | None = None, [ 0.67+0.58j -0.5 +1.44j 0.17+2.02j 1.83+0.29j] [ 0.67-0.58j -0.5 -1.44j 0.17-2.02j 1.83-0.29j]] """ - return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, n=n, axis=axis, + return _fft_core_1d('ifft', lax.FftType.IFFT, a, n=n, axis=axis, norm=norm) @@ -613,7 +613,7 @@ def rfft(a: ArrayLike, n: int | None = None, [ 1.-2.j, 3.-4.j, 5.-6.j], [-1.+0.j, -1.+0.j, -1.+0.j]], dtype=complex64) """ - return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, n=n, axis=axis, + return _fft_core_1d('rfft', lax.FftType.RFFT, a, n=n, axis=axis, norm=norm) @@ -673,7 +673,7 @@ def irfft(a: ArrayLike, n: int | None = None, [-0.75, -1.25, -1.75], [ 0.25, 0.75, 1.25]], dtype=float32) """ - return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, n=n, axis=axis, + return _fft_core_1d('irfft', lax.FftType.IRFFT, a, n=n, axis=axis, norm=norm) @@ -763,7 +763,7 @@ def hfft(a: ArrayLike, n: int | None = None, conj_a = ufuncs.conj(a) _axis_check_1d('hfft', axis) nn = (conj_a.shape[axis] - 1) * 2 if n is None else n - return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, n=n, axis=axis, + return _fft_core_1d('hfft', lax.FftType.IRFFT, conj_a, n=n, axis=axis, norm=norm) * nn @@ -813,12 +813,12 @@ def ihfft(a: ArrayLike, n: int | None = None, _axis_check_1d('ihfft', axis) arr = jnp.asarray(a) nn = arr.shape[axis] if n is None else n - output = _fft_core_1d('ihfft', xla_client.FftType.RFFT, arr, n=n, axis=axis, + output = _fft_core_1d('ihfft', lax.FftType.RFFT, arr, n=n, axis=axis, norm=norm) return ufuncs.conj(output) * (1 / nn) -def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, +def _fft_core_2d(func_name: str, fft_type: lax.FftType, a: ArrayLike, s: Shape | None, axes: Sequence[int], norm: str | None) -> Array: full_name = f"jax.numpy.fft.{func_name}" @@ -905,7 +905,7 @@ def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), >>> jnp.allclose(x, jnp.fft.ifft2(x_fft2)) Array(True, dtype=bool) """ - return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes, + return _fft_core_2d('fft2', lax.FftType.FFT, a, s=s, axes=axes, norm=norm) @@ -977,7 +977,7 @@ def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [-0.33-0.58j, -0.33-0.58j], [-0.33+0.58j, -0.33+0.58j]]], dtype=complex64) """ - return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes, + return _fft_core_2d('ifft2', lax.FftType.IFFT, a, s=s, axes=axes, norm=norm) @@ -1056,7 +1056,7 @@ def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j], [ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64) """ - return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes, + return _fft_core_2d('rfft2', lax.FftType.RFFT, a, s=s, axes=axes, norm=norm) @@ -1131,7 +1131,7 @@ def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]]], dtype=float32) """ - return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes, + return _fft_core_2d('irfft2', lax.FftType.IRFFT, a, s=s, axes=axes, norm=norm) @@ -1206,7 +1206,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, Array of sample frequencies, length ``n // 2 + 1``. See also: - - :func:`jax.numpy.fft.rfftfreq`: frequencies for use with + - :func:`jax.numpy.fft.fftfreq`: frequencies for use with :func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`. """ dtype = dtype or dtypes.canonicalize_dtype(jnp.float_) @@ -1233,8 +1233,41 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, return result -@implements(np.fft.fftshift) def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: + """Shift zero-frequency fft component to the center of the spectrum. + + JAX implementation of :func:`numpy.fft.fftshift`. + + Args: + x: N-dimensional array array of frequencies. + axes: optional integer or sequence of integers specifying which axes to + shift. If None (default), then shift all axes. + + Returns: + A shifted copy of ``x``. + + See also: + - :func:`jax.numpy.fft.ifftshift`: inverse of ``fftshift``. + - :func:`jax.numpy.fft.fftfreq`: generate FFT frequencies. + + Examples: + Generate FFT frequencies with :func:`~jax.numpy.fft.fftfreq`: + + >>> freq = jnp.fft.fftfreq(5) + >>> freq + Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32) + + Use ``fftshift`` to shift the zero-frequency entry to the middle of the array: + + >>> shifted_freq = jnp.fft.fftshift(freq) + >>> shifted_freq + Array([-0.4, -0.2, 0. , 0.2, 0.4], dtype=float32) + + Unshift with :func:`~jax.numpy.fft.ifftshift` to recover the original frequencies: + + >>> jnp.fft.ifftshift(shifted_freq) + Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32) + """ check_arraylike("fftshift", x) x = jnp.asarray(x) shift: int | Sequence[int] @@ -1249,8 +1282,42 @@ def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: return jnp.roll(x, shift, axes) -@implements(np.fft.ifftshift) def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: + """The inverse of :func:`jax.numpy.fft.fftshift`. + + JAX implementation of :func:`numpy.fft.ifftshift`. + + Args: + x: N-dimensional array array of frequencies. + axes: optional integer or sequence of integers specifying which axes to + shift. If None (default), then shift all axes. + + Returns: + A shifted copy of ``x``. + + See also: + - :func:`jax.numpy.fft.fftshift`: inverse of ``ifftshift``. + - :func:`jax.numpy.fft.fftfreq`: generate FFT frequencies. + + Examples: + Generate FFT frequencies with :func:`~jax.numpy.fft.fftfreq`: + + >>> freq = jnp.fft.fftfreq(5) + >>> freq + Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32) + + Use :func:`~jax.numpy.fft.fftshift` to shift the zero-frequency entry + to the middle of the array: + + >>> shifted_freq = jnp.fft.fftshift(freq) + >>> shifted_freq + Array([-0.4, -0.2, 0. , 0.2, 0.4], dtype=float32) + + Unshift with ``ifftshift`` to recover the original frequencies: + + >>> jnp.fft.ifftshift(shifted_freq) + Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32) + """ check_arraylike("ifftshift", x) x = jnp.asarray(x) shift: int | Sequence[int] diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 387b3b2a51a7..a0e218c88cc2 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -32,48 +32,47 @@ import importlib import math import operator +import string import types -from typing import (overload, Any, Literal, NamedTuple, - Protocol, TypeVar, Union) +from typing import ( Any, Literal, NamedTuple, + Protocol, TypeVar, Union,overload) import warnings -import numpy as np -import opt_einsum - import jax -from jax import jit from jax import errors +from jax import jit from jax import lax -from jax.sharding import Sharding, SingleDeviceSharding -from jax.tree_util import tree_leaves, tree_flatten, tree_map - from jax._src import api_util from jax._src import config from jax._src import core -from jax._src.custom_derivatives import custom_jvp from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import xla_bridge from jax._src.api_util import _ensure_index_tuple from jax._src.array import ArrayImpl -from jax._src.core import ShapedArray, ConcreteArray -from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, - _sort_le_comparator, PrecisionLike) +from jax._src.core import ConcreteArray, ShapedArray +from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax as lax_internal +from jax._src.lax.lax import ( PrecisionLike,_array_copy, + _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.numpy.vectorize import vectorize from jax._src.typing import ( - Array, ArrayLike, DeprecatedArg, DimSize, DuckTypedArray, - DType, DTypeLike, Shape, StaticScalar, + Array, ArrayLike, + DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar, ) -from jax._src.util import (unzip2, subvals, safe_zip, - ceil_of_ratio, partition_list, +from jax._src.util import ( + NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - NumpyComplexWarning) + ceil_of_ratio, partition_list, safe_zip, subvals,unzip2) +from jax.sharding import Sharding, SingleDeviceSharding +from jax.tree_util import tree_flatten, tree_leaves, tree_map +import numpy as np +import opt_einsum for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']: try: @@ -256,7 +255,7 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: save = np.save savez = np.savez -@util.implements(np.dtype) + def _jnp_dtype(obj: DTypeLike | None, *, align: bool = False, copy: bool = False) -> DType: """Similar to np.dtype, but respects JAX dtype defaults.""" @@ -437,20 +436,191 @@ def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: """ return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2) -@util.implements(np.issubdtype) + def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: + """Return True if arg1 is equal or lower than arg2 in the type hierarchy. + + JAX implementation of :func:`numpy.issubdtype`. + + The main difference in JAX's implementation is that it properly handles + dtype extensions such as :code:`bfloat16`. + + Args: + arg1: dtype-like object. In typical usage, this will be a dtype specifier, + such as ``"float32"`` (i.e. a string), ``np.dtype('int32')`` (i.e. an + instance of :class:`numpy.dtype`), ``jnp.complex64`` (i.e. a JAX scalar + constructor), or ``np.uint8`` (i.e. a NumPy scalar type). + arg2: dtype-like object. In typical usage, this will be a generic scalar + type, such as ``jnp.integer``, ``jnp.floating``, or ``jnp.complexfloating``. + + Returns: + True if arg1 represents a dtype that is equal or lower in the type + hierarchy than arg2. + + See also: + - :func:`jax.numpy.isdtype`: similar function aligning with the array API standard. + + Examples: + >>> jnp.issubdtype('uint32', jnp.unsignedinteger) + True + >>> jnp.issubdtype(np.int32, jnp.integer) + True + >>> jnp.issubdtype(jnp.bfloat16, jnp.floating) + True + >>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating) + True + >>> jnp.issubdtype('complex64', jnp.integer) + False + + Be aware that while this is very similar to :func:`numpy.issubdtype`, the + results of these differ in the case of JAX's custom floating point types: + + >>> np.issubdtype('bfloat16', np.floating) + False + >>> jnp.issubdtype('bfloat16', jnp.floating) + True + """ return dtypes.issubdtype(arg1, arg2) -@util.implements(np.isscalar) + def isscalar(element: Any) -> bool: - if hasattr(element, '__jax_array__'): - element = element.__jax_array__() - return dtypes.is_python_scalar(element) or np.isscalar(element) + """Return True if the input is a scalar. + + JAX implementation of :func:`numpy.isscalar`. JAX's implementation differs + from NumPy's in that it considers zero-dimensional arrays to be scalars; see + the *Note* below for more details. + + Args: + element: input object to check; any type is valid input. + + Returns: + True if ``element`` is a scalar value or an array-like object with zero + dimensions, False otherwise. + + Note: + JAX and NumPy differ in their representation of scalar values. NumPy has + special scalar objects (e.g. ``np.int32(0)``) which are distinct from + zero-dimensional arrays (e.g. ``np.array(0)``), and :func:`numpy.isscalar` + returns ``True`` for the former and ``False`` for the latter. + + JAX does not define special scalar objects, but rather represents scalars as + zero-dimensional arrays. As such, :func:`jax.numpy.isscalar` returns ``True`` + for both scalar objects (e.g. ``0.0`` or ``np.float32(0.0)``) and array-like + objects with zero dimensions (e.g. ``jnp.array(0.0)``, ``np.array(0.0)``). + + One reason for the different conventions in ``isscalar`` is to maintain + JIT-invariance: i.e. the property that the result of a function should not + change when it is JIT-compiled. Because scalar inputs are cast to + zero-dimensional JAX arrays at JIT boundaries, the semantics of + :func:`numpy.isscalar` are such that the result changes under JIT: + + >>> np.isscalar(1.0) + True + >>> jax.jit(np.isscalar)(1.0) + Array(False, dtype=bool) + + By treating zero-dimensional arrays as scalars, :func:`jax.numpy.isscalar` + avoids this issue: + + >>> jnp.isscalar(1.0) + True + >>> jax.jit(jnp.isscalar)(1.0) + Array(True, dtype=bool) + + Examples: + In JAX, both scalars and zero-dimensional array-like objects are considered + scalars: + + >>> jnp.isscalar(1.0) + True + >>> jnp.isscalar(1 + 1j) + True + >>> jnp.isscalar(jnp.array(1)) # zero-dimensional JAX array + True + >>> jnp.isscalar(jnp.int32(1)) # JAX scalar constructor + True + >>> jnp.isscalar(np.array(1.0)) # zero-dimensional NumPy array + True + >>> jnp.isscalar(np.int32(1)) # NumPy scalar type + True + + Arrays with one or more dimension are not considered scalars: + + >>> jnp.isscalar(jnp.array([1])) + False + >>> jnp.isscalar(np.array([1])) + False + + Compare this to :func:`numpy.isscalar`, which returns ``True`` for + scalar-typed objects, and ``False`` for *all* arrays, even those with + zero dimensions: + + >>> np.isscalar(np.int32(1)) # scalar object + True + >>> np.isscalar(np.array(1)) # zero-dimensional array + False + + In JAX, as in NumPy, objects which are not array-like are not considered + scalars: + + >>> jnp.isscalar(None) + False + >>> jnp.isscalar([1]) + False + >>> jnp.isscalar(tuple()) + False + >>> jnp.isscalar(slice(10)) + False + """ + if (isinstance(element, (np.ndarray, jax.Array)) + or hasattr(element, '__jax_array__') + or np.isscalar(element)): + return asarray(element).ndim == 0 + return False iterable = np.iterable -@util.implements(np.result_type) + def result_type(*args: Any) -> DType: + """Return the result of applying JAX promotion rules to the inputs. + + JAX implementation of :func:`numpy.result_type`. + + JAX's dtype promotion behavior is described in :ref:`type-promotion`. + + Args: + args: one or more arrays or dtype-like objects. + + Returns: + A :class:`numpy.dtype` instance representing the result of type + promotion for the inputs. + + Examples: + Inputs can be dtype specifiers: + + >>> jnp.result_type('int32', 'float32') + dtype('float32') + >>> jnp.result_type(np.uint16, np.dtype('int32')) + dtype('int32') + + Inputs may also be scalars or arrays: + + >>> jnp.result_type(1.0, jnp.bfloat16(2)) + dtype(bfloat16) + >>> jnp.result_type(jnp.arange(4), jnp.zeros(4)) + dtype('float32') + + Be aware that the result type will be canonicalized based on the state + of the ``jax_enable_x64`` configuration flag, meaning that 64-bit types + may be downcast to 32-bit: + + >>> jnp.result_type('float64') + dtype('float32') + + For details on 64-bit values, refer to `Sharp bits - double precision`_: + + .. _Sharp bits - double precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + """ return dtypes.result_type(*args) @@ -1295,9 +1465,9 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, input is returned as is. axis: int, optional, default=-1. Specifies the axis along which the difference is computed. The difference is computed along ``axis -1`` by default. - prepend: scalar or array, optional, defualt=None. Specifies the values to be + prepend: scalar or array, optional, default=None. Specifies the values to be prepended along ``axis`` before computing the difference. - append: scalar or array, optional, defualt=None. Specifies the values to be + append: scalar or array, optional, default=None. Specifies the values to be appended along ``axis`` before computing the difference. Returns: @@ -1451,7 +1621,6 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype)))) return result -@util.implements(np.gradient, skip_params=['edge_order']) @partial(jit, static_argnames=("axis", "edge_order")) def gradient( f: ArrayLike, @@ -1459,6 +1628,64 @@ def gradient( axis: int | Sequence[int] | None = None, edge_order: int | None = None, ) -> Array | list[Array]: + """Compute the numerical gradient of a sampled function. + + JAX implementation of :func:`numpy.gradient`. + + The gradient in ``jnp.gradient`` is computed using second-order finite + differences across the array of sampled function values. This should not + be confused with :func:`jax.grad`, which computes a precise gradient of + a callable function via :ref:`automatic differentiation `. + + Args: + f: *N*-dimensional array of function values. + varargs: optional list of scalars or arrays specifying spacing of + function evaluations. Options are: + + - not specified: unit spacing in all dimensions. + - a single scalar: constant spacing in all dimensions. + - *N* values: specify different spacing in each dimension: + + - scalar values indicate constant spacing in that dimension. + - array values must match the length of the corresponding dimension, + and specify the coordinates at which ``f`` is evaluated. + + edge_order: not implemented in JAX + axis: integer or tuple of integers specifying the axis along which + to compute the gradient. If None (default) calculates the gradient + along all axes. + + Returns: + an array or tuple of arrays containing the numerical gradient along + each specified axis. + + See also: + - :func:`jax.grad`: automatic differentiation of a function with a single output. + + Examples: + Comparing numerical and automatic differentiation of a simple function: + + >>> def f(x): + ... return jnp.sin(x) * jnp.exp(-x / 4) + ... + >>> def gradf_exact(x): + ... # exact analytical gradient of f(x) + ... return -f(x) / 4 + jnp.cos(x) * jnp.exp(-x / 4) + ... + >>> x = jnp.linspace(0, 5, 10) + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print("numerical gradient:", jnp.gradient(f(x), x)) + ... print("automatic gradient:", jax.vmap(jax.grad(f))(x)) + ... print("exact gradient: ", gradf_exact(x)) + ... + numerical gradient: [ 0.83 0.61 0.18 -0.2 -0.43 -0.49 -0.39 -0.21 -0.02 0.08] + automatic gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15] + exact gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15] + + Notice that, as expected, the numerical gradient has some approximation error + compared to the automatic gradient computed via :func:`jax.grad`. + """ if edge_order is not None: raise NotImplementedError( @@ -3173,11 +3400,53 @@ def fix(x: ArrayLike, out: None = None) -> Array: return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x)) -@util.implements(np.nan_to_num) @jit def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, posinf: ArrayLike | None = None, neginf: ArrayLike | None = None) -> Array: + """Replace NaN and infinite entries in an array. + + JAX implementation of :func:`numpy.nan_to_num`. + + Args: + x: array of values to be replaced. If it does not have an inexact + dtype it will be returned unmodified. + copy: unused by JAX + nan: value to substitute for NaN entries. Defaults to 0.0. + posinf: value to substitute for positive infinite entries. + Defaults to the maximum representable value. + neginf: value to substitute for positive infinite entries. + Defaults to the minimum representable value. + + Returns: + A copy of ``x`` with the requested substitutions. + + See also: + - :func:`jax.numpy.isnan`: return True where the array contains NaN + - :func:`jax.numpy.isposinf`: return True where the array contains +inf + - :func:`jax.numpy.isneginf`: return True where the array contains -inf + + Examples: + >>> x = jnp.array([0, jnp.nan, 1, jnp.inf, 2, -jnp.inf]) + + Default substitution values: + + >>> jnp.nan_to_num(x) + Array([ 0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 3.4028235e+38, + 2.0000000e+00, -3.4028235e+38], dtype=float32) + + Overriding substitutions for ``-inf`` and ``+inf``: + + >>> jnp.nan_to_num(x, posinf=999, neginf=-999) + Array([ 0., 0., 1., 999., 2., -999.], dtype=float32) + + If you only wish to substitute for NaN values while leaving ``inf`` values + untouched, using :func:`~jax.numpy.where` with :func:`jax.numpy.isnan` is + a better option: + + >>> jnp.where(jnp.isnan(x), 0, x) + Array([ 0., 0., 1., inf, 2., -inf], dtype=float32) + """ del copy util.check_arraylike("nan_to_num", x) dtype = _dtype(x) @@ -3540,7 +3809,7 @@ def build_padding(array, padding, before): f"and larger or equal than the padding length (= {padding}). " f"Error while handling {'left' if before else 'right'} padding on axis {i}.") try: - # We check that we can determine all comparisions. + # We check that we can determine all comparisons. offset = 1 if (mode == "reflect" and axis_size > 1) else 0 has_poly_dim = not core.is_constant_shape((axis_size, padding)) # For shape polymorphism, ensure the loop below ends after 1 iteration @@ -4388,9 +4657,85 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: return concatenate(arrs, axis=1) -@util.implements(np.choose, skip_params=['out']) -def choose(a: ArrayLike, choices: Sequence[ArrayLike], +def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: + """Construct an array by stacking slices of choice arrays. + + JAX implementation of :func:`numpy.choose`. + + The semantics of this function can be confusing, but in the simplest case where + ``a`` is a one-dimensional array, ``choices`` is a two-dimensional array, and + all entries of ``a`` are in-bounds (i.e. ``0 <= a_i < len(choices)``), then the + function is equivalent to the following:: + + def choose(a, choices): + return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)]) + + In the more general case, ``a`` may have any number of dimensions and ``choices`` + may be an arbitrary sequence of broadcast-compatible arrays. In this case, again + for in-bound indices, the logic is equivalent to:: + + def choose(a, choices): + a, *choices = jnp.broadcast_arrays(a, *choices) + choices = jnp.array(choices) + return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)]) + + The only additional complexity comes from the ``mode`` argument, which controls + the behavior for out-of-bound indices in ``a`` as described below. + + Args: + a: an N-dimensional array of integer indices. + choices: an array or sequence of arrays. All arrays in the sequence must be + mutually broadcast compatible with ``a``. + out: unused by JAX + mode: specify the out-of-bounds indexing mode; one of ``'raise'`` (default), + ``'wrap'``, or ``'clip'``. Note that the default mode of ``'raise'`` is + not compatible with JAX transformations. + + Returns: + an array containing stacked slices from ``choices`` at the indices + specified by ``a``. The shape of the result is + ``broadcast_shapes(a.shape, *(c.shape for c in choices))``. + + See also: + - :func:`jax.lax.switch`: choose between N functions based on an index. + + Examples: + Here is the simplest case of a 1D index array with a 2D choice array, + in which case this chooses the indexed value from each column: + + >>> choices = jnp.array([[ 1, 2, 3, 4], + ... [ 5, 6, 7, 8], + ... [ 9, 10, 11, 12]]) + >>> a = jnp.array([2, 0, 1, 0]) + >>> jnp.choose(a, choices) + Array([9, 2, 7, 4], dtype=int32) + + The ``mode`` argument specifies what to do with out-of-bound indices; + options are to either ``wrap`` or ``clip``: + + >>> a2 = jnp.array([2, 0, 1, 4]) # last index out-of-bound + >>> jnp.choose(a2, choices, mode='clip') + Array([ 9, 2, 7, 12], dtype=int32) + >>> jnp.choose(a2, choices, mode='wrap') + Array([9, 2, 7, 8], dtype=int32) + + In the more general case, ``choices`` may be a sequence of array-like + objects with any broadcast-compatible shapes. + + >>> choice_1 = jnp.array([1, 2, 3, 4]) + >>> choice_2 = 99 + >>> choice_3 = jnp.array([[10], + ... [20], + ... [30]]) + >>> a = jnp.array([[0, 1, 2, 0], + ... [1, 2, 0, 1], + ... [2, 0, 1, 2]]) + >>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap') + Array([[ 1, 99, 10, 4], + [99, 20, 3, 99], + [30, 2, 99, 30]], dtype=int32) + """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.choose is not supported.") util.check_arraylike('choose', a, *choices) @@ -5470,7 +5815,7 @@ def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, a length that is an integer multiple of the dtype element size, or it must be an object exporting the `Python buffer interface`_. dtype: optional. Desired data type for the array. Default is ``float64``. - This specifes the dtype used to parse the buffer, but note that after parsing, + This specifies the dtype used to parse the buffer, but note that after parsing, 64-bit values will be cast to 32-bit JAX arrays if the ``jax_enable_x64`` flag is set to ``False``. count: optional integer specifying the number of items to read from the buffer. @@ -6299,7 +6644,7 @@ def i0(x: ArrayLike) -> Array: are not supported. Returns: - An array containing the corresponding vlaues of the modified Bessel function + An array containing the corresponding values of the modified Bessel function of ``x``. See also: @@ -6559,10 +6904,48 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, return take(a, gather_indices, axis=axis) -@util.implements(getattr(np, "trapezoid", getattr(np, "trapz", None))) @partial(jit, static_argnames=('axis',)) def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: + r""" + Integrate along the given axis using the composite trapezoidal rule. + + JAX implementation of :func:`numpy.trapezoid` + + The trapezoidal rule approximates the integral under a curve by summing the + areas of trapezoids formed between adjacent data points. + + Args: + y: array of data to integrate. + x: optional array of sample points corresponding to the ``y`` values. If not + provided, ``x`` defaults to equally spaced with spacing given by ``dx``. + dx: The spacing between sample points when `x` is None (default: 1.0). + axis: The axis along which to integrate (default: -1) + + Returns: + The definite integral approximated by the trapezoidal rule. + + Examples: + Integrate over a regular grid, with spacing 1.0: + + >>> y = jnp.array([1, 2, 3, 2, 3, 2, 1]) + >>> jnp.trapezoid(y, dx=1.0) + Array(13., dtype=float32) + + Integrate over an irregular grid: + + >>> x = jnp.array([0, 2, 5, 7, 10, 15, 20]) + >>> jnp.trapezoid(y, x) + Array(43., dtype=float32) + + Approximate :math:`\int_0^{2\pi} \sin^2(x)dx`, which equals :math:`\pi`: + + >>> x = jnp.linspace(0, 2 * jnp.pi, 1000) + >>> y = jnp.sin(x) ** 2 + >>> result = jnp.trapezoid(y, x) + >>> jnp.allclose(result, jnp.pi) + Array(True, dtype=bool) + """ # TODO(phawkins): remove this annotation after fixing jnp types. dx_array: Array if x is None: @@ -6827,19 +7210,53 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int return reductions.sum(a, axis=(-2, -1), dtype=dtype) -def _wrap_indices_function(f): - @util.implements(f, update_doc=False) - def wrapper(*args, **kwargs): - args = [core.concrete_or_error( - None, arg, f"argument {i} of jnp.{f.__name__}()") - for i, arg in enumerate(args)] - kwargs = {key: core.concrete_or_error( - None, val, f"argument '{key}' of jnp.{f.__name__}()") - for key, val in kwargs.items()} - return tuple(asarray(x) for x in f(*args, **kwargs)) - return wrapper +def mask_indices(n: int, + mask_func: Callable[[ArrayLike, int], Array], + k: int = 0, *, size: int | None = None) -> tuple[Array, Array]: + """Return indices of a mask of an (n, n) array. -mask_indices = _wrap_indices_function(np.mask_indices) + Args: + n: static integer array dimension. + mask_func: a function that takes a shape ``(n, n)`` array and + an optional offset ``k``, and returns a shape ``(n, n)`` mask. + Examples of functions with this signature are + :func:`~jax.numpy.triu` and :func:`~jax.numpy.tril`. + k: a scalar value passed to ``mask_func``. + size: optional argument specifying the static size of the output arrays. + This is passed to :func:`~jax.numpy.nonzero` when generating the indices + from the mask. + + Returns: + a tuple of indices where ``mask_func`` is nonzero. + + See also: + - :func:`jax.numpy.triu_indices`: compute ``mask_indices`` for :func:`~jax.numpy.triu`. + - :func:`jax.numpy.tril_indices`: compute ``mask_indices`` for :func:`~jax.numpy.tril`. + + Examples: + Calling ``mask_indices`` on built-in masking functions: + + >>> jnp.mask_indices(3, jnp.triu) + (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32)) + + >>> jnp.mask_indices(3, jnp.tril) + (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32)) + + Calling ``mask_indices`` on a custom masking function: + + >>> def mask_func(x, k=0): + ... i = jnp.arange(x.shape[0])[:, None] + ... j = jnp.arange(x.shape[1]) + ... return (i + 1) % (j + 1 + k) == 0 + >>> mask_func(jnp.ones((3, 3))) + Array([[ True, False, False], + [ True, True, False], + [ True, False, True]], dtype=bool) + >>> jnp.mask_indices(3, mask_func) + (Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32)) + """ + i, j = nonzero(mask_func(ones((n, n)), k), size=size) + return (i, j) def _triu_size(n, m, k): @@ -7423,7 +7840,7 @@ def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: - ``fb`` - trims both leading and trailing zeros. Returns: - An array containig the trimmed input with same dtype as ``filt``. + An array containing the trimmed input with same dtype as ``filt``. Examples: >>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0]) @@ -9684,11 +10101,64 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: return moveaxis(a, axis, start) -@util.implements(np.packbits) @partial(jit, static_argnames=('axis', 'bitorder')) -def packbits( - a: ArrayLike, axis: int | None = None, bitorder: str = "big" -) -> Array: +def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Array: + """Pack array of bits into a uint8 array. + + JAX implementation of :func:`numpy.packbits` + + Args: + a: N-dimensional array of bits to pack. + axis: optional axis along which to pack bits. If not specified, ``a`` will + be flattened. + bitorder: ``"big"`` (default) or ``"little"``: specify whether the bit order + is big-endian or little-endian. + + Returns: + A uint8 array of packed values. + + See also: + - :func:`jax.numpy.unpackbits`: inverse of ``packbits``. + + Examples: + Packing bits in one dimension: + + >>> bits = jnp.array([0, 0, 0, 0, 0, 1, 1, 1]) + >>> jnp.packbits(bits) + Array([7], dtype=uint8) + >>> 0b00000111 # equivalent bit-wise representation: + 7 + + Optionally specifying little-endian convention: + + >>> jnp.packbits(bits, bitorder="little") + Array([224], dtype=uint8) + >>> 0b11100000 # equivalent bit-wise representation + 224 + + If the number of bits is not a multiple of 8, it will be right-padded + with zeros: + + >>> jnp.packbits(jnp.array([1, 0, 1])) + Array([160], dtype=uint8) + >>> jnp.packbits(jnp.array([1, 0, 1, 0, 0, 0, 0, 0])) + Array([160], dtype=uint8) + + For a multi-dimensional input, bits may be packed along a specified axis: + + >>> a = jnp.array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0], + ... [0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]]) + >>> vals = jnp.packbits(a, axis=1) + >>> vals + Array([[212, 150], + [ 69, 207]], dtype=uint8) + + The inverse of ``packbits`` is provided by :func:`~jax.numpy.unpackbits`: + + >>> jnp.unpackbits(vals, axis=1) + Array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0], + [0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]], dtype=uint8) + """ util.check_arraylike("packbits", a) arr = asarray(a) if not (issubdtype(arr.dtype, integer) or issubdtype(arr.dtype, bool_)): @@ -9715,7 +10185,6 @@ def packbits( return swapaxes(packed, axis, -1) -@util.implements(np.unpackbits) @partial(jit, static_argnames=('axis', 'count', 'bitorder')) def unpackbits( a: ArrayLike, @@ -9723,6 +10192,67 @@ def unpackbits( count: int | None = None, bitorder: str = "big", ) -> Array: + """Unpack the bits in a uint8 array. + + JAX implementation of :func:`numpy.unpackbits`. + + Args: + a: N-dimensional array of type ``uint8``. + axis: optional axis along which to unpack. If not specified, ``a`` will + be flattened + count: specify the number of bits to unpack (if positive) or the number + of bits to trim from the end (if negative). + bitorder: ``"big"`` (default) or ``"little"``: specify whether the bit order + is big-endian or little-endian. + + Returns: + a uint8 array of unpacked bits. + + See also: + - :func:`jax.numpy.packbits`: this inverse of ``unpackbits``. + + Examples: + Unpacking bits from a scalar: + + >>> jnp.unpackbits(jnp.uint8(27)) # big-endian by default + Array([0, 0, 0, 1, 1, 0, 1, 1], dtype=uint8) + >>> jnp.unpackbits(jnp.uint8(27), bitorder="little") + Array([1, 1, 0, 1, 1, 0, 0, 0], dtype=uint8) + + Compare this to the Python binary representation: + + >>> 0b00011011 + 27 + + Unpacking bits along an axis: + + >>> vals = jnp.array([[154], + ... [ 49]], dtype='uint8') + >>> bits = jnp.unpackbits(vals, axis=1) + >>> bits + Array([[1, 0, 0, 1, 1, 0, 1, 0], + [0, 0, 1, 1, 0, 0, 0, 1]], dtype=uint8) + + Using :func:`~jax.numpy.packbits` to invert this: + + >>> jnp.packbits(bits, axis=1) + Array([[154], + [ 49]], dtype=uint8) + + The ``count`` keyword lets ``unpackbits`` serve as an inverse of ``packbits`` + in cases where not all bits are present: + + >>> bits = jnp.array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1]) # 11 bits + >>> vals = jnp.packbits(bits) + >>> vals + Array([219, 96], dtype=uint8) + >>> jnp.unpackbits(vals) # 16 zero-padded bits + Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0], dtype=uint8) + >>> jnp.unpackbits(vals, count=11) # specify 11 output bits + Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8) + >>> jnp.unpackbits(vals, count=-5) # specify 5 bits to be trimmed + Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8) + """ util.check_arraylike("unpackbits", a) arr = asarray(a) if _dtype(a) != uint8: @@ -10015,6 +10545,26 @@ def replace(tup, val): out_shape = lax.broadcast_shapes(idx_shape, arr_shape) if axis_size == 0: return zeros(out_shape, a.dtype) + + if mode == "one_hot": + indices = _normalize_index(indices, axis_size) + hot = jax.nn.one_hot(indices, axis_size, dtype=bool_) + if a.ndim == 1: + return einsum("...b,b->...", hot, a, preferred_element_type=a.dtype) + if axis_int > len(string.ascii_letters) - 2: + raise ValueError( + "One Hot indexing is only supported for up to 50 leading dimensions." + ) + labels = "".join([string.ascii_letters[i] for i in range(axis_int)]) + eq = labels + "y...z," + labels + "z...->" + labels + "y..." + return einsum( + eq, + hot, + a, + precision=lax.Precision.HIGHEST, + preferred_element_type=a.dtype, + ) + index_dims = [i for i, idx in enumerate(idx_shape) if i == axis_int or not core.definitely_equal(idx, 1)] gather_index_shape = tuple(np.array(out_shape)[index_dims]) + (1,) @@ -10023,6 +10573,8 @@ def replace(tup, val): offset_dims = [] start_index_map = [] collapsed_slice_dims = [] + operand_batching_dims = [] + start_indices_batching_dims = [] j = 0 for i in range(rank): if i == axis_int: @@ -10047,21 +10599,23 @@ def replace(tup, val): collapsed_slice_dims.append(i) j += 1 else: - # Otherwise, idx_shape[i] == arr_shape[i]. Use an iota index so - # corresponding elements of array and index are gathered. - # TODO(mattjj): next line needs updating for dynamic shapes - iota = lax.broadcasted_iota(index_dtype, gather_index_shape, j) - gather_indices.append(iota) - slice_sizes.append(1) - start_index_map.append(i) - collapsed_slice_dims.append(i) + # Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both + # array and index as batching so corresponding elements are gathered. + if core.definitely_equal(arr_shape[i], 0): + slice_sizes.append(0) + else: + slice_sizes.append(1) + operand_batching_dims.append(i) + start_indices_batching_dims.append(j) j += 1 gather_indices_arr = lax.concatenate(gather_indices, dimension=j) dnums = lax.GatherDimensionNumbers( offset_dims=tuple(offset_dims), collapsed_slice_dims=tuple(collapsed_slice_dims), - start_index_map=tuple(start_index_map)) + start_index_map=tuple(start_index_map), + operand_batching_dims=tuple(operand_batching_dims), + start_indices_batching_dims=tuple(start_indices_batching_dims)) return lax.gather(a, gather_indices_arr, dnums, tuple(slice_sizes), mode="fill" if mode is None else mode, fill_value=fill_value) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 3436b00cfce1..1c2a4689cb85 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -32,7 +32,7 @@ from jax._src import dtypes from jax._src.numpy.util import ( _broadcast_to, check_arraylike, _complex_elem_type, - promote_dtypes_inexact, promote_dtypes_numeric, _where, implements) + promote_dtypes_inexact, promote_dtypes_numeric, _where) from jax._src.lax import lax as lax_internal from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg from jax._src.util import ( @@ -700,9 +700,8 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, An array of the mean along the given axis. See also: - - :func:`jax.numpy.sum`: Compute the sum of array elements over a given axis. - - :func:`jax.numpy.max`: Compute the maximum of array elements over given axis. - - :func:`jax.numpy.min`: Compute the minimum of array elements over given axis. + - :func:`jax.numpy.average`: Compute the weighted average of array elements + - :func:`jax.numpy.sum`: Compute the sum of array elements. Examples: By default, the mean is computed along all the axes. @@ -782,9 +781,59 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, * @overload def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ... -@implements(np.average) def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: + """Compute the weighed average. + + JAX Implementation of :func:`numpy.average`. + + Args: + a: array to be averaged + axis: an optional integer or sequence of integers specifying the axis along which + the mean to be computed. If not specified, mean is computed along all the axes. + weights: an optional array of weights for a weighted average. Must be + broadcast-compatible with ``a``. + returned: If False (default) then return only the average. If True then return both + the average and the normalization factor (i.e. the sum of weights). + keepdims: If True, reduced axes are left in the result with size 1. If False (default) + then reduced axes are squeezed out. + + Returns: + An array ``average`` or tuple of arrays ``(average, normalization)`` if + ``returned`` is True. + + See also: + - :func:`jax.numpy.mean`: unweighted mean. + + Examples: + Simple average: + + >>> x = jnp.array([1, 2, 3, 2, 4]) + >>> jnp.average(x) + Array(2.4, dtype=float32) + + Weighted average: + + >>> weights = jnp.array([2, 1, 3, 2, 2]) + >>> jnp.average(x, weights=weights) + Array(2.5, dtype=float32) + + Use ``returned=True`` to optionally return the normalization, i.e. the + sum of weights: + + >>> jnp.average(x, returned=True) + (Array(2.4, dtype=float32), Array(5., dtype=float32)) + >>> jnp.average(x, weights=weights, returned=True) + (Array(2.5, dtype=float32), Array(10., dtype=float32)) + + Weighted average along a specified axis: + + >>> x = jnp.array([[8, 2, 7], + ... [3, 6, 4]]) + >>> weights = jnp.array([1, 2, 3]) + >>> jnp.average(x, weights=weights, axis=1) + Array([5.5, 4.5], dtype=float32) + """ return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims) @partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index dc265b8e87e1..432b123064b7 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -604,7 +604,7 @@ def arcsin(x: ArrayLike, /) -> Array: Note: - ``jnp.arcsin`` returns ``nan`` when ``x`` is real-valued and not in the closed interval ``[-1, 1]``. - - ``jnp.arcsin`` follows the branch cut convention of :func:`numpy.arcsin` for + - ``jnp.arcsin`` follows the branch cut convention of :obj:`numpy.arcsin` for complex inputs. See also: @@ -645,7 +645,7 @@ def arccos(x: ArrayLike, /) -> Array: Note: - ``jnp.arccos`` returns ``nan`` when ``x`` is real-valued and not in the closed interval ``[-1, 1]``. - - ``jnp.arccos`` follows the branch cut convention of :func:`numpy.arccos` for + - ``jnp.arccos`` follows the branch cut convention of :obj:`numpy.arccos` for complex inputs. See also: @@ -685,7 +685,7 @@ def arctan(x: ArrayLike, /) -> Array: in radians in the range ``[-pi/2, pi/2]``, promoting to inexact dtype. Note: - ``jnp.arctan`` follows the branch cut convention of :func:`numpy.arctan` for + ``jnp.arctan`` follows the branch cut convention of :obj:`numpy.arctan` for complex inputs. See also: @@ -817,14 +817,103 @@ def cosh(x: ArrayLike, /) -> Array: """ return lax.cosh(*promote_args_inexact('cosh', x)) -@implements(np.arcsinh, module='numpy') + @partial(jit, inline=True) def arcsinh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise inverse of hyperbolic sine of input. + + JAX implementation of :obj:`numpy.arcsinh`. + + The inverse of hyperbolic sine is defined by: + + .. math:: + + arcsinh(x) = \ln(x + \sqrt{1 + x^2}) + + Args: + x: input array or scalar. + + Returns: + An array of same shape as ``x`` containing the inverse of hyperbolic sine of + each element of ``x``, promoting to inexact dtype. + + Note: + - ``jnp.arcsinh`` returns ``nan`` for values outside the range ``(-inf, inf)``. + - ``jnp.arcsinh`` follows the branch cut convention of :obj:`numpy.arcsinh` + for complex inputs. + + See also: + - :func:`jax.numpy.sinh`: Computes the element-wise hyperbolic sine of the input. + - :func:`jax.numpy.arccosh`: Computes the element-wise inverse of hyperbolic + cosine of the input. + - :func:`jax.numpy.arctanh`: Computes the element-wise inverse of hyperbolic + tangent of the input. + + Examples: + >>> x = jnp.array([[-2, 3, 1], + ... [4, 9, -5]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arcsinh(x) + Array([[-1.444, 1.818, 0.881], + [ 2.095, 2.893, -2.312]], dtype=float32) + + For complex-valued inputs: + + >>> x1 = jnp.array([4-3j, 2j]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arcsinh(x1) + Array([2.306-0.634j, 1.317+1.571j], dtype=complex64) + """ return lax.asinh(*promote_args_inexact('arcsinh', x)) -@implements(np.arccosh, module='numpy') + @jit def arccosh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise inverse of hyperbolic cosine of input. + + JAX implementation of :obj:`numpy.arccosh`. + + The inverse of hyperbolic cosine is defined by: + + .. math:: + + arccosh(x) = \ln(x + \sqrt{x^2 - 1}) + + Args: + x: input array or scalar. + + Returns: + An array of same shape as ``x`` containing the inverse of hyperbolic cosine + of each element of ``x``, promoting to inexact dtype. + + Note: + - ``jnp.arccosh`` returns ``nan`` for real-values in the range ``[-inf, 1)``. + - ``jnp.arccosh`` follows the branch cut convention of :obj:`numpy.arccosh` + for complex inputs. + + See also: + - :func:`jax.numpy.cosh`: Computes the element-wise hyperbolic cosine of the + input. + - :func:`jax.numpy.arcsinh`: Computes the element-wise inverse of hyperbolic + sine of the input. + - :func:`jax.numpy.arctanh`: Computes the element-wise inverse of hyperbolic + tangent of the input. + + Examples: + >>> x = jnp.array([[1, 3, -4], + ... [-5, 2, 7]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arccosh(x) + Array([[0. , 1.763, nan], + [ nan, 1.317, 2.634]], dtype=float32) + + For complex-valued input: + + >>> x1 = jnp.array([-jnp.inf+0j, 1+2j, -5+0j]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arccosh(x1) + Array([ inf+3.142j, 1.529+1.144j, 2.292+3.142j], dtype=complex64) + """ # Note: arccosh is multi-valued for complex input, and lax.acosh # uses a different convention than np.arccosh. result = lax.acosh(*promote_args_inexact("arccosh", x)) @@ -885,9 +974,52 @@ def tanh(x: ArrayLike, /) -> Array: """ return lax.tanh(*promote_args_inexact('tanh', x)) -@implements(np.arctanh, module='numpy') + @partial(jit, inline=True) def arctanh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise inverse of hyperbolic tangent of input. + + JAX implementation of :obj:`numpy.arctanh`. + + The inverse of hyperbolic tangent is defined by: + + .. math:: + + arctanh(x) = \frac{1}{2} [\ln(1 + x) - \ln(1 - x)] + + Args: + x: input array or scalar. + + Returns: + An array of same shape as ``x`` containing the inverse of hyperbolic tangent + of each element of ``x``, promoting to inexact dtype. + + Note: + - ``jnp.arctanh`` returns ``nan`` for real-values outside the range ``[-1, 1]``. + - ``jnp.arctanh`` follows the branch cut convention of :obj:`numpy.arctanh` + for complex inputs. + + See also: + - :func:`jax.numpy.tanh`: Computes the element-wise hyperbolic tangent of the + input. + - :func:`jax.numpy.arcsinh`: Computes the element-wise inverse of hyperbolic + sine of the input. + - :func:`jax.numpy.arccosh`: Computes the element-wise inverse of hyperbolic + cosine of the input. + + Examples: + >>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arctanh(x) + Array([ nan, -inf, -0.549, 0. , 0.549, inf, nan], dtype=float32) + + For complex-valued input: + + >>> x1 = jnp.array([-2+0j, 3+0j, 4-1j]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arctanh(x1) + Array([-0.549+1.571j, 0.347+1.571j, 0.239-1.509j], dtype=complex64) + """ return lax.atanh(*promote_args_inexact('arctanh', x)) @@ -922,9 +1054,32 @@ def sqrt(x: ArrayLike, /) -> Array: """ return lax.sqrt(*promote_args_inexact('sqrt', x)) -@implements(np.cbrt, module='numpy') + @partial(jit, inline=True) def cbrt(x: ArrayLike, /) -> Array: + """Calculates element-wise cube root of the input array. + + JAX implementation of :obj:`numpy.cbrt`. + + Args: + x: input array or scalar. ``complex`` dtypes are not supported. + + Returns: + An array containing the cube root of the elements of ``x``. + + See also: + - :func:`jax.numpy.sqrt`: Calculates the element-wise non-negative square root + of the input. + - :func:`jax.numpy.square`: Calculates the element-wise square of the input. + + Examples: + >>> x = jnp.array([[216, 125, 64], + ... [-27, -8, -1]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cbrt(x) + Array([[ 6., 5., 4.], + [-3., -2., -1.]], dtype=float32) + """ return lax.cbrt(*promote_args_inexact('cbrt', x)) @partial(jit, inline=True) @@ -1074,14 +1229,65 @@ def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) -@implements(np.left_shift, module='numpy') + @partial(jit, inline=True) def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: + r"""Shift bits of ``x`` to left by the amount specified in ``y``, element-wise. + + JAX implementation of :obj:`numpy.left_shift`. + + Args: + x: Input array, must be integer-typed. + y: The amount of bits to shift each element in ``x`` to the left, only accepts + integer subtypes. ``x`` and ``y`` must either have same shape or be broadcast + compatible. + + Returns: + An array containing the left shifted elements of ``x`` by the amount specified + in ``y``, with the same shape as the broadcasted shape of ``x`` and ``y``. + + Note: + Left shifting ``x`` by ``y`` is equivalent to ``x * (2**y)`` within the + bounds of the dtypes involved. + + See also: + - :func:`jax.numpy.right_shift`: and :func:`jax.numpy.bitwise_right_shift`: + Shifts the bits of ``x1`` to right by the amount specified in ``x2``, + element-wise. + - :func:`jax.numpy.bitwise_left_shift`: Alias of :func:`jax.left_shift`. + + Examples: + >>> def print_binary(x): + ... return [bin(int(val)) for val in x] + + >>> x1 = jnp.arange(5) + >>> x1 + Array([0, 1, 2, 3, 4], dtype=int32) + >>> print_binary(x1) + ['0b0', '0b1', '0b10', '0b11', '0b100'] + >>> x2 = 1 + >>> result = jnp.left_shift(x1, x2) + >>> result + Array([0, 2, 4, 6, 8], dtype=int32) + >>> print_binary(result) + ['0b0', '0b10', '0b100', '0b110', '0b1000'] + + >>> x3 = 4 + >>> print_binary([x3]) + ['0b100'] + >>> x4 = jnp.array([1, 2, 3, 4]) + >>> result1 = jnp.left_shift(x3, x4) + >>> result1 + Array([ 8, 16, 32, 64], dtype=int32) + >>> print_binary(result1) + ['0b1000', '0b10000', '0b100000', '0b1000000'] + """ return lax.shift_left(*promote_args_numeric("left_shift", x, y)) -@implements(getattr(np, "bitwise_left_shift", np.left_shift), module='numpy') + @partial(jit, inline=True) def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.left_shift`.""" return lax.shift_left(*promote_args_numeric("bitwise_left_shift", x, y)) @implements(np.equal, module='numpy') @@ -1296,6 +1502,50 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.nextafter(*promote_args_inexact("nextafter", x, y)) + +@partial(jit, inline=True) +def spacing(x: ArrayLike, /) -> Array: + """Return the spacing between ``x`` and the next adjacent number. + + JAX implementation of :func:`numpy.spacing`. + + Args: + x: real-valued array. Integer or boolean types will be cast to float. + + Returns: + Array of same shape as ``x`` containing spacing between each entry of + ``x`` and its closest adjacent value. + + See also: + - :func:`jax.numpy.nextafter`: find the next representable value. + + Examples: + >>> x = jnp.array([0.0, 0.25, 0.5, 0.75, 1.0], dtype='float32') + >>> jnp.spacing(x) + Array([1.4012985e-45, 2.9802322e-08, 5.9604645e-08, 5.9604645e-08, + 1.1920929e-07], dtype=float32) + + For ``x = 1``, the spacing is equal to the ``eps`` value given by + :class:`jax.numpy.finfo`: + + >>> x = jnp.float32(1) + >>> jnp.spacing(x) == jnp.finfo(x.dtype).eps + Array(True, dtype=bool) + """ + arr, = promote_args_inexact("spacing", x) + if dtypes.isdtype(arr.dtype, "complex floating"): + raise ValueError("jnp.spacing is not defined for complex inputs.") + inf = _lax_const(arr, np.inf) + smallest_subnormal = dtypes.finfo(arr.dtype).smallest_subnormal + + # Numpy's behavior seems to depend on dtype + if arr.dtype == 'float16': + return lax.nextafter(arr, inf) - arr + else: + result = lax.nextafter(arr, copysign(inf, arr)) - arr + return _where(result == 0, copysign(smallest_subnormal, arr), result) + + # Logical ops @partial(jit, inline=True) def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: @@ -2277,58 +2527,87 @@ def _normalize_float(x): return lax.bitcast_convert_type(x1, int_type), x2 -@implements(np.ldexp, module='numpy') @jit def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Compute x1 * 2 ** x2 + + JAX implementation of :func:`numpy.ldexp`. + + Note that XLA does not provide an ``ldexp`` operation, so this + is implemneted in JAX via a standard multiplication and + exponentiation. + + Args: + x1: real-valued input array. + x2: integer input array. Must be broadcast-compatible with ``x1``. + + Returns: + ``x1 * 2 ** x2`` computed element-wise. + + See also: + - :func:`jax.numpy.frexp`: decompose values into mantissa and exponent. + + Examples: + >>> x1 = jnp.arange(5.0) + >>> x2 = 10 + >>> jnp.ldexp(x1, x2) + Array([ 0., 1024., 2048., 3072., 4096.], dtype=float32) + + ``ldexp`` can be used to reconstruct the input to ``frexp``: + + >>> x = jnp.array([2., 3., 5., 11.]) + >>> m, e = jnp.frexp(x) + >>> m + Array([0.5 , 0.75 , 0.625 , 0.6875], dtype=float32) + >>> e + Array([2, 2, 3, 4], dtype=int32) + >>> jnp.ldexp(m, e) + Array([ 2., 3., 5., 11.], dtype=float32) + """ check_arraylike("ldexp", x1, x2) x1_dtype = dtypes.dtype(x1) x2_dtype = dtypes.dtype(x2) if (dtypes.issubdtype(x1_dtype, np.complexfloating) or dtypes.issubdtype(x2_dtype, np.inexact)): raise ValueError(f"ldexp not supported for input types {(x1_dtype, x2_dtype)}") + x1, = promote_args_inexact("ldexp", x1) + x2 = lax.convert_element_type(x2, dtypes.dtype(x1)) + x = x1 * (2 ** x2) + return _where(isinf(x1) | (x1 == 0), x1, x) - x1, x2 = promote_shapes("ldexp", x1, x2) - dtype = dtypes.canonicalize_dtype(dtypes.to_inexact_dtype(x1_dtype)) - info = dtypes.finfo(dtype) - int_type = _INT_DTYPES[info.bits] - - x1 = lax.convert_element_type(x1, dtype) - x2 = lax.convert_element_type(x2, int_type) - - mask = (1 << info.nexp) - 1 - bias = 1 - info.minexp - x, e = _normalize_float(x1) - x2 += e + ((x >> info.nmant) & mask) - bias +@jit +def frexp(x: ArrayLike, /) -> tuple[Array, Array]: + """Split floating point values into mantissa and twos exponent. - # find underflow/overflow before denormalization - underflow_cond = less(x2, -(bias + info.nmant)) - overflow_cond = greater(x2, bias) + JAX implementation of :func:`numpy.frexp`. - m = lax.full_like(x, 1, dtype=dtype) + Args: + x: real-valued array - # denormals - cond = less(x2, -bias + 1) - x2 = _where(cond, x2 + info.nmant, x2) - m = _where(cond, m / (1 << info.nmant), m) + Returns: + A tuple ``(mantissa, exponent)`` where ``mantissa`` is a floating point + value between -1 and 1, and ``exponent`` is an integer such that + ``x == mantissa * 2 ** exponent``. - x2 = lax.convert_element_type(x2, np.int32) - x &= ~(mask << info.nmant) - x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) + See also: + - :func:`jax.numpy.ldexp`: compute the inverse of ``frexp``. - x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) + Examples: + Split values into mantissa and exponent: - # underflow - x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) - # overflow - x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) - # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 - return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x) + >>> x = jnp.array([1., 2., 3., 4., 5.]) + >>> m, e = jnp.frexp(x) + >>> m + Array([0.5 , 0.5 , 0.75 , 0.5 , 0.625], dtype=float32) + >>> e + Array([1, 2, 2, 3, 3], dtype=int32) + Reconstruct the original array: -@implements(np.frexp, module='numpy') -@jit -def frexp(x: ArrayLike, /) -> tuple[Array, Array]: + >>> m * 2 ** e + Array([1., 2., 3., 4., 5.], dtype=float32) + """ check_arraylike("frexp", x) x, = promote_dtypes_inexact(x) if dtypes.issubdtype(x.dtype, np.complexfloating): @@ -2699,9 +2978,36 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: return x - whole, whole -@implements(np.isfinite, module='numpy') @partial(jit, inline=True) def isfinite(x: ArrayLike, /) -> Array: + """Return a boolean array indicating whether each element of input is finite. + + JAX implementation of :obj:`numpy.isfinite`. + + Args: + x: input array or scalar. + + Returns: + A boolean array of same shape as ``x`` containing ``True`` where ``x`` is + not ``inf``, ``-inf``, or ``NaN``, and ``False`` otherwise. + + See also: + - :func:`jax.numpy.isinf`: Returns a boolean array indicating whether each + element of input is either positive or negative infinity. + - :func:`jax.numpy.isposinf`: Returns a boolean array indicating whether each + element of input is positive infinity. + - :func:`jax.numpy.isneginf`: Returns a boolean array indicating whether each + element of input is negative infinity. + - :func:`jax.numpy.isnan`: Returns a boolean array indicating whether each + element of input is not a number (``NaN``). + + Examples: + >>> x = jnp.array([-1, 3, jnp.inf, jnp.nan]) + >>> jnp.isfinite(x) + Array([ True, True, False, False], dtype=bool) + >>> jnp.isfinite(3-4j) + Array(True, dtype=bool, weak_type=True) + """ check_arraylike("isfinite", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): @@ -2712,9 +3018,36 @@ def isfinite(x: ArrayLike, /) -> Array: return lax.full_like(x, True, dtype=np.bool_) -@implements(np.isinf, module='numpy') @jit def isinf(x: ArrayLike, /) -> Array: + """Return a boolean array indicating whether each element of input is infinite. + + JAX implementation of :obj:`numpy.isinf`. + + Args: + x: input array or scalar. + + Returns: + A boolean array of same shape as ``x`` containing ``True`` where ``x`` is + ``inf`` or ``-inf``, and ``False`` otherwise. + + See also: + - :func:`jax.numpy.isposinf`: Returns a boolean array indicating whether each + element of input is positive infinity. + - :func:`jax.numpy.isneginf`: Returns a boolean array indicating whether each + element of input is negative infinity. + - :func:`jax.numpy.isfinite`: Returns a boolean array indicating whether each + element of input is finite. + - :func:`jax.numpy.isnan`: Returns a boolean array indicating whether each + element of input is not a number (``NaN``). + + Examples: + >>> jnp.isinf(jnp.inf) + Array(True, dtype=bool) + >>> x = jnp.array([2+3j, -jnp.inf, 6, jnp.inf, jnp.nan]) + >>> jnp.isinf(x) + Array([False, True, False, True, False], dtype=bool) + """ check_arraylike("isinf", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): @@ -2740,26 +3073,148 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: return lax.full_like(x, False, dtype=np.bool_) -@implements(np.isposinf, module='numpy') def isposinf(x, /, out=None): + """ + Return boolean array indicating whether each element of input is positive infinite. + + JAX implementation of :obj:`numpy.isposinf`. + + Args: + x: input array or scalar. ``complex`` dtype are not supported. + + Returns: + A boolean array of same shape as ``x`` containing ``True`` where ``x`` is + ``inf``, and ``False`` otherwise. + + See also: + - :func:`jax.numpy.isinf`: Returns a boolean array indicating whether each + element of input is either positive or negative infinity. + - :func:`jax.numpy.isneginf`: Returns a boolean array indicating whether each + element of input is negative infinity. + - :func:`jax.numpy.isfinite`: Returns a boolean array indicating whether each + element of input is finite. + - :func:`jax.numpy.isnan`: Returns a boolean array indicating whether each + element of input is not a number (``NaN``). + + Examples: + >>> jnp.isposinf(5) + Array(False, dtype=bool) + >>> x = jnp.array([-jnp.inf, 5, jnp.inf, jnp.nan, 1]) + >>> jnp.isposinf(x) + Array([False, False, True, False, False], dtype=bool) + """ return _isposneginf(np.inf, x, out) -@implements(np.isposinf, module='numpy') def isneginf(x, /, out=None): + """ + Return boolean array indicating whether each element of input is negative infinite. + + JAX implementation of :obj:`numpy.isneginf`. + + Args: + x: input array or scalar. ``complex`` dtype are not supported. + + Returns: + A boolean array of same shape as ``x`` containing ``True`` where ``x`` is + ``-inf``, and ``False`` otherwise. + + See also: + - :func:`jax.numpy.isinf`: Returns a boolean array indicating whether each + element of input is either positive or negative infinity. + - :func:`jax.numpy.isposinf`: Returns a boolean array indicating whether each + element of input is positive infinity. + - :func:`jax.numpy.isfinite`: Returns a boolean array indicating whether each + element of input is finite. + - :func:`jax.numpy.isnan`: Returns a boolean array indicating whether each + element of input is not a number (``NaN``). + + Examples: + >>> jnp.isneginf(jnp.inf) + Array(False, dtype=bool) + >>> x = jnp.array([-jnp.inf, 5, jnp.inf, jnp.nan, 1]) + >>> jnp.isneginf(x) + Array([ True, False, False, False, False], dtype=bool) + """ return _isposneginf(-np.inf, x, out) -@implements(np.isnan, module='numpy') @partial(jit, inline=True) def isnan(x: ArrayLike, /) -> Array: + """Returns a boolean array indicating whether each element of input is ``NaN``. + + JAX implementation of :obj:`numpy.isnan`. + + Args: + x: input array or scalar. + + Returns: + A boolean array of same shape as ``x`` containing ``True`` where ``x`` is + not a number (i.e. ``NaN``) and ``False`` otherwise. + + See also: + - :func:`jax.numpy.isfinite`: Returns a boolean array indicating whether each + element of input is finite. + - :func:`jax.numpy.isinf`: Returns a boolean array indicating whether each + element of input is either positive or negative infinity. + - :func:`jax.numpy.isposinf`: Returns a boolean array indicating whether each + element of input is positive infinity. + - :func:`jax.numpy.isneginf`: Returns a boolean array indicating whether each + element of input is negative infinity. + + Examples: + >>> jnp.isnan(6) + Array(False, dtype=bool, weak_type=True) + >>> x = jnp.array([2, 1+4j, jnp.inf, jnp.nan]) + >>> jnp.isnan(x) + Array([False, False, False, True], dtype=bool) + """ check_arraylike("isnan", x) return lax.ne(x, x) -@implements(np.heaviside, module='numpy') @jit def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: + r"""Compute the heaviside step function. + + JAX implementation of :obj:`numpy.heaviside`. + + The heaviside step function is defined by: + + .. math:: + + \mathrm{heaviside}(x1, x2) = \begin{cases} + 0., & x < 0\\ + x2, & x = 0\\ + 1., & x > 0. + \end{cases} + + Args: + x1: input array or scalar. ``complex`` dtype are not supported. + x2: scalar or array. Specifies the return values when ``x1`` is ``0``. ``complex`` + dtype are not supported. ``x1`` and ``x2`` must either have same shape or + broadcast compatible. + + Returns: + An array containing the heaviside step function of ``x1``, promoting to + inexact dtype. + + Examples: + >>> x1 = jnp.array([[-2, 0, 3], + ... [5, -1, 0], + ... [0, 7, -3]]) + >>> x2 = jnp.array([2, 0.5, 1]) + >>> jnp.heaviside(x1, x2) + Array([[0. , 0.5, 1. ], + [1. , 0. , 1. ], + [2. , 1. , 0. ]], dtype=float32) + >>> jnp.heaviside(x1, 0.5) + Array([[0. , 0.5, 1. ], + [1. , 0. , 0.5], + [0.5, 1. , 0. ]], dtype=float32) + >>> jnp.heaviside(-3, x2) + Array([0., 0., 0.], dtype=float32) + """ check_arraylike("heaviside", x1, x2) x1, x2 = promote_dtypes_inexact(x1, x2) zero = _lax_const(x1, 0) @@ -2767,9 +3222,39 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: _where(lax.gt(x1, zero), _lax_const(x1, 1), x2)) -@implements(np.hypot, module='numpy') @jit def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: + r""" + Return element-wise hypotenuse for the given legs of a right angle triangle. + + JAX implementation of :obj:`numpy.hypot`. + + Args: + x1: scalar or array. Specifies one of the legs of right angle triangle. + ``complex`` dtype are not supported. + x2: scalar or array. Specifies the other leg of right angle triangle. + ``complex`` dtype are not supported. ``x1`` and ``x2`` must either have + same shape or be broadcast compatible. + + Returns: + An array containing the hypotenuse for the given given legs ``x1`` and ``x2`` + of a right angle triangle, promoting to inexact dtype. + + Note: + ``jnp.hypot`` is a more numerically stable way of computing + ``jnp.sqrt(x1 ** 2 + x2 **2)``. + + Examples: + >>> jnp.hypot(3, 4) + Array(5., dtype=float32, weak_type=True) + >>> x1 = jnp.array([[3, -2, 5], + ... [9, 1, -4]]) + >>> x2 = jnp.array([-5, 6, 8]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.hypot(x1, x2) + Array([[ 5.831, 6.325, 9.434], + [10.296, 6.083, 8.944]], dtype=float32) + """ x1, x2 = promote_args_inexact("hypot", x1, x2) # TODO(micky774): Promote to ValueError when deprecation is complete @@ -2785,9 +3270,34 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(idx_inf, _lax_const(x, np.inf), x) -@implements(np.reciprocal, module='numpy') @partial(jit, inline=True) def reciprocal(x: ArrayLike, /) -> Array: + """Calculate element-wise reciprocal of the input. + + JAX implementation of :obj:`numpy.reciprocal`. + + The reciprocal is calculated by ``1/x``. + + Args: + x: input array or scalar. + + Returns: + An array of same shape as ``x`` containing the reciprocal of each element of + ``x``. + + Note: + For integer inputs, ``np.reciprocal`` returns rounded integer output, while + ``jnp.reciprocal`` promotes integer inputs to floating point. + + Examples: + >>> jnp.reciprocal(2) + Array(0.5, dtype=float32, weak_type=True) + >>> jnp.reciprocal(0.) + Array(inf, dtype=float32, weak_type=True) + >>> x = jnp.array([1, 5., 4.]) + >>> jnp.reciprocal(x) + Array([1. , 0.2 , 0.25], dtype=float32) + """ check_arraylike("reciprocal", x) x, = promote_dtypes_inexact(x) return lax.integer_pow(x, -1) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 9c9bc5d389e1..c5b1530ca215 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -115,7 +115,6 @@ def implements( original_fun: Callable[..., Any] | None, update_doc: bool = True, sections: Sequence[str] = ('Parameters', 'Returns', 'References'), - skip_params: Sequence[str] = (), module: str | None = None, ) -> Callable[[_T], _T]: """Decorator for JAX functions which implement a specified NumPy function. @@ -133,8 +132,6 @@ def implements( If False, include the numpy docstring verbatim. sections: a list of sections to include in the docstring. The default is ["Parameters", "Returns", "References"] - skip_params: a list of strings containing names of parameters accepted by the - function that should be skipped in the parameter list. module: an optional string specifying the module from which the original function is imported. This is useful for objects such as ufuncs, where the module cannot be determined from the original function itself. @@ -162,8 +159,7 @@ def decorator(wrapped_fun): # Remove unrecognized parameter descriptions. parameters = _parse_parameters(parsed.sections['Parameters']) parameters = {p: desc for p, desc in parameters.items() - if (code is None or p in code.co_varnames) - and p not in skip_params} + if (code is None or p in code.co_varnames)} if parameters: parsed.sections['Parameters'] = ( "Parameters\n" diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 2bcfe96ad2f0..809df8195d54 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -122,7 +122,9 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, dnums = lax.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, - scatter_dims_to_operand_dims=indexer.dnums.start_index_map + scatter_dims_to_operand_dims=indexer.dnums.start_index_map, + operand_batching_dims=indexer.dnums.operand_batching_dims, + scatter_indices_batching_dims=indexer.dnums.start_indices_batching_dims, ) out = scatter_op( x, indexer.gather_indices, y, dnums, diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index e817369a50c5..0ff463562355 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -24,13 +24,11 @@ import itertools import threading from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable -import warnings import jax from jax._src import api_util from jax._src import config from jax._src import core as jax_core -from jax._src import deprecations from jax._src import dtypes from jax._src import linear_util as lu from jax._src import mesh as mesh_lib @@ -40,6 +38,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge +from jax._src.state.types import TransformedRef import jax.numpy as jnp @@ -377,38 +376,11 @@ class BlockSpec: See :ref:`pallas_blockspec` for more details. """ # An internal canonicalized version is in BlockMapping. - block_shape: tuple[int | None, ...] | None = None + block_shape: Sequence[int | None] | None = None index_map: Callable[..., Any] | None = None memory_space: Any | None = dataclasses.field(kw_only=True, default=None) indexing_mode: IndexingMode = dataclasses.field(kw_only=True, default=blocked) - def __init__( - self, - block_shape: Any | None = None, - index_map: Any | None = None, - *, - memory_space: Any | None = None, - indexing_mode: IndexingMode = blocked, - ) -> None: - if callable(block_shape): - # TODO(slebedev): Remove this code path and update the signature of - # __init__ after October 1, 2024. - message = ( - "BlockSpec now expects ``block_shape`` to be passed before" - " ``index_map``. Update your code by swapping the order of these" - " arguments. For example, ``pl.BlockSpace(lambda i: i, (42,))``" - " should be written as ``pl.BlockSpec((42,), lambda i: i)``." - ) - if deprecations.is_accelerated("pallas-block-spec-order"): - raise TypeError(message) - warnings.warn(message, DeprecationWarning) - index_map, block_shape = block_shape, index_map - - self.block_shape = block_shape - self.index_map = index_map - self.memory_space = memory_space - self.indexing_mode = indexing_mode - def to_block_mapping( self, origin: OriginStr, @@ -496,7 +468,7 @@ def to_block_mapping( mapping = BlockMapping( block_shape=mapped_block_shape, - block_aval=block_aval, + transformed_block_aval=block_aval, # There are no transforms by default index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), index_map_src_info=index_map_src_info, indexing_mode=self.indexing_mode, @@ -523,7 +495,7 @@ def __repr__(self): class MemoryRefTransform(Protocol): """Transforms a memory reference on load or store.""" - def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef: + def undo(self, ref: TransformedRef) -> TransformedRef: raise NotImplementedError("Abstract evaluation not implemented.") @@ -533,8 +505,10 @@ class BlockMapping: See the `check_invariants` method for precise specification. """ + # TODO(apaszke,sharadmv): Replace mapped dims in block_shape with a transform. + # After all, it's just indexing out singleton dimensions. block_shape: tuple[Mapped | int, ...] - block_aval: AbstractMemoryRef # The block ref aval + transformed_block_aval: AbstractMemoryRef index_map_jaxpr: jax_core.ClosedJaxpr index_map_src_info: NameAndSrcInfo indexing_mode: IndexingMode @@ -546,8 +520,8 @@ def check_invariants(self) -> None: if not config.enable_checks.value: return unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped) - assert unmapped_block_shape == self.block_aval.shape, ( - self.block_shape, self.block_aval) + assert unmapped_block_shape == self.ref_aval.shape, ( + self.block_shape, self.ref_aval.shape) assert len(self.block_shape) == len(self.array_shape_dtype.shape), ( self.block_shape, self.array_shape_dtype ) @@ -568,12 +542,21 @@ def replace(self, **kwargs): return new_self @property - def ref_aval(self) -> AbstractMemoryRef: + def block_aval(self) -> AbstractMemoryRef: + # If you hit this, make sure you take transforms into account and use either + # ref_aval or transformed_block_aval. + assert not self.transforms, "Lowering failed to handle transforms" + return self.transformed_block_aval + + @property + def ref_aval(self) -> AbstractMemoryRef | TransformedRef: """Returns the abstract value of the Ref after transformations.""" - block_aval = self.block_aval - for transform in self.transforms: - block_aval = transform(block_aval) - return block_aval + if not self.transforms: + return self.transformed_block_aval + ref = TransformedRef(self.transformed_block_aval, ()) + for transform in reversed(self.transforms): + ref = transform.undo(ref) + return ref def compute_start_indices_interpret(self, loop_idx, *args): discharged_jaxpr, discharged_consts = state_discharge.discharge_state( diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index ae76a00a6c17..f52ba9ddd6cd 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -57,9 +57,16 @@ py_library( name = "primitives", srcs = ["primitives.py"], deps = [ + ":core", "//jax", "//jax:core", + "//jax:dtypes", "//jax:mlir", + "//jax:pretty_printer", + "//jax:tree_util", + "//jax:typing", + "//jax:util", + "//jax/_src/pallas", ], ) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 4ff9d894da8f..57f1cad325bb 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -22,6 +22,7 @@ from typing import Any, ClassVar, Literal import jax +from jax._src import config from jax._src import core as jax_core from jax._src import dtypes from jax._src import util @@ -45,6 +46,17 @@ _convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping split_list = util.split_list +_ENABLE_RUNTIME_ASSERT = config.bool_state( + "jax_pallas_enable_runtime_assert", + default=False, + help=( + "If set, enables runtime assertions in the kernel via checkify.check." + " Otherwise, runtime asserts will be ignored unless functionalized" + " using checkify.checkify." + ), +) + + @dataclasses.dataclass(frozen=True) class TPUCompilerParams(pallas_core.CompilerParams): """Mosaic TPU compiler parameters. @@ -203,3 +215,7 @@ def create_tensorcore_mesh(axis_name: str) -> pallas_core.PallasMesh: np.array([TensorCore(i) for i in range(num_cores)]), [axis_name], ) + +def runtime_assert_enabled() -> bool: + """Returns whether runtime asserts are enabled.""" + return _ENABLE_RUNTIME_ASSERT.value diff --git a/jax/_src/pallas/mosaic/error_handling.py b/jax/_src/pallas/mosaic/error_handling.py index f8231f5b24b6..fcc304055453 100644 --- a/jax/_src/pallas/mosaic/error_handling.py +++ b/jax/_src/pallas/mosaic/error_handling.py @@ -151,8 +151,7 @@ def parse_location_string(location_string: str) -> tuple[str, list[RawFrame]]: def traceback_from_raw_frames(frames: list[RawFrame]) -> types.TracebackType: """Constructs a traceback from a list of RawFrame objects.""" xla_frames = [ - xla_client.Frame(frame.filename, frame.func_name, -1, frame.lineno - ) # type: ignore [call-arg] + xla_client.Frame(frame.filename, frame.func_name, -1, frame.lineno) for frame in frames ] return xla_client.Traceback.traceback_from_frames(xla_frames) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 46cbe8e4758b..de77b71f544a 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -26,6 +26,7 @@ from jax import lax from jax import tree_util from jax._src import ad_util +from jax._src import checkify from jax._src import core as jax_core from jax._src import custom_derivatives from jax._src import debugging @@ -36,6 +37,7 @@ from jax._src import prng from jax._src import source_info_util from jax._src import state +from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal @@ -54,6 +56,7 @@ from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import error_handling from jax._src.pallas.mosaic import primitives as tpu_primitives +from jax._src.pallas.mosaic import random as pl_random from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives @@ -200,10 +203,13 @@ def aval_to_ir_type(aval, return ir.MemRefType.get((), sem_type, memory_space=memspace) if dtypes.issubdtype(aval.dtype, dtypes.prng_key): shape = aval.dtype._impl.key_shape - if memory_space is None: - memory_space = TPUMemorySpace.SMEM - if memory_space != TPUMemorySpace.SMEM: - raise ValueError(f"PRNG keys must be stored in SMEM. Got {memory_space}") + if pl_random.is_pallas_impl(aval.dtype._impl): + if memory_space is None: + memory_space = TPUMemorySpace.SMEM + if memory_space != TPUMemorySpace.SMEM: + raise ValueError( + f"PRNG keys must be stored in SMEM. Got {memory_space}" + ) memspace = _memory_space_to_mosaic_attribute(memory_space) return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)), memory_space=memspace) @@ -359,11 +365,16 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, if grid_mapping.get_grid_indices is None: + # Avoid using self.mapped_dims within the function, since doing so will + # introduce a self->_get_grid_indices->self reference cycle that means + # MosaicGridMapping instances can only ever be deleted by GC, rather than + # by their reference counts going to 0. + mapped_dims = self.mapped_dims def _get_grid_indices(indices, maybe_include_mapped_dims: bool): if maybe_include_mapped_dims: return indices return tuple( - idx for i, idx in enumerate(indices) if i not in self.mapped_dims + idx for i, idx in enumerate(indices) if i not in mapped_dims ) self.get_grid_indices = _get_grid_indices @@ -460,7 +471,7 @@ def err_details(): "has block shape " f"{bm.block_shape}, array shape {bm.array_shape_dtype.shape}, " # TODO(necula): add index_map source location info - f"and index_map returning {bm.index_map_jaxpr.jaxpr.outvars}, in " + f"and index_map {bm.index_map_jaxpr.jaxpr}, in " f"memory space {bm.block_aval.memory_space}." "\nSee details at https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec") if rank < 1: @@ -475,7 +486,8 @@ def err_details(): "only blocks having the same block shape as the array shape " "and a trivial index_map (returning all 0s)." + err_details()) - unmapped_bs = [1 if bs is pallas_core.mapped else bs for bs in bm.block_shape] + unmapped_bs = [ + 1 if bs is pallas_core.mapped else bs for bs in bm.block_shape] bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1] if rank >= 2: bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2] @@ -537,10 +549,13 @@ def lower_jaxpr_to_module( if grid: for i, bm in enumerate(grid_mapping.block_mappings): func_name = f"transform_{i}" - # ANY operands don't support windowing and require empty window_params. + # ANY and SEMAPHORE operands don't support windowing and require empty window_params. tpu_memory_space = _memory_space_to_tpu_memory_space( bm.block_aval.memory_space) - if tpu_memory_space == tpu_core.TPUMemorySpace.ANY: + if ( + tpu_memory_space == tpu_core.TPUMemorySpace.ANY + or tpu_memory_space == tpu_core.TPUMemorySpace.SEMAPHORE + ): # We checked above that the block does not require windowing. window_params.append(ir.DictAttr.get()) continue @@ -820,15 +835,16 @@ def write_env(var: jax_core.Var, val): except LoweringException: raise # We only add the extra info to the innermost exception. except Exception as e: - raise LoweringException( - f"Exception while lowering eqn:\n {eqn}\nWith context:\n " - f" {rule_context}\nWith inval" - f" shapes={map(lambda t: getattr(t, 'shape', None), invals)}\nWith" - " inval" - f" types={map(lambda t: getattr(t, 'type', None), invals)}\nIn" - f" jaxpr:\n{jaxpr}" - f"\nException: {e}" - ) from e + msg = (f"{type(e).__name__}: {e}\n" + + "Additional diagnostics: \n" + + f"Failing jaxpr equation: {eqn}\n") + new_error = LoweringException(msg) + # We insert the traceback here so that the user code shows + # up in the traceback for the post-transform error. + if source_info.traceback is not None: + tb = source_info.traceback.as_python_traceback() + new_error.__traceback__ = traceback_util.filter_traceback(tb) + raise new_error from e else: raise NotImplementedError( "Unimplemented primitive in Pallas TPU lowering: " @@ -1121,7 +1137,9 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): ref_type = ir.MemRefType(ref.type) is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space" (aval_out,) = ctx.avals_out - if isinstance(aval_out.dtype, prng.KeyTy): + if isinstance(aval_out.dtype, prng.KeyTy) and pl_random.is_pallas_impl( + aval_out.dtype._impl + ): if not is_smem_load: raise ValueError("PRNG keys must be loaded from SMEM. Did you set " "the memory space to TPUMemorySpace.SMEM in the " @@ -1221,7 +1239,7 @@ def _maybe_cast_load_to_bool( if out_aval.dtype != jnp.bool_: return val load_scalar_type = _dtype_to_ir_type(BOOL_MEMREF_TYPE) - pred = _cmpi_lowering_types[lax.ne_p] + pred = _cmpsi_lowering_types[lax.ne_p] predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) const_zero = ir.IntegerAttr.get(load_scalar_type, 0) if out_aval.shape: # Vector case. @@ -1615,7 +1633,12 @@ def _convert_helper(x, *, to_dtype): x = x.astype(jnp.float32) return x.astype(to_dtype) if jnp.issubdtype(from_dtype, jnp.floating): - if jnp.issubdtype(to_dtype, jnp.signedinteger): + if jnp.issubdtype(to_dtype, np.dtype("bool")): + # Cast to float32 rather than int32 because 0 < |x| < 1 rounds to 0, + # leading to false in bool. However, convert_element_type(x, bool) + # returns true. It's handled correctly when x is float32. + x = x.astype(jnp.float32) + elif jnp.issubdtype(to_dtype, jnp.signedinteger): if from_dtype.itemsize < 4: x = x.astype(jnp.float32) if to_dtype.itemsize < 4: @@ -1623,10 +1646,7 @@ def _convert_helper(x, *, to_dtype): minval, maxval = jnp.iinfo(to_dtype).min, jnp.iinfo(to_dtype).max x = jnp.clip(x, minval, maxval) return x.astype(jnp.int32).astype(to_dtype) - return x.astype(to_dtype) - elif jnp.issubdtype(to_dtype, np.dtype("bool")): - x = x.astype(jnp.int32) - return x.astype(jnp.float32) + return x.astype(to_dtype) raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}") def _convert_element_type_lowering_rule( @@ -1641,6 +1661,10 @@ def _convert_element_type_lowering_rule( if old_dtype == new_dtype: return x + + if new_dtype.itemsize == 8: + raise NotImplementedError("64-bit types are not supported") + if jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype( new_dtype, jnp.floating ): @@ -1673,24 +1697,32 @@ def _convert_element_type_lowering_rule( ): return arith.extui(out_type, x) elif ( - jnp.issubdtype(old_dtype, jnp.integer) + ( + (is_float := jnp.issubdtype(old_dtype, jnp.floating)) + or jnp.issubdtype(old_dtype, jnp.integer) + ) and new_dtype == jnp.bool_ and old_dtype.itemsize == 4 ): - pred = _cmpi_lowering_types[lax.ne_p] - predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) + # Lower float32 or (u)int32 -> bool to cmp neq %in, 0 const_type = _dtype_to_ir_type(old_dtype) - const_zero = ir.IntegerAttr.get(const_type, 0) + if is_float: + pred = _cmpf_lowering_types[lax.ne_p] + const_zero = ir.FloatAttr.get(const_type, 0) + op = arith.CmpFOp + else: + pred = _cmpsi_lowering_types[lax.ne_p] + const_zero = ir.IntegerAttr.get(const_type, 0) + op = arith.CmpIOp + predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) if in_aval.shape: in_type = aval_to_ir_type(in_aval, is_kernel_boundary=False) vector_zeros = arith.ConstantOp( in_type, ir.DenseElementsAttr.get_splat(in_type, const_zero), ) - return arith.CmpIOp(predicate, x, vector_zeros).result - return arith.CmpIOp( - predicate, x, arith.ConstantOp(const_type, const_zero) - ).result + return op(predicate, x, vector_zeros).result + return op(predicate, x, arith.ConstantOp(const_type, const_zero)).result return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype), multiple_results=False)(ctx, x) @@ -2001,6 +2033,20 @@ def _sin_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.sin_p] = _sin_lowering_rule +def _cos_lowering_rule(ctx: LoweringRuleContext, x): + return math.CosOp(x).result + + +lowering_rules[lax.cos_p] = _cos_lowering_rule + + +def _tan_lowering_rule(ctx: LoweringRuleContext, x): + return math.TanOp(x).result + + +lowering_rules[lax.tan_p] = _tan_lowering_rule + + def _tanh_lowering_rule(ctx: LoweringRuleContext, x): return math.TanhOp(x).result @@ -2034,55 +2080,97 @@ def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method): lowering_rules[lax.round_p] = _round_lowering_rule -# See https://mlir.llvm.org/docs/Dialects/ArithOps/#arithcmpi-arithcmpiop for -# the mapping from comparison type to integer predicates for int comparisons. -_cmpi_lowering_types = { - lax.eq_p: 0, - lax.ne_p: 1, - lax.lt_p: 2, - lax.le_p: 3, - lax.gt_p: 4, - lax.ge_p: 5, +def _ceil_lowering_rule(ctx: LoweringRuleContext, x): + return math.CeilOp(x).result + + +lowering_rules[lax.ceil_p] = _ceil_lowering_rule + + +def _floor_lowering_rule(ctx: LoweringRuleContext, x): + return math.FloorOp(x).result + + +lowering_rules[lax.floor_p] = _floor_lowering_rule + + +def _clz_lowering_rule(ctx: LoweringRuleContext, x): + return math.CountLeadingZerosOp(x).result + +lowering_rules[lax.clz_p] = _clz_lowering_rule + + +def _population_count_lowering_rule(ctx: LoweringRuleContext, x): + aval_out = ctx.avals_out[0] + if aval_out.shape == (): + raise ValueError("Population count is not supported on scalars") + return math.CtPopOp(x).result + +lowering_rules[lax.population_count_p] = _population_count_lowering_rule + + +# Mapping for signed integer comparisons. +_cmpsi_lowering_types = { + lax.eq_p: arith.CmpIPredicate.eq, + lax.ne_p: arith.CmpIPredicate.ne, + lax.lt_p: arith.CmpIPredicate.slt, + lax.le_p: arith.CmpIPredicate.sle, + lax.gt_p: arith.CmpIPredicate.sgt, + lax.ge_p: arith.CmpIPredicate.sge, +} + +# Mapping for unsigned integer comparisons. +_cmpui_lowering_types = { + lax.eq_p: arith.CmpIPredicate.eq, + lax.ne_p: arith.CmpIPredicate.ne, + lax.lt_p: arith.CmpIPredicate.ult, + lax.le_p: arith.CmpIPredicate.ule, + lax.gt_p: arith.CmpIPredicate.ugt, + lax.ge_p: arith.CmpIPredicate.uge, } -# See https://mlir.llvm.org/docs/Dialects/ArithOps/#arithcmpf-arithcmpfop for -# the mapping from comparison type to integer predicate for float comparisons. +# Mapping for floating point comparisons. _cmpf_lowering_types = { - lax.eq_p: 1, - lax.ne_p: 6, - lax.lt_p: 4, - lax.le_p: 5, - lax.gt_p: 2, - lax.ge_p: 3, + lax.eq_p: arith.CmpFPredicate.OEQ, + lax.ne_p: arith.CmpFPredicate.ONE, + lax.lt_p: arith.CmpFPredicate.OLT, + lax.le_p: arith.CmpFPredicate.OLE, + lax.gt_p: arith.CmpFPredicate.OGT, + lax.ge_p: arith.CmpFPredicate.OGE, } def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) x_aval, y_aval = ctx.avals_in - dtypes = x_aval.dtype, y_aval.dtype - if all( - jnp.issubdtype(dtype, jnp.integer) | jnp.issubdtype(dtype, jnp.bool_) - for dtype in dtypes - ): + if x_aval.dtype != y_aval.dtype: + raise ValueError( + f"Mixed dtype operands in cmp: {x_aval.dtype}, {y_aval.dtype}" + ) + dtype = x_aval.dtype - # Handle bool comparisons by casting to int32. + # Handle bool comparisons by casting to int32. + if jnp.issubdtype(dtype, jnp.bool_): bool_cast_to = _dtype_to_ir_type(jnp.dtype("int32")) true_ = ir_constant(1, mlir_type=bool_cast_to) false_ = ir_constant(0, mlir_type=bool_cast_to) - if jnp.issubdtype(dtypes[0], jnp.bool_): - x = arith.SelectOp(x, true_, false_) - if jnp.issubdtype(dtypes[1], jnp.bool_): - y = arith.SelectOp(y, true_, false_) - pred = _cmpi_lowering_types[prim] + x = arith.SelectOp(x, true_, false_) + y = arith.SelectOp(y, true_, false_) + dtype = jnp.dtype("int32") + + if jnp.issubdtype(dtype, jnp.integer): + is_uint = jnp.issubdtype(dtype, jnp.unsignedinteger) + pred = (_cmpui_lowering_types if is_uint else _cmpsi_lowering_types)[prim] predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) return arith.CmpIOp(predicate, x, y).result - elif all(jnp.issubdtype(dtype, jnp.floating) for dtype in dtypes): + + if jnp.issubdtype(dtype, jnp.floating): pred = _cmpf_lowering_types[prim] predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) return arith.CmpFOp(predicate, x, y).result - raise NotImplementedError("Mixed dtype operands in cmp") + + raise NotImplementedError(f"Unsupported dtype in cmp: {dtype}") lowering_rules[lax.eq_p] = functools.partial(_cmp_lowering_rule, lax.eq_p) @@ -2736,8 +2824,10 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, device_id_type: tpu_primitives.DeviceIdType): del device_id_type - sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, args) - sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in) + (_, _, ref, transforms, sem, sem_transforms, _, _, _) = tree_util.tree_unflatten( + tree, args) + (_, _, ref_aval, _, sem_aval, _, _, _, _) = tree_util.tree_unflatten( + tree, ctx.avals_in) block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes) ref_block_shape = block_shapes[2] ref, _ = _transform_ref(ref, ref_aval.dtype, ref_block_shape, transforms) @@ -2851,8 +2941,13 @@ def random_bits_lowering(ctx, keys, *, bit_width, shape): assert bit_width == 32, "Only 32-bit PRNG supported." aval, = ctx.avals_in impl = aval.dtype._impl - bits_lowering = lower_fun( - impl.random_bits, multiple_results=False) + _proxy_fn = impl.random_bits + if not pl_random.is_pallas_impl(impl): + def new_lowering(key, bit_width, shape): + key = jax.random.key_data(key).astype(jnp.uint32) + return impl.random_bits(key, bit_width, shape) + _proxy_fn = new_lowering + bits_lowering = lower_fun(_proxy_fn, multiple_results=False) return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape) lowering_rules[prng.random_bits_p] = random_bits_lowering @@ -2867,7 +2962,10 @@ def random_fold_in_lowering(ctx, keys, msgs): def random_unwrap_lowering(ctx, key): - del ctx + keys_aval = ctx.avals_in[0] + impl = keys_aval.dtype._impl + if not pl_random.is_pallas_impl(impl): + return key assert isinstance(key, KeyScalarBundle) # Convert to a vector. if tuple(key.key_shape) != (1, 1): @@ -2884,7 +2982,9 @@ def random_unwrap_lowering(ctx, key): def random_wrap_lowering(ctx, key_data, *, impl): - del ctx, impl + del ctx + if not pl_random.is_pallas_impl(impl): + return key_data if isinstance(key_data.type, ir.VectorType): # If the key data lives in vregs, need to unpack it to sregs. key_data_list = [] @@ -2906,6 +3006,78 @@ def random_wrap_lowering(ctx, key_data, *, impl): lowering_rules[prng.random_wrap_p] = random_wrap_lowering +def _checkify_lowering_rule( + ctx: LoweringRuleContext, *err_args, err_tree, debug): + if not tpu_core.runtime_assert_enabled(): + if debug: + return [] + else: + raise LoweringException("Non-debug check must be functionalized. " + "Enable runtime asserts with " + "--jax_pallas_enable_runtime_assert " + "or functionalize with checkify.check.") + + assert ctx.lowering_context.ir_context.allow_unregistered_dialects, ( + "allow_unregistered_dialects must be set to True for " + "runtime assert check.") + error = jax.tree.unflatten(err_tree, err_args) + assert len(error._pred) == 1 + assert len(error._metadata) == 1 + assert len(error._payload) == 1 + pred = list(error._pred.items())[0][1] + metadata = list(error._metadata.items())[0] + payload = list(error._payload.items())[0][1] + exception_tree = metadata[1] + exception = jax.tree.unflatten(exception_tree, payload) + assert isinstance(exception, checkify.FailedCheckError) + + # check_p has an inverted predicate compared to assert, + # so we need to compute not(pred) here. + out_scalar_type = _dtype_to_ir_type(jnp.dtype('bool')) + minus_one = ir_constant(-1, out_scalar_type) + not_pred = arith.XOrIOp(pred, minus_one).result + attrs = {"msg": ir.StringAttr.get(exception.fmt_string)} + ir.Operation.create("cf.assert", + operands=(not_pred,), + attributes=attrs) + return [] +lowering_rules[checkify.check_p] = _checkify_lowering_rule + +def _threefry2x32_lowering(ctx, k1, k2, m1, m2): + def _lower_fun(k1, k2, m1, m2): + with jax.named_scope("threefry2x32"): + res = prng._threefry2x32_lowering(k1, k2, m1, m2, use_rolled_loops=False) + return res + + threefry_lowering = lower_fun(_lower_fun, multiple_results=True) + return threefry_lowering(ctx, k1, k2, m1, m2) + + +lowering_rules[prng.threefry2x32_p] = _threefry2x32_lowering + + +def _iota_2x32_shape_lowering(ctx, *, shape): + total_elements = np.prod(shape) + if total_elements > np.iinfo(jnp.int32).max: + raise NotImplementedError(f"Iota with >{np.iinfo(jnp.int32).max} items.") + + def _lower_fun(shape): + iota_data = jnp.zeros(shape, dtype=jnp.int32) + multiplier = 1 + for dim in range(len(shape)-1, -1, -1): + counts_lo = lax.broadcasted_iota( + dtype=jnp.int32, shape=shape, dimension=dim + ) + iota_data += counts_lo * multiplier + multiplier *= shape[dim] + counts_hi = jnp.zeros(shape, dtype=jnp.int32) + return counts_hi, iota_data + + iota_lowering = lower_fun(_lower_fun, multiple_results=True) + return iota_lowering(ctx, shape=shape) + + +lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering # Lowering for shard_map diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index b09d36a9d3b2..2bf96511b64e 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -151,7 +151,7 @@ def pallas_call_tpu_lowering_rule( mlir_ctx.load_all_available_dialects() tpu.register_dialect(mlir_ctx) def lower_module(for_verification: bool): - if for_verification: + if for_verification or tpu_core.runtime_assert_enabled(): mlir_ctx.allow_unregistered_dialects = True with mlir_ctx, ir.Location.unknown(mlir_ctx): dimension_semantics = mosaic_params.get("dimension_semantics", None) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 1514f67a9e33..0112b3cb4dbb 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -99,11 +99,6 @@ def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]: return (*(1,) * len(leading_dims), second_minor_tiling, _TILING[1]) -def _mod(a, n): - """"Calculates a mod n for positive and negative a with |a| <= n.""" - return lax.rem(a + n, n) - - def _round_up_to_nearest_multiple(s: int, multiple: int) -> int: if s % multiple == 0: return s @@ -536,12 +531,46 @@ def accumulate(self): is_leaf=lambda x: isinstance(x, BufferedRef)) +def _filter_indices( + indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...] +) -> tuple[int | jax.Array, ...]: + return tuple( + 0 if isinstance(g, int) and g == 1 else i + for i, g in zip(indices, grid, strict=True) + ) + + +def _next_index( + indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...] +) -> tuple[int | jax.Array, ...]: + out = [] + carry: bool | jax.Array = True + for i, g in reversed(list(zip(indices, grid, strict=True))): + inc = jax.lax.select(carry, i + 1, i) + carry = inc == g + out.append(jax.lax.select(carry, 0, inc)) + return _filter_indices(tuple(reversed(out)), grid) + + +def _prev_index( + indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...] +) -> tuple[int | jax.Array, ...]: + out = [] + borrow: bool | jax.Array = True + for i, g in reversed(list(zip(indices, grid, strict=True))): + dec = jax.lax.select(borrow, i - 1, i) + borrow = dec == -1 + out.append(jax.lax.select(borrow, g - 1, dec)) + return _filter_indices(tuple(reversed(out)), grid) + + class Scheduler: """Sequences input and output copies and waits for a pipeline.""" def __init__( self, step: jax.Array, + indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...], grid_offsets: tuple[int | jax.Array, ...], first_cycle=None, @@ -553,6 +582,7 @@ def __init__( Args: step: inner step number. + indices: current grid indices. grid: pallas grid for BufferedRefs. grid_offsets: offsets for grid indices (used for megacore). first_cycle: whether this is the first invocation of the pipeline. @@ -579,17 +609,18 @@ def __init__( self.first_step_ever = first_cycle & self.first_step self.last_step_ever = last_cycle & self.last_step - # Cyclic steps - self.prev_step = _mod(step - 1, self.num_steps) - self.next_step = _mod(step + 1, self.num_steps) - # Derived grid indices for present, previous, and next steps. - self.indices = _get_indices(step, grid, grid_offsets) - self.prev_indices = _get_indices( - self.prev_step, grid, grid_offsets + self.indices = tuple( + i + j for i, j in zip(indices, grid_offsets, strict=True) + ) + + self.prev_indices = tuple( + i + j + for i, j in zip(_prev_index(indices, grid), grid_offsets, strict=True) ) - self.next_indices = _get_indices( - self.next_step, grid, grid_offsets + self.next_indices = tuple( + i + j + for i, j in zip(_next_index(indices, grid), grid_offsets, strict=True) ) @contextmanager @@ -1100,10 +1131,11 @@ def pipeline( schedule = map_brefs( lambda _, x: get_pipeline_schedule(x), allocations, schedule) - def loop_body(step, _): + def loop_body(step, indices): nonlocal allocations scheduler = Scheduler( step, + indices, grid, grid_offsets=grid_offsets, first_cycle=first_cycle, @@ -1147,10 +1179,10 @@ def loop_body(step, _): lambda: None) map_brefs(scheduler.finalize, brefs, refs, schedule) - return () + return _next_index(indices, grid) # run pipeline - lax.fori_loop(0, num_steps, loop_body, ()) + lax.fori_loop(0, num_steps, loop_body, (0,) * len(grid)) return pipeline diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index aab214a2d700..7aab30ffc2ab 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -431,18 +431,34 @@ def __post_init__(self): def is_remote(self): return self.src_sem is not None + def _get_args_and_tree(self, swap_src_and_dst: bool = False): + if swap_src_and_dst: + return tree_util.tree_flatten(( + self.dst_ref, + self.dst_transforms, + self.src_ref, + self.src_transforms, + self.src_sem, + self.src_sem_transforms, + self.dst_sem, + self.dst_sem_transforms, + self.device_id, + )) + else: + return tree_util.tree_flatten(( + self.src_ref, + self.src_transforms, + self.dst_ref, + self.dst_transforms, + self.dst_sem, + self.dst_sem_transforms, + self.src_sem, + self.src_sem_transforms, + self.device_id, + )) + def start(self): - flat_args, tree = tree_util.tree_flatten(( - self.src_ref, - self.src_transforms, - self.dst_ref, - self.dst_transforms, - self.dst_sem, - self.dst_sem_transforms, - self.src_sem, - self.src_sem_transforms, - self.device_id, - )) + flat_args, tree = self._get_args_and_tree() dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type) def wait(self): @@ -451,27 +467,20 @@ def wait(self): self.wait_recv() def wait_recv(self): - wait_args, tree = tree_util.tree_flatten(( - self.dst_sem, - self.dst_sem_transforms, - self.dst_ref, - self.dst_transforms, - )) + flat_args, tree = self._get_args_and_tree() dma_wait_p.bind( - *wait_args, tree=tree, device_id_type=self.device_id_type + *flat_args, tree=tree, device_id_type=self.device_id_type ) def wait_send(self): if not self.is_remote: raise ValueError("Cannot `wait_send` on a local copy.") - wait_args, tree = tree_util.tree_flatten(( - self.src_sem, - self.src_sem_transforms, - self.src_ref, - self.src_transforms, - )) + # We swap src and dst since by default dma_wait_p waits on the dst_sem + # As a clean up, maybe we could modify the primitive to have a + # `wait_on_send` bool. + flat_args, tree = self._get_args_and_tree(swap_src_and_dst=True) dma_wait_p.bind( - *wait_args, tree=tree, device_id_type=self.device_id_type + *flat_args, tree=tree, device_id_type=self.device_id_type ) @@ -689,7 +698,17 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, del settings invars = eqn.invars tree = eqn.params["tree"] - sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, invars) + ( + _, + _, + ref, + transforms, + sem, + sem_transforms, + _, + _, + _, + ) = tree_util.tree_unflatten(tree, invars) return pp.concat([ pp.text("dma_wait"), pp.text(" "), @@ -702,29 +721,38 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, def dma_wait_discharge_rule(in_avals, out_avals, *args, tree, device_id_type): + # TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start del out_avals, device_id_type - (sem, sem_transforms, ref, ref_transforms) = tree_util.tree_unflatten( - tree, args - ) - ( - sem_aval, - sem_transforms_avals, + _, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = ( + tree_util.tree_unflatten(tree, args)) + (_, + src_ref_transforms_avals, _, - ref_transforms_avals, + dst_ref_transforms_avals, + dst_sem_aval, + dst_sem_transforms_avals, + src_sem_aval, + src_sem_transforms_avals, + device_id_aval, ) = tree_util.tree_unflatten(tree, in_avals) - num_sem_transforms = len(tree_util.tree_leaves(sem_transforms_avals)) - num_transforms = len(tree_util.tree_leaves(ref_transforms_avals)) - updates = state_discharge.transform_array(ref, ref_transforms) + num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals)) + num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals)) + updates = state_discharge.transform_array(dst_ref, dst_ref_transforms) copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - sem_value = _transform_semaphore(sem, sem_transforms, sem_aval) + sem_value = _transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) _, new_sem = state_discharge.transform_swap_array( - sem, sem_transforms, sem_value - copy_size + dst_sem, dst_sem_transforms, sem_value - copy_size ) - new_vals = (new_sem,) # sem - new_vals += (None,) * num_sem_transforms + new_vals = (None,) # src_ref + new_vals += (None,) * len(tree_util.tree_leaves(src_ref_transforms_avals)) new_vals += (None,) # ref - new_vals += (None,) * num_transforms + new_vals += (None,) * num_transforms # ref_transforms + new_vals += (new_sem,) # sem + new_vals += (None,) * num_sem_transforms + new_vals += (None,) * len(tree_util.tree_leaves(src_sem_aval)) # src_sem + new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals)) + new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id return new_vals, [] state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule) diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py index 68a4fe508917..16dc5ee1fe56 100644 --- a/jax/_src/pallas/mosaic/random.py +++ b/jax/_src/pallas/mosaic/random.py @@ -46,6 +46,12 @@ def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray: pallas_key_data = (jax.vmap(generate_key))(key) return jax_api_random.wrap_key_data(pallas_key_data, impl="pallas_tpu") + +def is_pallas_impl(impl: jax_prng.PRNGImpl) -> bool: + """Returns True if the PRNGImpl is a Pallas-specific implementation.""" + return impl == tpu_key_impl or impl == tpu_internal_stateful_impl + + def _seed_func(seed: jnp.int32): seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32) return (seed_data + seed).astype(jnp.uint32) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index fd291b201fa1..91616948be49 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -77,6 +77,7 @@ pytype_strict_library( "//jax:mosaic_gpu", "//jax:tree_util", "//jax/_src/pallas", + "//jaxlib/mlir:ir", ] + py_deps("numpy"), ) @@ -89,8 +90,9 @@ pytype_strict_library( "//jax", "//jax:core", "//jax:effects", - "//jax:mlir", "//jax:mosaic_gpu", + "//jax:tree_util", + "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", ], diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index 187a84478c65..862a661e24b9 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -11,22 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# TODO(slebedev): Move these imports to ``jax.experimental.pallas``. - -from jax._src.pallas.mosaic_gpu.core import Barrier -from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec -from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams -from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace -from jax._src.pallas.mosaic_gpu.core import TilingTransform -from jax._src.pallas.mosaic_gpu.core import TransposeTransform -from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC -from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem -from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem -from jax._src.pallas.mosaic_gpu.primitives import wait_barrier -from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem -from jax._src.pallas.mosaic_gpu.primitives import wgmma -from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait - -GMEM = GPUMemorySpace.GMEM -SMEM = GPUMemorySpace.SMEM diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index fe8daf43e995..efd908407e52 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -14,17 +14,20 @@ """Contains GPU-specific Pallas abstractions.""" +import abc from collections.abc import Sequence import dataclasses import enum -from typing import Any, ClassVar, Literal, Protocol +from typing import Any, ClassVar, Literal from jax._src import core as jax_core from jax._src import dtypes from jax._src import tree_util from jax._src.pallas import core as pallas_core +from jax._src.state.types import Transform import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp +from jaxlib.mlir import ir AbstractMemoryRef = pallas_core.AbstractMemoryRef @@ -43,16 +46,24 @@ class GPUCompilerParams(pallas_core.CompilerParams): executed sequentially. max_concurrent_steps: The maximum number of sequential stages that are active concurrently. Defaults to 1. + delay_release: The number of steps to wait before reusing the input/output + references. Defaults to 0, and must be strictly smaller than + max_concurrent_steps. Generally, you'll want to set it to 1 if you don't + await the WGMMA in the body. """ PLATFORM: ClassVar[str] = "mosaic_gpu" approx_math: bool = False dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None max_concurrent_steps: int = 1 + delay_release: int = 0 class GPUMemorySpace(enum.Enum): + #: Global memory. GMEM = "gmem" + #: Shared memory. SMEM = "smem" + #: Registers. REGS = "regs" def __str__(self) -> str: @@ -63,10 +74,17 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): return pallas_core.MemoryRef(shape, dtype, memory_space=self) -class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol): +class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC): + @abc.abstractmethod def to_gpu_transform(self) -> mgpu.MemRefTransform: - ... + pass + def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: + return aval.update( + shape=self.to_gpu_transform().transform_shape(aval.shape) + ) + +Index = slice | int | ir.Value @dataclasses.dataclass(frozen=True) class TilingTransform(MemoryRefTransform): @@ -79,61 +97,172 @@ class TilingTransform(MemoryRefTransform): tiling: tuple[int, ...] - def __call__( - self, block_aval: pallas_core.AbstractMemoryRef - ) -> pallas_core.AbstractMemoryRef: - block_shape = block_aval.shape - old_tiled_dims = block_shape[-len(self.tiling) :] - num_tiles = tuple( - block_dim // tiling_dim - for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling) - ) - rem = ( - block_dim % tiling_dim - for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling) - ) - if any(rem): - raise ValueError( - f"Block shape {block_shape} is not divisible by tiling {self.tiling}" - ) - new_block_shape = block_shape[: -len(self.tiling)] + num_tiles + self.tiling - return block_aval.update( - inner_aval=block_aval.inner_aval.update(shape=new_block_shape) + def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: + return dataclasses.replace( + ref, transforms=(*ref.transforms, UntileRef(self.tiling)) ) def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class UntileRef(Transform): + tiling: tuple[int, ...] + + def transform_shape(self, shape): + if shape is None: + return None + assert shape[-len(self.tiling) :] == self.tiling + shape = shape[: -len(self.tiling)] # Drop tiling + return shape[: -len(self.tiling)] + tuple( + block_dim * tiling_dim + for block_dim, tiling_dim in zip(shape[-len(self.tiling) :], self.tiling) + ) + + def transform_dtype(self, dtype): + return dtype + + def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]: + if not all(isinstance(idx, slice) for idx in idxs): + raise NotImplementedError("Non-slice indices are not supported") + untiled_idxs = idxs[: -len(self.tiling)] + tiled_idxs = idxs[-len(self.tiling) :] + idxs_after_tiling = [] + for idx, tile in zip(tiled_idxs, self.tiling): + assert isinstance(idx, slice) + if idx.step is not None and idx.step != 1: + raise NotImplementedError("Strided slices unsupported") + if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): + raise ValueError("Non-empty slices must be tile aligned") + idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) + return (*untiled_idxs, *idxs_after_tiling, *(slice(None) for _ in self.tiling)) + + def tree_flatten(self): + return (), (self.tiling,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + +def _perm_inverse(permutation: tuple[int, ...]) -> tuple[int, ...]: + inverse = [-1] * len(permutation) + for i, p in enumerate(permutation): + inverse[p] = i + return tuple(inverse) + + @dataclasses.dataclass(frozen=True) class TransposeTransform(MemoryRefTransform): """Transpose a tiled memref.""" - permutation: tuple[int, ...] - def __call__( - self, block_aval: pallas_core.AbstractMemoryRef - ) -> pallas_core.AbstractMemoryRef: - shape = block_aval.shape # pytype: disable=attribute-error - return block_aval.update( - inner_aval=block_aval.inner_aval.update( - shape=self.to_gpu_transform().transform_shape(shape) - ) + def __post_init__(self): + if set(self.permutation) != set(range(len(self.permutation))): + raise ValueError(f"Permutation {self.permutation} is not a permutation.") + + def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: + return dataclasses.replace( + ref, + transforms=( + *ref.transforms, + TransposeRef(_perm_inverse(self.permutation)), + ), ) def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(self.permutation) +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class TransposeRef(Transform): + permutation: tuple[int, ...] + + def transform_shape(self, shape): + if shape is None: + return None + return tuple(shape[i] for i in self.permutation) + + def transform_dtype(self, dtype): + return dtype + + def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]: + return tuple(idxs[i] for i in _perm_inverse(self.permutation)) + + def tree_flatten(self): + return (), (self.permutation,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + @dataclasses.dataclass(frozen=True) -class GPUBlockMapping(pallas_core.BlockMapping): - swizzle: int | None = None +class SwizzleTransform(MemoryRefTransform): + swizzle: int + + def __post_init__(self): + if self.swizzle not in {32, 64, 128}: + raise ValueError( + f"Swizzle {self.swizzle} is not supported. Only 32, 64 and 128 are" + " accepted." + ) + + def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: + return dataclasses.replace( + ref, transforms=(*ref.transforms, UnswizzleRef(self.swizzle)) + ) + + def to_gpu_transform(self) -> mgpu.MemRefTransform: + raise RuntimeError("SwizzleTransform does not have a GPU transform.") + + def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: + swizzle_elems = self.swizzle // aval.dtype.itemsize + if swizzle_elems != aval.shape[-1]: + raise ValueError( + f"Swizzle {self.swizzle} requires the trailing dimension to be of" + f" size {swizzle_elems}, but got shape: {aval.shape}" + ) + return aval + + +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class UnswizzleRef(Transform): + swizzle: int + + def untransform_index(self, idxs: tuple[Index, ...]) -> tuple[Index, ...]: + if not idxs: + return idxs + if not all(isinstance(idx, slice) for idx in idxs): + raise NotImplementedError("Non-slice indices are not supported") + last_idx = idxs[-1] + assert isinstance(last_idx, slice) + if last_idx.step is not None and last_idx.step != 1: + raise NotImplementedError("Swizzled dims cannot be sliced") + if (last_idx.start is not None and last_idx.start != 0) or ( + last_idx.stop is not None and last_idx.stop != self.swizzle + ): + raise ValueError("Swizzled dims cannot be sliced") + return idxs + + def tree_flatten(self): + return (), (self.swizzle,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) @dataclasses.dataclass class GPUBlockSpec(pallas_core.BlockSpec): - transforms: MemoryRefTransform | tuple[MemoryRefTransform, ...] = () - swizzle: int | None = None # TODO: apaszke - Swizzle is also a transform. + transforms: Sequence[MemoryRefTransform] = () def to_block_mapping( self, @@ -144,7 +273,7 @@ def to_block_mapping( index_map_tree: tree_util.PyTreeDef, grid: pallas_core.GridMappingGrid, mapped_dims: tuple[int, ...], - ) -> GPUBlockMapping: + ) -> pallas_core.BlockMapping: bm = super().to_block_mapping( origin, array_aval, @@ -153,19 +282,14 @@ def to_block_mapping( grid=grid, mapped_dims=mapped_dims, ) - transforms = self.transforms - if not isinstance(transforms, tuple): - transforms = (transforms,) - return GPUBlockMapping( - block_shape=bm.block_shape, - block_aval=bm.block_aval, - origin=bm.origin, - index_map_jaxpr=bm.index_map_jaxpr, - index_map_src_info=bm.index_map_src_info, - indexing_mode=bm.indexing_mode, - array_shape_dtype=bm.array_shape_dtype, - transforms=transforms, - swizzle=self.swizzle, + block_inner_aval = bm.block_aval.inner_aval + for t in self.transforms: + block_inner_aval = t(block_inner_aval) + return bm.replace( + transformed_block_aval=bm.block_aval.update( + inner_aval=block_inner_aval + ), + transforms=self.transforms, ) @@ -235,10 +359,14 @@ def at_least_vspace(self): return _as_accum(super().at_least_vspace()) def _getitem(self, tracer, idx): - if not _is_trivial_index(idx): - raise NotImplementedError(f"Can only dereference accumulators, not slice ({idx=}).") from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error - return wgmma_accumulator_deref(tracer) + arr = wgmma_accumulator_deref(tracer) + + if not _is_trivial_index(idx): + arr = arr[idx] + + return arr + def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: return WGMMAAbstractAccumulatorRef( diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0d0ac41d11e3..3a5c26e63feb 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -16,12 +16,13 @@ from __future__ import annotations -from collections.abc import Sequence +import collections +from collections.abc import MutableMapping, MutableSequence, Sequence import dataclasses import functools import itertools as it import math -from typing import Any, cast +from typing import Any, Protocol, cast import jax from jax import lax @@ -40,6 +41,7 @@ from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.state import discharge +from jax._src.state import indexing from jax._src.state import primitives as sp import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core as mgpu_core @@ -58,53 +60,129 @@ partial = functools.partial SMEM = gpu_core.SMEM -_smem_estimators = {} + +@dataclasses.dataclass(kw_only=True, frozen=True) +class Resources: + smem_scratch_bytes: int + barrier_counts: collections.Counter[mgpu.Barrier] = dataclasses.field( + default_factory=collections.Counter + ) + + @property + def barriers(self) -> Sequence[mgpu.Barrier]: + return list(self.barrier_counts.elements()) + + def __add__(self, other: Resources) -> Resources: + # TODO(slebedev): Optimize this. + # + # At the moment, if we have run_scoped(b1) followed by run_scoped(b2) + # we will allocate two barriers, even though one would be enough. + return Resources( + smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes, + barrier_counts=self.barrier_counts + other.barrier_counts, + ) + + def __or__(self, other: Resources) -> Resources: + return Resources( + smem_scratch_bytes=max( + self.smem_scratch_bytes, other.smem_scratch_bytes + ), + barrier_counts=self.barrier_counts | other.barrier_counts, + ) + + +class ResourceEstimator(Protocol): + + def __call__(self, *args: Any, **params: Any) -> Resources: + ... + + +_resource_estimators: dict[jax_core.Primitive, ResourceEstimator] = {} -def _regiter_smem_estimator(primitive: jax_core.Primitive): +def _register_resource_estimator(primitive: jax_core.Primitive): def deco(fn): - _smem_estimators[primitive] = fn + _resource_estimators[primitive] = fn return fn return deco -def _estimate_smem_scratch_bytes(jaxpr: jax_core.Jaxpr) -> int: - """Estimates the amount of SMEM scratch bytes required by the kernel.""" - max_used = 0 +def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources: + """Estimates the resources required by the kernel.""" + rs = Resources(smem_scratch_bytes=0) for eqn in jaxpr.eqns: # TODO(slebedev): Add support for other primitives, notably control flow. - rule = _smem_estimators.get(eqn.primitive) + rule = _resource_estimators.get(eqn.primitive) if rule is None: - # Assume that unsupported primitives are neutral wrt SMEM usage. + # Assume that unsupported primitives are neutral wrt resource usage. continue - max_used = max( - max_used, rule(*(invar.aval for invar in eqn.invars), **eqn.params) - ) - return max_used + rs |= rule(*(invar.aval for invar in eqn.invars), **eqn.params) + return rs + + +@_register_resource_estimator(lax.cond_p) +def _cond_resource_estimator(*args, branches) -> int: + del args # Unused. + return functools.reduce( + lambda a, b: a | b, + (_estimate_resources(branch.jaxpr) for branch in branches), + ) + + +@_register_resource_estimator(lax.scan_p) +def _scan_resource_estimator(*args, jaxpr: jax_core.ClosedJaxpr, **params) -> int: + del args, params # Unused. + return _estimate_resources(jaxpr) -@_regiter_smem_estimator(primitives.run_scoped_p) -def _run_scoped_smem_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int: +@_register_resource_estimator(primitives.run_scoped_p) +def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int: del consts # Unused. - in_avals = (v.aval.inner_aval for v in jaxpr.invars) - return sum(math.prod(aval.shape) * aval.dtype.itemsize for aval in in_avals) + smem_scratch_bytes = 0 + barriers = [] + for v in jaxpr.invars: + aval = v.aval + if isinstance(aval.dtype, gpu_core.BarrierType): + barriers.append(mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)) + else: + smem_scratch_bytes += math.prod(aval.shape) * aval.dtype.itemsize + rs = Resources( + smem_scratch_bytes=smem_scratch_bytes, + barrier_counts=collections.Counter(barriers), + ) + return rs + _estimate_resources(jaxpr) -@_regiter_smem_estimator(lax.reduce_sum_p) -def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int: +@_register_resource_estimator(lax.reduce_sum_p) +def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int: if axes != (0,): raise NotImplementedError("No support for axes other than 0 yet") - return 4 * x_aval.dtype.itemsize + return Resources(smem_scratch_bytes=4 * x_aval.dtype.itemsize) @dataclasses.dataclass class ModuleContext: name: str grid_mapping: pallas_core.GridMapping + program_ids: Sequence[ir.Value] | None approx_math: bool runtime_smem: ir.Value # ir.MemRefType - smem_used_bytes: int = 0 + smem_used_bytes: int + runtime_barriers: MutableMapping[ + mgpu.Barrier, MutableSequence[mgpu.BarrierRef] + ] + + def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: + """Reserves a barrier. + + Raises: + RuntimeError: If the barrier is already reserved. + """ + available = self.runtime_barriers.get(barrier, []) + if not available: + raise RuntimeError(f"Barrier {barrier} is already reserved") + return available.pop() # TODO(cperivol): Only return the shapes and figure out the sizes when freeing. def scratch_view( @@ -206,6 +284,41 @@ def _uses_arguments(cjaxpr: jax_core.ClosedJaxpr) -> list[bool]: return pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))[1] +def _check_block_mappings( + block_mappings: Sequence[pallas_core.BlockMapping], + name_and_src_info: pallas_core.NameAndSrcInfo, +) -> None: + def err_details(bm: pallas_core.BlockMapping) -> str: + return ( + f"Block spec for {bm.origin} in pallas_call {name_and_src_info}" + f" has block shape {bm.block_shape}, array shape" + f" {bm.array_shape_dtype.shape}," + # TODO(necula): add index_map source location info + f" and index_map {bm.index_map_jaxpr.jaxpr} in" + f" memory space {bm.transformed_block_aval.memory_space}." + " See details at" + " https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec." + ) + + for bm in block_mappings: + if ( + bm.transformed_block_aval.memory_space == gpu_core.GMEM + and not bm.has_trivial_window() + ): + raise NotImplementedError( + "Mosaic GPU lowering currently requires blocks in GMEM memory space " + "to have same block shape as the array shape " + "and a trivial index_map (returning all 0s).\n\n" + + err_details(bm) + ) + + if not isinstance(bm.indexing_mode, pallas_core.Blocked): + raise NotImplementedError( + "Only Blocked indexing mode is supported in Mosaic GPU lowering.\n\n" + + err_details(bm) + ) + + def lower_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, @@ -215,8 +328,6 @@ def lower_jaxpr_to_module( ) -> LoweringResult: del cost_estimate # Unused. - block_mappings = grid_mapping.block_mappings - assert len(jaxpr.outvars) == 0 assert not grid_mapping.vmapped_dims if len(grid_mapping.grid) > 3: @@ -231,22 +342,15 @@ def lower_jaxpr_to_module( raise NotImplementedError( "Scalar prefetch not supported in Mosaic GPU lowering." ) - if not all( - isinstance(bm.indexing_mode, pallas_core.Blocked) for bm in block_mappings - ): - raise NotImplementedError( - "Only Blocked indexing mode is supported in Mosaic GPU lowering." - ) - with grid_mapping.trace_env(): - jaxpr, _ = pe.dce_jaxpr( - jaxpr, [True] * len(jaxpr.outvars), instantiate=True - ) + block_mappings = grid_mapping.block_mappings + _check_block_mappings(block_mappings, name_and_src_info) block = (128, 1, 1) params = compiler_params.get("mosaic_gpu", {}) approx_math = params.get("approx_math", False) max_concurrent_steps = params.get("max_concurrent_steps", 1) + delay_release = params.get("delay_release", 0) dimension_semantics = params.get("dimension_semantics") if dimension_semantics is None: dimension_semantics = ["parallel"] * len(grid_mapping.grid) @@ -258,6 +362,11 @@ def lower_jaxpr_to_module( sequential_axes = tuple( i for i, s in enumerate(dimension_semantics) if s == "sequential" ) + if max_concurrent_steps <= delay_release: + raise ValueError( + "max_concurrent_steps must be greater than delay_release, but" + f" {max_concurrent_steps=}, {delay_release=}" + ) grid = [d for i, d in enumerate(grid_mapping.grid) if i not in sequential_axes] if len(grid) < 3: @@ -266,7 +375,6 @@ def lower_jaxpr_to_module( raise NotImplementedError( "Only <=3D grids are supported in Mosaic GPU lowering." ) - # Compute the number of steps along each sequential axis. if sequential_axes: # TODO(slebedev): Support multiple sequential axes. if len(sequential_axes) > 1: @@ -283,21 +391,30 @@ def lower_jaxpr_to_module( num_steps = 1 out_sequential_invariant = [True] * len(grid_mapping.out_shapes) + # Shrink ``max_concurrent_steps`` if the total number of steps is lower to + # reduce the size of the allocated buffers below. + if max_concurrent_steps > num_steps: + max_concurrent_steps = num_steps + delay_release = 0 # No need to delay anything + in_in_smem, out_in_smem = util.split_list( [ - bm.block_aval.memory_space in (None, gpu_core.SMEM) + bm.transformed_block_aval.memory_space in (None, gpu_core.SMEM) for bm in block_mappings ], [grid_mapping.num_inputs], ) - in_structs_gmem = [*grid_mapping.in_shapes] in_block_mappings, out_block_mappings = util.split_list( block_mappings, [grid_mapping.num_inputs] ) + in_structs_gmem = [*grid_mapping.in_shapes] + # We allocate the fully transformed shapes here. All primitives have seen the + # inverse transformation stack and will understand how to handle it. in_structs_smem = [ jax.ShapeDtypeStruct( - [max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype + [max_concurrent_steps, *bm.transformed_block_aval.shape], + bm.transformed_block_aval.dtype, ) if in_smem else None @@ -307,18 +424,13 @@ def lower_jaxpr_to_module( ] in_gmem_transforms = [ cast(gpu_core.MemoryRefTransform, bm.transforms) - for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs] + for bm in in_block_mappings ] - in_swizzles = map( - lambda bm: bm.swizzle - if isinstance(bm, gpu_core.GPUBlockMapping) - else None, - grid_mapping.block_mappings[: grid_mapping.num_inputs], - ) out_structs_gmem = [*grid_mapping.out_shapes] - # TODO(justinfu): Implement output Memref transforms out_structs_smem = [ - jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype) + jax.ShapeDtypeStruct( + [max_concurrent_steps, *bm.transformed_block_aval.shape], s.dtype + ) if in_smem else None for bm, in_smem, s in zip( @@ -327,6 +439,10 @@ def lower_jaxpr_to_module( grid_mapping.out_shapes, ) ] + out_gmem_transforms = [ + cast(gpu_core.MemoryRefTransform, bm.transforms) + for bm in out_block_mappings + ] def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): *buffers_gmem, ( @@ -342,16 +458,34 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): in_buffers_smem, out_buffers_smem = util.split_list( buffers_smem, [grid_mapping.num_inputs] ) - barriers, *extra_barriers = barriers + barriers, runtime_barriers, extra_barriers = barriers parallel_count = it.count() program_ids_template = [ - _program_id(next(parallel_count)) if i not in sequential_axes else None - for i in range(len(grid_mapping.grid)) + _program_id(next(parallel_count)) + if axis not in sequential_axes + else None + for axis in range(len(grid_mapping.grid)) ] + + def make_program_ids(step: ir.Value): + assert ir.IndexType.isinstance(step.type) + step = arith_dialect.index_cast(ir.IntegerType.get_signless(32), step) + return [step if pid is None else pid for pid in program_ids_template] + + grouped_barriers = collections.defaultdict(list) + for barrier, barrier_ref in zip(rs.barriers, runtime_barriers): + grouped_barriers[barrier].append(barrier_ref) module_ctx = ModuleContext( - name_and_src_info.name, grid_mapping, approx_math, runtime_smem + name_and_src_info.name, + grid_mapping, + None, + approx_math, + runtime_smem, + smem_used_bytes=0, + runtime_barriers=grouped_barriers, ) + del runtime_smem, grouped_barriers, runtime_barriers smem_scratch_it = iter(scratch_buffers_smem) scratch_buffers_template = [] @@ -412,7 +546,7 @@ def gmem_slice( block_mapping: pallas_core.BlockMapping, ) -> Sequence[mgpu.DynamicSlice]: assert len(sequential_axes) <= 1 - program_ids = [step if i is None else i for i in program_ids_template] + program_ids = make_program_ids(step) idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping) return tuple( mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape) @@ -424,15 +558,21 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: if not in_in_smem[idx]: return - # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. - gmem_transforms = (x.to_gpu_transform() for x in in_gmem_transforms[idx]) + swizzle = None + pl_transforms = in_gmem_transforms[idx] + if pl_transforms and isinstance( + pl_transforms[-1], gpu_core.SwizzleTransform + ): + swizzle = pl_transforms[-1].swizzle + pl_transforms = pl_transforms[:-1] + gmem_transforms = tuple(x.to_gpu_transform() for x in pl_transforms) launch_ctx.async_copy( src_ref=in_buffers_gmem[idx], dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot), gmem_slice=gmem_slice(step, in_block_mappings[idx]), barrier=barriers[slot], - gmem_transform=tuple(gmem_transforms), - swizzle=in_swizzles[idx], + gmem_transform=gmem_transforms, + swizzle=swizzle, arrive=False, # The caller must do ``arrive_expect_tx`` manually! uniform=False, predicate=is_memory_thread, @@ -454,9 +594,6 @@ def store( # We have to do some work to make sure that consecutive stores are not # going to be writing to the same location, or else we'll end up with # multiple concurrent writes and a racy program. - # TODO(apaszke,slebedev): In most cases output index maps depend only on - # parallel grid axes and in that case we can simply move the store to - # happen after the loop. # TODO(apaszke,slebedev): This still diverges significantly from the TPU # semantics in that it will move on to the next SMEM output slice even if # it's not storing the previous one. @@ -475,12 +612,21 @@ def store( do_store = arith_dialect.andi( is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step) ) - # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. + + swizzle = None + pl_transforms = out_gmem_transforms[idx] + if pl_transforms and isinstance( + pl_transforms[-1], gpu_core.SwizzleTransform + ): + swizzle = pl_transforms[-1].swizzle + pl_transforms = pl_transforms[:-1] + gmem_transforms = tuple(x.to_gpu_transform() for x in pl_transforms) launch_ctx.async_copy( src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot), dst_ref=out_buffers_gmem[idx], gmem_slice=store_slice, - swizzle=None, + gmem_transform=gmem_transforms, + swizzle=swizzle, uniform=False, predicate=do_store, ) @@ -492,6 +638,7 @@ def store( fetch(idx, _as_index(slot), _as_index(slot)) last_store_offsets = [None if inv else _as_index(-1) for inv in out_sequential_invariant] + @mgpu.fori(_as_index(num_steps), (accs, last_store_offsets)) def _(step, carry): accs, last_store_offsets = carry @@ -501,7 +648,9 @@ def _(step, carry): barriers[slot].wait() # We need to make sure the output copy is complete before the kernel starts # writing to the output window. - launch_ctx.await_async_copy(max_concurrent_steps - 1, await_read_only=True) + launch_ctx.await_async_copy( + max_concurrent_steps - (1 + delay_release), await_read_only=True + ) args = [ mgpu.memref_slice(buffers_smem[idx], slot) @@ -519,11 +668,14 @@ def _(step, carry): # but that's not necessarily true. args.extend(extra_barriers) new_accs = lower_jaxpr_to_mosaic_gpu( - module_ctx, launch_ctx, lowered_jaxpr, args + dataclasses.replace(module_ctx, program_ids=make_program_ids(step)), + launch_ctx, + lowered_jaxpr, + args, ) - # TODO(apaszke): Elide this if we're not going to perform any stores - mgpu.commit_shared() + if not all(out_sequential_invariant): + mgpu.commit_shared() new_store_offsets = [] for idx in range(grid_mapping.num_outputs): last_offset = last_store_offsets[idx] @@ -533,20 +685,28 @@ def _(step, carry): else last_offset # Only store if the output can depend on the step. ) - next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps)) - next_step_in_bounds = arith_dialect.cmpi( - arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps) + del slot # Just to make sure we don't accidentally use it. + fetch_step = arith_dialect.addi( + step, _as_index(max_concurrent_steps - delay_release) + ) + fetch_step_in_bounds = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.ult, fetch_step, _as_index(num_steps) ) - next_slot = slot # (x + y) % y == x % y - with mgpu.when(next_step_in_bounds): - barriers[slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread) + not_initial_step = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.uge, step, _as_index(delay_release) + ) + fetch_slot = arith_dialect.remui(fetch_step, _as_index(max_concurrent_steps)) + with mgpu.when(arith_dialect.andi(fetch_step_in_bounds, not_initial_step)): + barriers[fetch_slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread) for idx in range(grid_mapping.num_inputs): - fetch(idx, next_step, next_slot) + fetch(idx, fetch_step, fetch_slot) return list(new_accs), new_store_offsets # Outputs invariant to the sequential axis are never written from inside the # loop. This is the only place where we store them. + if all(out_sequential_invariant): + mgpu.commit_shared() last_slot = _as_index((num_steps - 1) % max_concurrent_steps) for idx in range(grid_mapping.num_outputs): if out_sequential_invariant[idx]: @@ -567,6 +727,7 @@ def _(step, carry): "All scratch operands must be SMEM references or accumulators (ACC)," f" but got: {scratch_avals}" ) + rs = _estimate_resources(jaxpr) extra_barriers = [ mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape) for aval in scratch_avals @@ -580,7 +741,7 @@ def _(step, carry): ] smem_scratch_bytes = compiler_params.get("smem_scratch_bytes") if smem_scratch_bytes is None: - smem_scratch_bytes = _estimate_smem_scratch_bytes(jaxpr) + smem_scratch_bytes = rs.smem_scratch_bytes extra_smem_scratch.append( jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8) ) @@ -597,7 +758,8 @@ def _(step, carry): *extra_smem_scratch, ( mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps), - *extra_barriers, + rs.barriers, + extra_barriers, ), ), module_name=name_and_src_info.name, @@ -668,9 +830,9 @@ def write_env(var: jax_core.Var, val): @register_lowering_rule(primitives.program_id_p) def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): - # TODO(apaszke): Sequential axis should be handled specially!! - del ctx # Unused. - return _program_id(axis) + if ctx.module_ctx.program_ids is None: + raise NotImplementedError("pl.program_id() is not supported in this context") + return ctx.module_ctx.program_ids[axis] def _program_id(axis: int) -> ir.Value: @@ -689,34 +851,100 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): ) +def _handle_indexing( + ref: ir.Value, transforms: Sequence[gpu_core.Transform] +) -> ir.Value: + if not transforms: + pass + if not any(isinstance(t, indexing.NDIndexer) for t in transforms): + return ref + if any( + isinstance(t, indexing.NDIndexer) for t in transforms[:-1] + ) or not isinstance(transforms[-1], indexing.NDIndexer): + raise NotImplementedError("Only one level of indexing supported.") + + indexer = cast(indexing.NDIndexer, transforms[-1]) + if indexer.int_indexer_shape: + raise NotImplementedError("int_indexer_shape non-empty") + indices = _ndindexer_indices(indexer) + for t in reversed(transforms[:-1]): + indices = t.untransform_index(indices) + return mgpu.memref_slice(ref, indices) + + +def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]: + indices = [] + for idx in indexer.indices: + if not isinstance(idx, indexing.Slice): + indices.append(_as_index(idx)) + elif not idx.is_dynamic_start and not idx.is_dynamic_size: + indices.append(slice(idx.start, idx.start + idx.size, idx.stride)) + elif idx.stride == 1: + indices.append( + mgpu.DynamicSlice( + _as_index(idx.start) if idx.is_dynamic_start else idx.start, + _as_index(idx.size) if idx.is_dynamic_size else idx.size, + ) + ) + else: + raise NotImplementedError(f"Unsupported slice: {idx}") + return tuple(indices) + + +def _is_swizzled(transforms: tuple[gpu_core.Transform, ...]) -> int | None: + if not transforms: + return None + if any(isinstance(t, gpu_core.UnswizzleRef) for t in transforms[1:]): + raise NotImplementedError( + "Swizzling must be the last transform applied to a ref" + ) + if isinstance(t := transforms[0], gpu_core.UnswizzleRef): + return t.swizzle + return None + + @register_lowering_rule(sp.get_p) -def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree): - del tree # Unused. - if indexers: - raise NotImplementedError("No support for indexers yet") - [x_aval] = ctx.avals_in - return mgpu.FragmentedArray.load_strided( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) +def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): + if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): + raise TypeError(f"Can only load from references (got {x_smem}).") + x_aval = ctx.avals_in[0] + transform = jax.tree.unflatten(tree, leaves) + swizzle = _is_swizzled(transform) + x_smem = _handle_indexing(x_smem, transform) + if swizzle is None: + return mgpu.FragmentedArray.load_strided( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + else: + return mgpu.FragmentedArray.load_tiled( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle + ) @register_lowering_rule(sp.swap_p) def _swap_lowering_rule( - ctx: LoweringRuleContext, x_smem, value, *indexers, tree + ctx: LoweringRuleContext, x_smem, value, *leaves, tree ): - del tree # Unused. - if indexers: - raise NotImplementedError("No support for indexers yet") if not isinstance(value, mgpu.FragmentedArray): raise TypeError(f"Can only store arrays (got {value}).") if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): - raise TypeError(f"Can only store to references (got {value}).") - x_aval, _ = ctx.avals_in - old_value = mgpu.FragmentedArray.load_strided( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) - value.store_untiled(x_smem) - return old_value + raise TypeError(f"Can only store to references (got {x_smem}).") + transforms = jax.tree.unflatten(tree, leaves) + swizzle = _is_swizzled(transforms) + x_smem = _handle_indexing(x_smem, transforms) + x_aval = ctx.avals_in[0] + if swizzle is None: + old_value = mgpu.FragmentedArray.load_strided( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + value.store_untiled(x_smem) + return old_value + else: + old_value = mgpu.FragmentedArray.load_tiled( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle + ) + value.store_tiled(x_smem, swizzle=swizzle) + return old_value @register_lowering_rule(pjit.pjit_p) @@ -728,6 +956,16 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): ) +@register_lowering_rule(lax.slice_p) +def _slice_lowering_rule( + ctx: LoweringRuleContext, x, limit_indices, start_indices, strides +): + if strides is not None: + raise NotImplementedError("Strides are not supported.") + + return x[tuple(slice(b, e) for b, e in zip(start_indices, limit_indices))] + + @register_lowering_rule(lax.select_n_p) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): if len(cases) != 2: @@ -806,6 +1044,12 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) +@register_lowering_rule(lax.logistic_p) +def _logistic_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + a = _ensure_fa(x, x_aval.dtype) + return 1. / (1. + (-a).exp(approx=ctx.module_ctx.approx_math)) + @register_lowering_rule(lax.reduce_sum_p) def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): @@ -860,21 +1104,28 @@ def _run_scoped_lowering_rule( input_refs = [] bytes_allocated = 0 should_discharge = [] - for a in jaxpr.invars: - a = a.aval - if isinstance(a, gpu_core.WGMMAAbstractAccumulatorRef): - mlir_dtype = mlir.dtype_to_ir_type(a.dtype) - input_refs.append(mgpu.WGMMAAccumulator.zero(*a.shape, mlir_dtype)) + for v in jaxpr.invars: + aval = v.aval + if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): + mlir_dtype = mlir.dtype_to_ir_type(aval.dtype) + input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, mlir_dtype)) should_discharge.append(True) - elif a.memory_space == gpu_core.SMEM: + elif isinstance(aval.dtype, gpu_core.BarrierType): + input_refs.append( + ctx.module_ctx.reserve_barrier( + mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape) + ) + ) + should_discharge.append(False) + elif aval.memory_space == gpu_core.SMEM: ref_bytes, [input_ref] = ctx.module_ctx.scratch_view( - [jax.ShapeDtypeStruct(shape=a.shape, dtype=a.dtype)] + [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)] ) bytes_allocated += ref_bytes input_refs.append(input_ref) should_discharge.append(False) else: - raise ValueError(f"Can't convert to ref: {a}") + raise ValueError(f"Can't convert to ref: {aval}") if any(should_discharge): # We convert consts to args, because we only have ir.Values and @@ -1086,9 +1337,15 @@ def _i64_constant(v: int) -> ir.Value: return arith_dialect.constant(ir.IntegerType.get_signless(64), v) -def _as_index(v: int | ir.Value) -> ir.Value: - if isinstance(v, int): - return arith_dialect.constant(ir.IndexType.get(), v) - if ir.IndexType.isinstance(v.type): - return v - return arith_dialect.index_cast(ir.IndexType.get(), v) +def _as_index(v: object) -> ir.Value: + match v: + case int(): + return arith_dialect.constant(ir.IndexType.get(), v) + case ir.Value() if ir.IndexType.isinstance(v.type): + return v + case ir.Value() if ir.IntegerType.isinstance(v.type): + return arith_dialect.index_cast(ir.IndexType.get(), v) + case mgpu.FragmentedArray(layout=mgpu.WGSplatFragLayout()): + return _as_index(v.registers.item()) + case _: + raise ValueError(f"Unsupported index: {v}") diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index dcec631e389b..dd689427b304 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -19,93 +19,288 @@ from jax._src import core as jax_core from jax._src import effects from jax._src import state -from jax._src.state import discharge +from jax._src import tree_util +from jax._src import util from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering +from jax._src.state import discharge +from jax._src.state import indexing +from jax._src.state import primitives as state_primitives import jax.experimental.mosaic.gpu as mgpu -async_copy_p = jax_core.Primitive("async_copy") -async_copy_p.multiple_results = True + +copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem") +copy_smem_to_gmem_p.multiple_results = True -@async_copy_p.def_effectful_abstract_eval -def _async_copy_abstract_eval(*avals): - del avals # Unused. +@copy_smem_to_gmem_p.def_effectful_abstract_eval +def _copy_smem_to_gmem_abstract_eval(*avals, **params): + del avals, params # Unused. return (), {state.ReadEffect(0), state.WriteEffect(1)} -@lowering.register_lowering_rule(async_copy_p) -def _async_copy_lowering_rule( - ctx: lowering.LoweringRuleContext, src, dst, barrier=None +@lowering.register_lowering_rule(copy_smem_to_gmem_p) +def _copy_smem_to_gmem_lowering( + ctx: lowering.LoweringRuleContext, + src, + dst, + *flat_transforms, + src_transforms_treedef, + dst_transforms_treedef, ): - ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, barrier=barrier) + flat_src_transforms, flat_dst_transforms = util.split_list( + flat_transforms, + [src_transforms_treedef.num_leaves], + ) + src = lowering._handle_indexing( + src, src_transforms_treedef.unflatten(flat_src_transforms) + ) + copy_params = _extract_copy_params( + dst_transforms_treedef.unflatten(flat_dst_transforms) + ) + mgpu.commit_shared() + ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params) return () -def async_copy_smem_to_gmem( +def _extract_copy_params(transforms): + if not transforms: + return {} + if any( + isinstance(t, indexing.NDIndexer) for t in transforms[:-1] + ) or not isinstance(transforms[-1], indexing.NDIndexer): + raise NotImplementedError("Only one level of indexing supported") + *transforms, indexer = transforms + swizzle = lowering._is_swizzled(transforms) + if swizzle is not None: + transforms = transforms[1:] + gpu_transforms = [t.to_gpu_transform() for t in transforms] + return dict( + gmem_slice=lowering._ndindexer_indices(indexer), + gmem_transform=tuple(gpu_transforms), + swizzle=swizzle, + ) + + +def copy_smem_to_gmem( src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef ) -> None: + """Asynchronously copies a SMEM reference to a GMEM reference. + + See also: + :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` + """ if src.memory_space is not gpu_core.SMEM: raise TypeError(f"src must be a SMEM reference, got {src.memory_space}") if dst.memory_space is not gpu_core.GMEM: raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}") - async_copy_p.bind(src, dst) + src, src_transforms = state_primitives.get_ref_and_transforms( + src, None, "copy_smem_to_gmem" + ) + dst, dst_transforms = state_primitives.get_ref_and_transforms( + dst, None, "copy_smem_to_gmem" + ) + flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten( + src_transforms + ) + flat_dst_transforms, dst_transforms_treedef = tree_util.tree_flatten( + dst_transforms + ) + copy_smem_to_gmem_p.bind( + src, + dst, + *flat_src_transforms, + *flat_dst_transforms, + src_transforms_treedef=src_transforms_treedef, + dst_transforms_treedef=dst_transforms_treedef, + ) return None -def async_copy_gmem_to_smem( +copy_gmem_to_smem_p = jax_core.Primitive("copy_gmem_to_smem") +copy_gmem_to_smem_p.multiple_results = True + + +@copy_gmem_to_smem_p.def_effectful_abstract_eval +def _copy_gmem_to_smem_abstract_eval(*avals, **params): + del avals, params # Unused. + return (), {state.ReadEffect(0), state.WriteEffect(1)} + + +@lowering.register_lowering_rule(copy_gmem_to_smem_p) +def _copy_gmem_to_smem_lowering( + ctx: lowering.LoweringRuleContext, + src, + dst, + barrier, + *flat_transforms, + src_transforms_treedef, + dst_transforms_treedef, + barrier_transforms_treedef, +): + flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( + util.split_list( + flat_transforms, + [ + src_transforms_treedef.num_leaves, + dst_transforms_treedef.num_leaves, + ], + ) + ) + copy_params = _extract_copy_params( + src_transforms_treedef.unflatten(flat_src_transforms) + ) + dst = lowering._handle_indexing( + dst, dst_transforms_treedef.unflatten(flat_dst_transforms) + ) + barrier_indexer = _extract_barrier_indexer( + barrier_transforms_treedef.unflatten(flat_barrier_transforms) + ) + if barrier_indexer is not None: + barrier = barrier.__getitem__( + *map(lowering._as_index, barrier_indexer.indices) + ) + ctx.launch_ctx.async_copy( + src_ref=src, dst_ref=dst, barrier=barrier, **copy_params + ) + return () + + +def copy_gmem_to_smem( src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef, *, barrier: pallas_core.AbstractMemoryRef, ) -> None: + """Asynchronously copies a GMEM reference to a SMEM reference. + + See also: + :func:`jax.experimental.mosaic.gpu.wait_barrier` + """ if src.memory_space is not gpu_core.GMEM: raise TypeError(f"src must be a GMEM reference, got {src.memory_space}") if dst.memory_space is not gpu_core.SMEM: raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}") - async_copy_p.bind(src, dst, barrier) + src, src_transforms = state_primitives.get_ref_and_transforms( + src, None, "copy_gmem_to_smem" + ) + dst, dst_transforms = state_primitives.get_ref_and_transforms( + dst, None, "copy_gmem_to_smem" + ) + flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten( + src_transforms + ) + flat_dst_transforms, dst_transforms_treedef = tree_util.tree_flatten( + dst_transforms + ) + barrier, barrier_transforms = state_primitives.get_ref_and_transforms( + barrier, None, "copy_gmem_to_smem" + ) + flat_barrier_transforms, barrier_transforms_treedef = tree_util.tree_flatten( + barrier_transforms + ) + copy_gmem_to_smem_p.bind( + src, + dst, + barrier, + *flat_src_transforms, + *flat_dst_transforms, + *flat_barrier_transforms, + src_transforms_treedef=src_transforms_treedef, + dst_transforms_treedef=dst_transforms_treedef, + barrier_transforms_treedef=barrier_transforms_treedef, + ) return None +def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None: + if not transforms: + return None + match transforms: + case [indexing.NDIndexer(indices=[idx]) as indexer]: + if not isinstance(idx, indexing.Slice): + return indexer + if indexing.Slice.from_slice(slice(None), *indexer.shape) == idx: + # Special-case: the whole slice. + return None + else: + raise ValueError( + f"Barrier can only be indexed with an integer, got {idx}" + ) + case [indexing.NDIndexer()]: + raise NotImplementedError("Barrier does not support multiple indices") + case []: + return None + case _: + raise ValueError("Barrier does not support arbirary transforms") + + class WaitEffect(jax_core.Effect): ... +effects.control_flow_allowed_effects.add_type(WaitEffect) -wait_effect = WaitEffect() +_wait_effect = WaitEffect() -wait_p = jax_core.Primitive("wait") -wait_p.multiple_results = True +wait_barrier_p = jax_core.Primitive("wait") +wait_barrier_p.multiple_results = True -@wait_p.def_effectful_abstract_eval -def _wait_abstract_eval(*avals, **params): +@wait_barrier_p.def_effectful_abstract_eval +def _wait_barrier_abstract_eval(*avals, **params): del avals, params # Unused. - return (), {wait_effect} + return (), {_wait_effect} -@lowering.register_lowering_rule(wait_p) -def _wait_lowering_rule( - ctx: lowering.LoweringRuleContext, barrier=None, allow_groups=None, +@lowering.register_lowering_rule(wait_barrier_p) +def _wait_barrier_lowering( + ctx: lowering.LoweringRuleContext, + barrier, + *flat_transforms, + transforms_treedef, ): - if barrier is not None: - barrier.wait() - else: - assert allow_groups is not None - ctx.launch_ctx.await_async_copy(allow_groups=allow_groups) + del ctx # Unused. + transforms = transforms_treedef.unflatten(flat_transforms) + indexer = _extract_barrier_indexer(transforms) + if indexer is not None: + barrier = barrier.__getitem__(*map(lowering._as_index, indexer.indices)) + barrier.wait() return () -def wait_smem_to_gmem(allow_groups: int) -> None: - """Waits until there are no more than the given number of SMEM->GMEM copies in flight.""" - wait_p.bind(allow_groups=allow_groups) - - def wait_barrier(barrier: pallas_core.AbstractMemoryRef) -> None: """Waits on the given barrier.""" - wait_p.bind(barrier) + barrier, transforms = state_primitives.get_ref_and_transforms( + barrier, None, "wait_barrier" + ) + flat_transforms, transforms_treedef = tree_util.tree_flatten(transforms) + wait_barrier_p.bind( + barrier, *flat_transforms, transforms_treedef=transforms_treedef + ) + + +wait_smem_to_gmem_p = jax_core.Primitive("wait_smem_to_gmem") +wait_smem_to_gmem_p.multiple_results = True + + +@wait_smem_to_gmem_p.def_effectful_abstract_eval +def _wait_smem_to_gmem_abstract_eval(n): + del n # Unused. + return (), {_wait_effect} + + +@lowering.register_lowering_rule(wait_smem_to_gmem_p) +def _wait_smem_to_gmem_lowering(ctx: lowering.LoweringRuleContext, n): + ctx.launch_ctx.await_async_copy(allow_groups=n) + return () + + +def wait_smem_to_gmem(n: int) -> None: + """Waits until there are no more than ``n`` SMEM->GMEM copies in flight.""" + wait_smem_to_gmem_p.bind(n) class _WGMMAPipelineEffect(effects.Effect): @@ -119,40 +314,76 @@ class _WGMMAPipelineEffect(effects.Effect): wgmma_ref_p = jax_core.Primitive("wgmma_ref") wgmma_ref_p.multiple_results = True -def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128): - """Asynchronous warp group matmul. - The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires - that accumulator is an accumualtion register reference. +def wgmma( + acc: gpu_core.WGMMAAbstractAccumulatorRef, + a: pallas_core.TransformedRef, + b: pallas_core.TransformedRef, +) -> None: + """Performs and asynchronous warp group matmul-accumulate on the given references. + + Conceptually, this is equivalent to doing ``acc[...] += a[...] @ b[...]``, + except that the computation is performed asynchronously. Args: - acc: The accumulator register. - a: The left hand side operand. - b: The right hand side operand. - transpose: Whether to transpose b. - n_tile: The number of tiles to use. - swizzle: The swizzle pattern. + acc: The accumulator reference. Needs to be allocated via + :func:`jax.experimental.pallas.run_scoped` called with a + :func:`jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef`. + a: The left hand side operand reference. + b: The right hand side operand reference. + + See also: + :func:`jax.experimental.pallas.mosaic_gpu.wgmma_wait` """ - if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): - raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}") - - ma, ka, tma, tka = a.shape - kb, nb, tkb, tnb = b.shape - mc, nc = acc.shape - - if rhs_transpose: - kb, nb, tkb, tnb = nb, kb, tnb, tkb - - if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb: - raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}") + m, n = acc.shape + m2, k = a.shape + k2, n2 = b.shape + + if m != m2 or n != n2 or k != k2: + raise ValueError( + f"Incompatible shapes for matrix multiplication: lhs={a.shape}," + f" rhs={b.shape=}, acc={acc.shape}" + ) + + if (dtype := a.dtype) != b.dtype: + raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}") + + # Infer swizzle from a. + if not a.transforms or not isinstance( + (swizzle_transform := a.transforms[0]), gpu_core.UnswizzleRef + ): + raise ValueError("WGMMA lhs must be a tiled and swizzled reference.") + + swizzle = swizzle_transform.swizzle + swizzle_elems = swizzle // dtype.itemsize + if a.transforms[1:] != (gpu_core.UntileRef((64, swizzle_elems)),): + raise ValueError( + f"WGMMA lhs must be tiled with 64x{swizzle_elems} tiles for element type" + f" {dtype}." + ) + + rhs_transpose_transform = gpu_core.TransposeRef((1, 0, 2, 3)) + rhs_tiling = gpu_core.UntileRef((swizzle_elems, swizzle_elems)) + if b.transforms == (swizzle_transform, rhs_tiling): + rhs_transpose = False + elif b.transforms == (swizzle_transform, rhs_transpose_transform, rhs_tiling): + rhs_transpose = True + else: + raise ValueError( + f"WGMMA rhs must have {swizzle=} and be tiled with" + f" {swizzle_elems}x{swizzle_elems} tiles for element type {dtype} (and" + " optionally transposed)." + ) - return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose) + wgmma_ref_p.bind(acc, a.ref, b.ref, swizzle=swizzle, rhs_transpose=rhs_transpose) @wgmma_ref_p.def_effectful_abstract_eval -def _wgmma_ref_effectful_abstract_eval(acc, *args, **kwargs): - del acc, args, kwargs - return [], { +def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, **params): + del a_aval, b_aval, params + if not isinstance(acc_aval, gpu_core.WGMMAAbstractAccumulatorRef): + raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc_aval}") + return (), { _wgmma_pipeline_effect, state.WriteEffect(0), state.ReadEffect(0), @@ -162,8 +393,9 @@ def _wgmma_ref_effectful_abstract_eval(acc, *args, **kwargs): @discharge.register_discharge_rule(wgmma_ref_p) -def _wgmma_ref_discharge_rule( - in_avals, out_avals, +def _wgmma_ref_discharge( + in_avals, + out_avals, acc, a, b, @@ -183,8 +415,9 @@ def _wgmma_ref_discharge_rule( # Functional WGMMA, returns a shaped array. Internal. wgmma_p = jax_core.Primitive("wgmma") + @lowering.register_lowering_rule(wgmma_p) -def _wgmma_lowering_rule( +def _wgmma_lowering( ctx: lowering.LoweringRuleContext, acc, a, @@ -205,6 +438,7 @@ def _wgmma_lowering_rule( nvvm_dialect.wgmma_commit_group_sync_aligned() return new_acc + @wgmma_p.def_effectful_abstract_eval def _wgmma_effectful_abstract_eval(acc, *args, **kwargs): del args, kwargs @@ -217,22 +451,26 @@ def _wgmma_effectful_abstract_eval(acc, *args, **kwargs): wgmma_wait_p = jax_core.Primitive("wgmma_wait") wgmma_wait_p.multiple_results = True -def wgmma_wait(i: int): - """Wait until all but the last `i` WGMMA operations are done.""" - return wgmma_wait_p.bind(i) + +def wgmma_wait(n: int): + """Waits until there is no more than ``n`` WGMMA operations in flight.""" + return wgmma_wait_p.bind(n) @wgmma_wait_p.def_effectful_abstract_eval def wgmma_wait_effectful_abstract_eval(_): return [], {_wgmma_pipeline_effect} + @lowering.register_lowering_rule(wgmma_wait_p) -def _wgmma_wait_lowering_rule(ctx: lowering.LoweringRuleContext, allow_groups): +def _wgmma_wait_lowering(ctx: lowering.LoweringRuleContext, allow_groups): del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) return () + wgmma_accumulator_deref_p = jax_core.Primitive("wgmma_accumulator_deref_p") + def wgmma_accumulator_deref(acc): """Dereferences an accumulator register.""" @@ -248,13 +486,15 @@ def _wgmma_accumulator_deref_abstract_eval(acc): assert isinstance(ret, jax_core.ShapedArray), acc return ret, {_wgmma_pipeline_effect} + @discharge.register_discharge_rule(wgmma_accumulator_deref_p) -def _wgmma_accumulator_deref_discharge_rule(in_avals, out_avals, acc): +def _wgmma_accumulator_deref_discharge(in_avals, out_avals, acc): del in_avals, out_avals return (None,), wgmma_accumulator_deref_p.bind(acc) + @lowering.register_lowering_rule(wgmma_accumulator_deref_p) -def _wgmma_accumulator_deref_lowering_rule(ctx: lowering.LoweringRuleContext, acc): +def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc): del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(0) return acc.value diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 1c10d2bda9e9..44dad819bc09 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -30,6 +30,7 @@ from jax._src import core as jax_core from jax._src import effects from jax._src import linear_util as lu +from jax._src import state from jax._src import tree_util from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -39,6 +40,7 @@ from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.state import discharge as state_discharge +from jax._src.state import types as state_types from jax._src.util import ( safe_map, safe_zip, @@ -207,6 +209,7 @@ def _pallas_call_impl_interpret( print(discharged_jaxpr) out = _initialize_output_vals(grid_mapping.block_mappings_output, args, input_output_aliases) + # TODO(b/370563936): Fix correctness issue w/ io aliasing scalars = args[grid_mapping.slice_index_ops] block_args = args[len(scalars):] # invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch] @@ -935,14 +938,17 @@ def g(): with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()): kernel_src_info: pallas_core.SrcInfoStr = "" - jaxpr = _trace_kernel_to_jaxpr( + jaxpr, consts = _trace_kernel_to_jaxpr( when_wrapped_kernel, kernel_src_info, batched_grid_mapping, tuple(flat_kernel_avals), kernel_in_tree, + tuple(() for _ in flat_kernel_avals), interpret=interpret, ) + if consts: + raise NotImplementedError("consts not supported in pallas_call") assert ragged_axis_length is not None args = (ragged_axis_length, *args) @@ -987,6 +993,85 @@ def checkify_pallas_kernel_body_jaxpr( body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) return checked_jaxpr, out_tree, error_effects +def pallas_call_checkify_oob_grid(error: checkify.Error, + enabled_errors, + args: jax_core.Value, + grid_mapping: GridMapping, + input_output_aliases) -> checkify.Error: + if checkify.OOBError not in enabled_errors: + return error + dynamic_grid_args, args = split_list( + args, [grid_mapping.num_dynamic_grid_bounds] + ) + output_args = _initialize_output_vals(grid_mapping.block_mappings_output, + args, input_output_aliases) + scalars, input_args, _ = split_list( + args, [grid_mapping.num_index_operands, + grid_mapping.num_inputs], + ) + dynamic_grid_args_iter = iter(dynamic_grid_args) + grid = tuple( + a if a is not pallas_core.dynamic_grid_dim + else next(dynamic_grid_args_iter) + for a in grid_mapping.grid + ) + grid_start_indices = (jnp.int32(0),) * len(grid) + if grid: + num_iterations = reduce(jnp.multiply, grid) + else: + # Base case is always one iteration when grid is () + num_iterations = 1 + + is_indexing_dim = [ + tuple(b is pallas_core.mapped for b in bm.block_shape) + for bm in grid_mapping.block_mappings + ] + block_shapes = [ + None if iid is None + else tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) + for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) + ] + # The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch) + # i:int32 is the interation index + # loop_idx: tuple[int32] are the program ids for each grid axis + def cond(carry): + i, *_ = carry + return i < num_iterations + def body(carry): + i, loop_idx = carry + if grid_mapping.local_grid_env is not None: + local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) + else: + local_grid_env = tuple( + pallas_core.GridAxis(idx, b) + for dim, (idx, b) in enumerate(zip(loop_idx, grid)) + if dim not in grid_mapping.vmapped_dims + ) + with pallas_core.grid_env(local_grid_env): + start_indices = [ + None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars) + for bm in grid_mapping.block_mappings] + # We perform a dynamic slice on the i/o blocks, which will be checked by + # checkify for OOB accesses. + map(_maybe_dynamic_slice, start_indices, block_shapes, + [*input_args, *output_args], is_indexing_dim) + return (i + 1, _get_next_indices(grid, loop_idx)) + def f(_): + lax.while_loop( + cond, body, (jnp.int32(0), grid_start_indices) + ) + flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),)) + wrapped_loop, _ = api_util.flatten_fun_nokwargs( + lu.wrap_init(f), jaxpr_in_tree) + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + avals_in = map(jax_core.get_aval, flat_args) + traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic( + wrapped_loop, list(avals_in)) + traced_loop = jax_core.ClosedJaxpr(traced_loop, consts) + out_error, _ = checkify.checkify_jaxpr( + traced_loop, checkify.index_checks, error, flat_args) + return out_error + def pallas_call_checkify_rule(error: checkify.Error, enabled_errors, *args: jax_core.Value, @@ -996,6 +1081,10 @@ def pallas_call_checkify_rule(error: checkify.Error, grid_mapping: GridMapping, out_avals: tuple[jax_core.AbstractValue, ...], **kwargs): + # Check for OOB accesses in the grid. + error = pallas_call_checkify_oob_grid(error, enabled_errors, + args, grid_mapping, + input_output_aliases) # We implement the checkify rule in 4 steps: # 1) First, trace the kernel body to get the expected error shapes. # 2) Checkify the kernel body to obtain a jaxpr with errors as inputs @@ -1136,19 +1225,35 @@ def _ensure_2d_error_shape(arg): return new_error, results checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule + +# All of those shenanigans are because we can't make TransformedRef a PyTree, +# because they should appear as atomic JAX values to the users. +@lu.transformation +def wrap_with_transforms(transforms, *args): + new_args = tuple( + state_types.TransformedRef(a, t) if t else a + for a, t in zip(args, transforms) + ) + res = yield new_args, {} + yield res + + @weakref_lru_cache -def _trace_kernel_to_jaxpr(fun: Callable, - name_and_src_info: pallas_core.NameAndSrcInfo, - grid_mapping: GridMapping, - kernel_avals: tuple[pallas_core.AbstractMemRef, ...], - kernel_in_tree: tree_util.PyTreeDef, - interpret: bool, - ) -> jax_core.ClosedJaxpr: +def _trace_kernel_to_jaxpr( + fun: Callable, + name_and_src_info: pallas_core.NameAndSrcInfo, + grid_mapping: GridMapping, + kernel_avals: tuple[pallas_core.AbstractMemRef, ...], + kernel_in_tree: tree_util.PyTreeDef, + kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...], + interpret: bool, +) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]: if interpret: kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval, kernel_avals)) wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(fun), kernel_in_tree) + wrapped_kernel_fun = wrap_with_transforms(wrapped_kernel_fun, kernel_in_transforms) debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call") with grid_mapping.trace_env(): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun, @@ -1156,17 +1261,18 @@ def _trace_kernel_to_jaxpr(fun: Callable, if consts: consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c)) for c in consts] - raise ValueError( - f"The kernel function in the pallas_call {name_and_src_info} " - f"captures constants {consts_avals}. " - "You should pass them as inputs") + if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals): + raise ValueError( + f"The kernel function in the pallas_call {name_and_src_info} " + f"captures constants {consts_avals}. " + "You should pass them as inputs") kernel_out_tree = out_tree_thunk() if kernel_out_tree != tree_util.tree_structure(None): raise ValueError( f"The kernel function in the pallas_call {name_and_src_info} " f"should return None. It returns a PyTree: {kernel_out_tree}") - return jaxpr + return jaxpr, tuple(consts) _PALLAS_USE_MOSAIC_GPU = config.bool_flag( @@ -1191,6 +1297,8 @@ def _unsupported_lowering_error(platform: str) -> Exception: def _pallas_call_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params ): + if params['jaxpr'].constvars: + raise ValueError('Cannot lower a pallas_call with constants.') if interpret: # If we are in interpret mode, we don't care what platform we are on. impl = partial(_pallas_call_impl_interpret, **params) @@ -1268,6 +1376,133 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue: return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) +def _get_memory_space_from_ref(ref_aval: state.AbstractRef) -> Any: + if isinstance(ref_aval, pallas_core.AbstractMemoryRef): + return ref_aval.memory_space + return pallas_core.MemorySpace.ANY + + +@state_discharge.register_discharge_rule(pallas_call_p) +def _pallas_call_state_discharge_rule( + avals_in, + avals_out, + *args, + jaxpr: jax_core.Jaxpr, + input_output_aliases: tuple[tuple[int, int], ...], + name_and_src_info: pallas_core.NameAndSrcInfo, + grid_mapping: GridMapping, + debug: bool, + interpret: bool, + compiler_params: Any, + cost_estimate: CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], +): + del avals_out + assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars) + num_refs = len(jaxpr.constvars) + ref_avals, rest_in_avals = split_list(avals_in, [num_refs]) + assert all(isinstance(ref_aval, state.AbstractRef) for ref_aval in ref_avals) + ref_avals = [ + pallas_core.AbstractMemoryRef( + ref_aval.inner_aval, pallas_core.MemorySpace.ANY + ) + for ref_aval in ref_avals + ] + ref_block_specs = [ + pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY) + ] * num_refs + ref_block_mappings = [ + block_spec.to_block_mapping( + origin="", # TODO(sharadmv): enable origins for refs + array_aval=ref_aval.inner_aval, + index_map_avals=grid_mapping.index_map_avals, + index_map_tree=grid_mapping.index_map_tree, + grid=grid_mapping.grid, + mapped_dims=grid_mapping.mapped_dims, + ) for ref_aval, block_spec in zip(ref_avals, ref_block_specs) + ] + in_block_mappings, out_block_mappings = split_list( + grid_mapping.block_mappings, [grid_mapping.num_inputs] + ) + new_block_mappings = ( + *ref_block_mappings, + *in_block_mappings, + *ref_block_mappings, + *out_block_mappings, + ) + new_grid_mapping = grid_mapping.replace( + block_mappings=new_block_mappings, + num_inputs=grid_mapping.num_inputs + num_refs, + num_outputs=grid_mapping.num_outputs + num_refs) + new_input_output_aliases = [ + (i + grid_mapping.num_index_operands, i) for i in range(num_refs) + ] + for i, o in input_output_aliases: + new_input_output_aliases.append((i + num_refs, o + num_refs)) + ref_out_avals = [ref_aval.inner_aval for ref_aval in ref_avals] + new_out_avals = (*ref_out_avals, *out_avals) + ref_args, dynamic_grid_bounds, index_operands, rest_args = split_list( + args, + [ + num_refs, + grid_mapping.num_dynamic_grid_bounds, + grid_mapping.num_index_operands, + ], + ) + def _rewritten_body(*args): + index_args, in_args, out_args, rest_args = split_list( + args, [new_grid_mapping.num_index_operands, new_grid_mapping.num_inputs, + new_grid_mapping.num_outputs]) + ref_in_args, in_args = split_list(in_args, [num_refs]) + ref_out_args, out_args = split_list(out_args, [num_refs]) + # We don't care about ref_out_args because they are aliased to ref_in_args + del ref_out_args + jax_core.eval_jaxpr( + jaxpr, ref_in_args, *index_args, *in_args, *out_args, *rest_args + ) + return [] + index_map_avals, jaxpr_in_avals, jaxpr_out_avals, jaxpr_rest_avals = ( + split_list( + [v.aval for v in jaxpr.invars], + [ + grid_mapping.num_index_operands, + grid_mapping.num_inputs, + grid_mapping.num_outputs, + ], + ) + ) + new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(_rewritten_body), + [ + *index_map_avals, + *ref_avals, + *jaxpr_in_avals, + *ref_avals, + *jaxpr_out_avals, + *jaxpr_rest_avals, + ], + ) + out_flat = pallas_call_p.bind( + *consts, + *dynamic_grid_bounds, + *index_operands, + *ref_args, + *rest_args, + jaxpr=new_jaxpr, + input_output_aliases=new_input_output_aliases, + grid_mapping=new_grid_mapping, + name_and_src_info=name_and_src_info, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, + cost_estimate=cost_estimate, + out_avals=new_out_avals, + ) + refs_out, rest = split_list(out_flat, [num_refs]) + updated_vals_in = refs_out + [None] * len(rest_in_avals) + return updated_vals_in, rest + + def pallas_call( kernel: Callable[..., None], out_shape: Any, @@ -1406,16 +1641,25 @@ def wrapped(*args): for p in in_paths) out_origins = tuple(f"outputs{tree_util.keystr(p)}" for p in out_paths) # TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc. - kernel_avals, grid_mapping = pallas_core.get_grid_mapping( + kernel_args, grid_mapping = pallas_core.get_grid_mapping( grid_spec, flat_in_avals, in_tree, in_origins, flat_out_avals, out_tree, out_origins) - flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals) + flat_kernel_args, kernel_in_tree = tree_util.tree_flatten(kernel_args) + flat_kernel_avals = tuple( + x.ref if isinstance(x, state_types.TransformedRef) else x + for x in flat_kernel_args + ) + # Note that only a subset of all transforms can be found here, and they are + # never expected to contains any arrays. + kernel_arg_transforms = tuple( + x.transforms if isinstance(x, state_types.TransformedRef) else () + for x in flat_kernel_args + ) with pallas_core.interpret_mode_env(interpret): - jaxpr = _trace_kernel_to_jaxpr( - kernel, kernel_src_info, - grid_mapping, tuple(flat_kernel_avals), kernel_in_tree, - interpret=interpret) + jaxpr, consts = _trace_kernel_to_jaxpr( + kernel, kernel_src_info, grid_mapping, tuple(flat_kernel_avals), + kernel_in_tree, kernel_arg_transforms, interpret=interpret) for i_idx, o_idx in input_output_aliases.items(): if i_idx not in range(len(flat_in_avals)): raise ValueError( @@ -1440,6 +1684,7 @@ def wrapped(*args): index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands]) with pallas_core.interpret_mode_env(interpret): out_flat = pallas_call_p.bind( + *consts, *dynamic_grid_bounds, *index_args, *rest_args, diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 40caae76bd8f..3bf815cd3cdd 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -819,15 +819,13 @@ def debug_print_lowering_rule(ctx, *args, **params): run_scoped_p.multiple_results = True -def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any: - """Call the function with allocated references. +def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any: + """Calls the function with allocated references and returns the result. - Args: - f: The function that generates the jaxpr. - *types: The types of the function's positional arguments. - **kw_types: The types of the function's keyword arguments. + The positional and keyword arguments describe which reference types + to allocate for each argument. Each backend has its own set of reference + types in addition to :class:`jax.experimental.pallas.MemoryRef`. """ - flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) avals = [t.get_ref_aval() for t in flat_types] diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 9db5e4081239..2140962e2953 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -286,10 +286,6 @@ def lower_jaxpr_to_triton_module( raise NotImplementedError( "scratch memory not implemented in the Triton backend" ) - with grid_mapping.trace_env(): - jaxpr, _ = pe.dce_jaxpr( - jaxpr, [True] * len(jaxpr.outvars), instantiate=True - ) with _new_ir_context(), ir.Location.unknown(): module = ir.Module.create() attrs = module.operation.attributes @@ -2094,10 +2090,8 @@ def _dot_general_lowering( dimension_numbers, precision, preferred_element_type, - algorithm, - transpose_algorithm, ): - del preferred_element_type, algorithm, transpose_algorithm # Unused. + del preferred_element_type # Unused. ((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers assert batch_dims == ((), ()) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0a75128477ce..f1baf48f6857 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -62,7 +62,6 @@ from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src import sharding from jax._src.mesh import AbstractMesh from jax._src.sharding_impls import ( @@ -322,28 +321,11 @@ def _cpp_pjit_evict_fn(self): _cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192) -if xla_extension_version < 286: - def _get_cpp_global_cache(pjit_has_explicit_sharding): - if pjit_has_explicit_sharding: - return xc._xla.PjitFunctionCache() - else: - return _cpp_pjit_cache_fun_only - - def _pjit_explicit_sharding_and_layout( - in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, - device, backend) -> bool: - return (device is not None or - backend is not None or - any(not is_unspecified(i) for i in in_shardings_flat) or - any(not is_unspecified(o) for o in out_shardings_flat) or - any(i is not None for i in in_layouts_flat) or - any(o is not None for o in out_layouts_flat)) -else: - def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore - if contains_explicit_attributes: - return _cpp_pjit_cache_explicit_attributes - else: - return _cpp_pjit_cache_fun_only +def _get_cpp_global_cache(contains_explicit_attributes: bool): + if contains_explicit_attributes: + return _cpp_pjit_cache_explicit_attributes + else: + return _cpp_pjit_cache_fun_only def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @@ -364,35 +346,24 @@ def cache_miss(*args, **kwargs): return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) - if xla_extension_version >= 286: - cache_key = pxla.JitGlobalCppCacheKeys( - donate_argnums=jit_info.donate_argnums, - donate_argnames=jit_info.donate_argnames, - device=jit_info.device, backend=jit_info.backend, - in_shardings_treedef=jit_info.in_shardings_treedef, - in_shardings_leaves=jit_info.in_shardings_leaves, - out_shardings_treedef=jit_info.out_shardings_treedef, - out_shardings_leaves=jit_info.out_shardings_leaves, - in_layouts_treedef=jit_info.in_layouts_treedef, - in_layouts_leaves=jit_info.in_layouts_leaves, - out_layouts_treedef=jit_info.out_layouts_treedef, - out_layouts_leaves=jit_info.out_layouts_leaves, - use_resource_env=jit_info.use_resource_env) - cpp_pjit_f = xc._xla.pjit( - fun_name(fun), fun, cache_miss, jit_info.static_argnums, - jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore - pxla.cc_shard_arg, - _get_cpp_global_cache(cache_key.contains_explicit_attributes)) - else: - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - jit_info.in_shardings_leaves, jit_info.out_shardings_leaves, - jit_info.in_layouts_leaves, jit_info.out_layouts_leaves, - jit_info.device, jit_info.backend) - cpp_pjit_f = xc._xla.pjit( - fun_name(fun), fun, cache_miss, jit_info.static_argnums, - jit_info.static_argnames, jit_info.donate_argnums, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(has_explicit_sharding)) + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=jit_info.donate_argnums, + donate_argnames=jit_info.donate_argnames, + device=jit_info.device, backend=jit_info.backend, + in_shardings_treedef=jit_info.in_shardings_treedef, + in_shardings_leaves=jit_info.in_shardings_leaves, + out_shardings_treedef=jit_info.out_shardings_treedef, + out_shardings_leaves=jit_info.out_shardings_leaves, + in_layouts_treedef=jit_info.in_layouts_treedef, + in_layouts_leaves=jit_info.in_layouts_leaves, + out_layouts_treedef=jit_info.out_layouts_treedef, + out_layouts_leaves=jit_info.out_layouts_leaves, + use_resource_env=jit_info.use_resource_env) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, cache_key, tree_util.dispatch_registry, + pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun @@ -1302,6 +1273,9 @@ def _create_pjit_jaxpr( ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: del ignored_inline # just for explain_cache_miss + if config.no_tracing.value: + raise RuntimeError(f"re-tracing function {fun.f} for `jit`, but " + "'no_tracing' is set") with dispatch.log_elapsed_time( "Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): @@ -1752,26 +1726,18 @@ def call_impl_cache_miss(*args_, **kwargs_): jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) - if xla_extension_version >= 286: - cache_key = pxla.JitGlobalCppCacheKeys( - donate_argnums=donated_argnums, donate_argnames=None, - device=None, backend=None, - in_shardings_treedef=None, in_shardings_leaves=in_shardings, - out_shardings_treedef=None, out_shardings_leaves=out_shardings, - in_layouts_treedef=None, in_layouts_leaves=in_layouts, - out_layouts_treedef=None, out_layouts_leaves=out_layouts, - use_resource_env=resource_env is not None) - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], cache_key, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) - else: - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - in_shardings, out_shardings, in_layouts, out_layouts, None, None) - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(has_explicit_sharding))(*args) + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=donated_argnums, donate_argnames=None, + device=None, backend=None, + in_shardings_treedef=None, in_shardings_leaves=in_shardings, + out_shardings_treedef=None, out_shardings_leaves=out_shardings, + in_layouts_treedef=None, in_layouts_leaves=in_layouts, + out_layouts_treedef=None, out_layouts_leaves=out_layouts, + use_resource_env=resource_env is not None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], cache_key, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index b69e78fe9ddf..b43cad745f3a 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -307,6 +307,21 @@ def is_fully_replicated(self) -> bool: def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) + def normalized_spec(self, ndim: int) -> PartitionSpec: + out = [] # type: ignore + for p in self._parsed_pspec: + if p is None: + raise ValueError("UNCONSTRAINED is not supported yet.") + if not p: + out.append(None) + elif isinstance(p, tuple) and len(p) == 1: + out.append(p[0]) + else: + out.append(p) + if len(out) < ndim: + out.extend([None] * (ndim - len(out))) + return PartitionSpec(*out) + def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 7970440d29a6..74024b449edd 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -97,11 +97,35 @@ def __call__(self, in_avals: Sequence[core.AbstractValue], _discharge_rules: dict[core.Primitive, DischargeRule] = {} +class PartialDischargeRule(Protocol): + """A partial discharge rule. + + Exactly like a discharge rule only it accepts a `should_discharge` + argument that indicates which inputs should be discharged and the + return value returns a tuple of which the first element is the new + inputs or none but only the ones that correspond to `True` entries + in `should_charge`. + """ + + def __call__(self, should_discharge: Sequence[bool], + in_avals: Sequence[core.AbstractValue], + out_avals: Sequence[core.AbstractValue], *args: Any, + **params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]: + ... + +_partial_discharge_rules: dict[core.Primitive, PartialDischargeRule] = {} + def register_discharge_rule(prim: core.Primitive): def register(f: DischargeRule): _discharge_rules[prim] = f return register +def register_partial_discharge_rule(prim: core.Primitive): + def register(f: PartialDischargeRule): + _partial_discharge_rules[prim] = f + return register + + def _eval_jaxpr_discharge_state( jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any], *args: Any): @@ -116,22 +140,33 @@ def _eval_jaxpr_discharge_state( if d and isinstance(v.aval, AbstractRef)} for eqn in jaxpr.eqns: + should_discharge = [id(v.aval) in refs_to_discharge for v in eqn.invars] if eqn.primitive is core.mutable_array_p: [invar], [outvar] = eqn.invars, eqn.outvars ans = env.read(invar) refs_to_discharge.add(id(outvar.aval)) - elif (any(id(v.aval) in refs_to_discharge for v in eqn.invars) - or core.internal_mutable_array_effect in eqn.effects ): - if eqn.primitive not in _discharge_rules: + elif (any(should_discharge) + or core.internal_mutable_array_effect in eqn.effects + ): + if eqn.primitive in _partial_discharge_rules: + rule: DischargeRule = partial(_partial_discharge_rules[eqn.primitive], should_discharge) + elif eqn.primitive in _discharge_rules: + rule = _discharge_rules[eqn.primitive] + else: raise NotImplementedError("No state discharge rule implemented for " f"primitive: {eqn.primitive}") invals = map(env.read, eqn.invars) in_avals = [v.aval for v in eqn.invars] out_avals = [v.aval for v in eqn.outvars] - new_invals, ans = _discharge_rules[eqn.primitive]( + new_invals, ans = rule( in_avals, out_avals, *invals, **eqn.params) - for new_inval, invar in zip(new_invals, eqn.invars): + for invar, should, new_inval in zip(eqn.invars, should_discharge, new_invals): if new_inval is not None: + if not should: + raise ValueError( + f"Did not ask for inval to be discharged but it was. ({invar=}," + f" {new_inval=})" + ) env.write(invar, new_inval) # type: ignore[arg-type] else: # Default primitive rule, similar to `core.eval_jaxpr`. Note that here diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index acf1c7216240..cb653547baff 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -256,3 +256,10 @@ def get_indexer_shape(self) -> tuple[int | Array, ...]: # In NDIndexers, the int_indexer_shape is *always* at the front of the # result. return (*self.int_indexer_shape, *slice_shape) + + def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]: + del shape # Unused + return self.get_indexer_shape() + + def transform_dtype(self, dtype): + return dtype diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 773302c9f637..b91c2a13cf7c 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -22,6 +22,7 @@ from jax._src import core from jax._src import dispatch from jax._src import pretty_printer as pp +from jax._src import traceback_util from jax._src import tree_util from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -49,6 +50,7 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +traceback_util.register_exclusion(__file__) ## get/swap/addupdate implementations @@ -175,25 +177,17 @@ def _shape_after_transforming( shape: tuple[int | Array, ...], transforms: tuple[Transform, ...] ) -> tuple[int | Array, ...]: for transform in transforms: - match transform: - case indexing.NDIndexer(): - # Run some simple checks that all the indexers have consistent shapes - if not transform.is_dynamic_size: - assert transform.shape == shape, (transform.shape, shape) - shape = transform.get_indexer_shape() - case RefBitcaster(): - shape = transform.shape - case _: - raise ValueError(f"Unsupported transform: {transform}") + shape = transform.transform_shape(shape) # type: ignore + assert shape is not None return shape def _dtype_after_transforming( dtype: Any, transforms: tuple[Transform, ...] ) -> Any: - for transform in reversed(transforms): - if isinstance(transform, RefBitcaster): - return transform.dtype + for transform in transforms: + dtype = transform.transform_dtype(dtype) + assert dtype is not None return dtype @@ -334,7 +328,7 @@ def pp_transform(context: core.JaxprPpContext, transform: Transform) -> pp.Doc: case RefBitcaster(): return pp_bitcaster(context, transform) case _: - raise ValueError(f"Unsupported transform: {transform}") + return pp.text(f"[{transform}]") def _pp_transforms( diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index e64d6258a808..7ec271b9af33 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -18,12 +18,14 @@ from collections.abc import Sequence import dataclasses import math -from typing import Any, Union +from typing import Any, Union, Protocol +from jax._src.typing import DTypeLike from jax._src import core from jax._src import dtypes from jax._src import effects from jax._src import pretty_printer as pp +from jax._src import traceback_util from jax._src import tree_util from jax._src.state import indexing from jax._src.typing import Array @@ -33,6 +35,7 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +traceback_util.register_exclusion(__file__) _ref_effect_color = pp.Color.GREEN @@ -105,8 +108,39 @@ def tree_unflatten(cls, metadata, arrays): assert not arrays return cls(*metadata) + def transform_shape( + self, shape: tuple[int | Array, ...] | None + ) -> tuple[int | Array, ...] | None: + del shape # Unused + return self.shape + + def transform_dtype(self, dtype): + del dtype # Unused + return self.dtype + + +class Transform(Protocol): + + def transform_shape( + self, shape: tuple[int | Array, ...] | None + ) -> tuple[int | Array, ...] | None: + """Transform the shape. + + Can return None if the input shape is not known, but must return a concrete + result when the input shape is known. + """ + return shape + + def transform_dtype( + self, dtype: DTypeLike | None + ) -> DTypeLike | None: + """Transform the dtype. + + Can return None if the input dtype is not known, but must return a concrete + result when the input dtype is known. + """ + return dtype -Transform = indexing.NDIndexer | RefBitcaster @dataclasses.dataclass class RefIndexer: @@ -122,30 +156,51 @@ def __getitem__(self, slc): return TransformedRef(self.ref_or_view, (indexer,)) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class TransformedRef: ref: Any transforms: tuple[Transform, ...] @property def is_dynamic_size(self): - return self.transforms[-1].is_dynamic_size + return any(not isinstance(i, int) for i in self.shape) @property def shape(self) -> tuple[int | Array, ...]: - assert ( - len(self.transforms) > 0 - ), "Should not be able to create a trivial TransformedRef" - if isinstance(self.transforms[-1], indexing.NDIndexer): - return self.transforms[-1].get_indexer_shape() - return self.transforms[-1].shape + unprocessed, shape = 0, None + # We first go backwards to find the first transform that knows its output + # shape. It's possible none of them do! + for unprocessed, t in enumerate(reversed(self.transforms), 1): + if (shape := t.transform_shape(None)) is not None: + unprocessed -= 1 + break + if shape is None: + shape = self.ref.shape + if not unprocessed: + return shape + # If there are any unprocessed transforms left, we apply them to the shape + # we've found previuously. + for t in self.transforms[-unprocessed:]: + shape = t.transform_shape(shape) + assert shape is not None + return shape @property def dtype(self): - for transform in reversed(self.transforms): - if isinstance(transform, RefBitcaster): - return transform.dtype - return self.ref.dtype + # The structure of this method is analogous to `shape`. See comments there. + unprocessed, dtype = 0, None + for unprocessed, t in enumerate(reversed(self.transforms), 1): + if (dtype := t.transform_dtype(None)) is not None: + unprocessed -= 1 + break + if dtype is None: + dtype = self.ref.dtype + if not unprocessed: + return dtype + for t in self.transforms[-unprocessed:]: + dtype = t.transform_dtype(dtype) + assert dtype is not None + return dtype @property def at(self) -> RefIndexer: diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 97b6a2cfd32a..ba340b8ba537 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -274,6 +274,7 @@ def _tpu_custom_call_lowering( def _lower_tpu_kernel( module: ir.Module, hardware_generation: int, + target_shape: tuple[int, int], ) -> ir.Module: """Runs MLIR passes lowering the given module to an MLIR module. @@ -283,6 +284,7 @@ def _lower_tpu_kernel( Args: module: The MLIR module to lower. hardware_generation: The TPU hardware generation to target. + target_shape: The target shape of (sublane_count, lane_count). Returns: An MLIR module implementing the kernel. @@ -312,11 +314,16 @@ def _lower_tpu_kernel( pipeline.run(module.operation) dump_mlir(module, "post-hlo-conversion") + sl_cnt, l_cnt = target_shape # Note: we don't pass the TpuTilingFlags here, since we don't know the # tiling decisions made by the compiler / what flags are enabled at this # point, so we assume everything can be tiled up to default tiling. pipeline = [ - f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})" + "func.func(tpu-infer-memref-layout{" + f" hardware-generation={hardware_generation}" + f" sublane-count={sl_cnt}" + f" lane-count={l_cnt}" + "})" ] pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") pipeline.run(module.operation) @@ -357,14 +364,16 @@ def _lower_tpu_kernel( dump_mlir(module, "post-canonicalize-mosaic") pipeline = [ - "func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})", + ( + "func.func(tpu-infer-vector-layout{" + f" sublane-count={sl_cnt} lane-count={l_cnt}" + "})" + ), ] pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") pipeline.run(module.operation) dump_mlir(module, "post-infer-vector-layout") - sl_cnt = 8 - l_cnt = 128 mxu_size = 128 if hardware_generation < 6 else 256 pipeline = [ "func.func(tpu-apply-vector-layout{" @@ -414,7 +423,10 @@ def _lower_mosaic_module_to_asm( "tpu_custom_call cannot be lowered on a machine without TPUs " "when mosaic_use_python_pipeline=True.") hardware_generation = int(device_kind[len("TPU v")]) - module = _lower_tpu_kernel(module, hardware_generation) + # TODO(b/369418606): Infer the target shape from the hardware generation. + module = _lower_tpu_kernel( + module, hardware_generation, target_shape=(8, 128) + ) needs_hlo_passes = False needs_layout_passes = False prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 49162809a325..7d60f62e230f 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Primitives for calling Python functions on the host from JAX accelerator code. +"""Backwards compatibility shims for the deprecated host_callback APIs. .. warning:: The host_callback APIs are deprecated as of March 20, 2024. @@ -19,2128 +19,16 @@ `new JAX external callbacks `_ See https://github.com/jax-ml/jax/issues/20385. -This module introduces the host callback functions :func:`call`, -:func:`id_tap`, and :func:`id_print`, that send their arguments from the device -to the host and invoke user-defined Python functions on the host, optionally -returning results back to the device computation. - -We show below how these functions can be used. We start with :func:`call`, -and we discuss examples of calling from JAX to arbitrary Python functions -on the CPU, e.g., to use NumPy CPU custom kernels. Then we -show uses of :func:`id_tap` and :func:`id_print`, which have the restriction -that they cannot return values from the host to the device. -These primitives are generally faster -because they are executed asynchronously with the device code. -In particular, they can be used to tap into and to debug JAX code. - -Using :func:`call` to call a host function and return results to device ------------------------------------------------------------------------ - -Use :func:`call` to invoke a computation on the host and return -NumPy arrays to the device computation. -Host computation is useful, e.g., when a device computation needs some data -that requires I/O on the host, or it needs a library that is available on the -host and you do not want to code it in JAX. -For example, eigen decomposition for general matrices in JAX does not work on TPU. -We can call the Numpy implementation from any JAX accelerator computation, -using a host computation:: - - # This function runs on the host - def host_eig(m: np.ndarray) -> np.ndarray: - return np.linalg.eigvals(m) - - # This function is used in JAX - def device_fun(m): - # We send "m" to the host, asking it to call "host_eig" and return the result. - # We have to specify the result shape and dtype, either in the form of an - # example return value or any object that has `shape` and `dtype` attributes, - # e.g., a NumPy array or a `jax.ShapeDtypeStruct`. - return hcb.call(host_eig, m, - # Given an input of shape (..., d, d), eig output has shape (..., d) - result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype)) - - -The :func:`call` function and the Python host function both take a single argument -and return a single result, but those can be pytrees. Note that we must tell -the :func:`call` what shape and dtype to expect from the host invocation, using -the ``result_shape`` keyword argument. -This is important because the device code is compiled with that expectation. -There will be an error raised at runtime if the actual invocation produces a -different result shape. In general, **such errors and also exceptions raised -by the host computation may be difficult to debug**. See the Debugging section -below. -This is a problem for :func:`call` but not for :func:`id_tap` because for the -latter the device code does not expect a returned value. - -The :func:`call` API can be used inside a jit or pmap computation or inside -cond/scan/while control flow. When used inside :func:`jax.pmap`, there will be -separate calls to the host from each of the participating devices:: - - def host_sin(x, *, device): - # The ``device`` argument is passed due to ``call_with_device=True`` below. - print(f"Invoking host_sin with {x.shape} on {device}") - return np.sin(x) - - # Use pmap to run the computation on two devices - jax.pmap(lambda x: hcb.call(host_sin, x, - result_shape=x, - # Ask that the `host_sin` function be passed `device=dev` - call_with_device=True))( - np.ones((2, 4), dtype=np.float32)) - - # prints (in arbitrary order) - # Invoking host_sin with (4,) on cpu:0 - # Invoking host_sin with (4,) on cpu:1 - -Note that :func:`call` does not support any JAX transformations, but as we -show below one can make use of the -existing support for `Custom differentiation in JAX `_. - -Using :func:`id_tap` to call a Python function on the host, with no returned values ------------------------------------------------------------------------------------ - -The :func:`id_tap` and :func:`id_print` are special cases of :func:`call`, when -you just want the side effects of your Python callback. These functions have -the advantage that once the arguments have been sent to the host, the device -computation can proceed without waiting for the Python callback to return. -For :func:`id_tap` you can specify your Python callback to be called, while -:func:`id_print` uses a built-in callback that prints the arguments to -`stdout` on the host. -The Python function passed -to :func:`id_tap` takes two positional arguments (the value tapped -from the device computation along with a ``transforms`` tuple, -described below). Optionally, the function may be passed a keyword argument -``device`` with the Device from which the value was tapped. - -A few examples:: - - def host_func(arg, transforms): - ...do something with arg... - - # calls host_func(2x, []) on host - id_tap(host_func, 2 * x) - - # calls host_func((2x, 3x), []) - id_tap(host_func, (2 * x, 3 * x)) # The argument can be a pytree - - # calls host_func(2x, [], device=jax.devices()[0]) - id_tap(host_func, 2 * x, tap_with_device=True) # Pass the device to the tap - - # calls host_func(2x, [], what='activation') - id_tap(functools.partial(host_func, what='activation'), 2 * x) - - # calls host_func(dict(x=x, y=y), what='data') - id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y)) - -The above examples can all be adapted to use :func:`id_print` instead, with -the difference that :func:`id_print` prints on the host the positional argument, -along with any additional kwargs and the automatic kwarg ``transforms``. - -Using :func:`barrier_wait` to wait until all callbacks have executed --------------------------------------------------------------------- - -If your Python callbacks have side-effects you may need to wait until the -computation has finished to ensure that the side-effects have been observed. -You can use the :func:`barrier_wait` function for that purpose:: - - accumulator = [] - def host_log(arg, transforms): - # We just record the arguments in a list - accumulator.append(arg) - - - def device_fun(x): - id_tap(host_log, x) - id_tap(host_log, 2. * x) - - jax.jit(device_fun)(1.) - jax.jit(device_fun)(1.) - - # At this point, we have started two computations, each with two - # taps, but they may not have yet executed. - barrier_wait() - # Now we know that all the computations started before `barrier_wait` - # on all devices, have finished, and all the callbacks have finished - # executing. - -Note that :func:`barrier_wait` will start one -tiny computation with one tap on each of the `jax.local_devices()` and -will wait for all these taps to be received. - -An alternative to using :func:`barrier_wait` is to just wait for the end -of the computation, if all the callbacks are :func:`call`:: - - accumulator = p[] - def host_log(arg): - # We just record the arguments in a list - accumulator.append(arg) - return 0. # return something - - - def device_fun(c): - y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32)) - z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32)) - return y + z # return something that uses both results - - res1 = jax.jit(device_fun)(1.) - res2 = jax.jit(device_fun)(1.) - res1.block_until_ready() - res2.block_until_ready() - -Behavior under parallelization transformations ----------------------------------------------- - -In presence of :func:`jax.pmap` the code will run on multiple devices and -each device will tap its values independently. -It may be helpful to use the ``tap_with_device`` option for :func:`id_print` -or :func:`id_tap`, so that you see which device is sending which data:: - - jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.]) - # device=cpu:0 what=x,x^2: (3., 9.) # from the first device - # device=cpu:1 what=x,x^2: (4., 16.) # from the second device - -When using :func:`jax.pmap` with multiple devices on multiple hosts, every -host will receive callbacks from all of its local devices, with an operand -that corresponds to each device slice. For a -:func:`call`, the callback must return to each device only the slice of the -result that pertains to the corresponding device. - -When using the experimental :func:`pjit.pjit` the code will run on multiple -devices on different shards of the input. The current implementation of -host callbacks will ensure that a single device will collect and outfeed -the entire operand, in a single callback. The callback function is supposed -to return the entire array, which will then be sent in a single infeed to the -same device that issued the outfeed. This device is then responsible for -sending the required shards to the other devices:: - - with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]): - pjit.pjit(power3, in_shardings=(P("d"),), - out_shardings=(P("d"),))(np.array([3., 4.])) - - # device=TPU:0 what=x,x^2: ( [3., 4.], - # [9., 16.] ) - -Note that the collection of the operand on one device may result in OOM if -the operand was sharded across devices. - -When using :func:`pjit.pjit` with multiple devices on multiple hosts, only -the host for the device 0 (w.r.t. the mesh) will receive the callback, with -the operand collected -from all participating devices on all hosts. For a :func:`call`, the callback -must return the entire array for all devices on all hosts. - -Behavior under JAX autodiff transformations -------------------------------------------- - -When used under a JAX autodiff transformation, the host callback functions -operate on the primal values only. Consider the following example:: - - def power3(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2") - return y * x - - power3(3.) - # what: x,x^2 : (3., 9.) - -(You can see these examples tested in `host_callback_test.HostCallbackTapTest.test_tap_transforms`.) - -When used under :func:`jax.jvp` there will be one callback with the primal -values only:: - - jax.jvp(power3, (3.,), (0.1,)) - # what: x,x^2 : (3., 9.) - -Similarly for :func:`jax.grad`, we get a callback from the forward computation -only:: - - jax.grad(power3)(3.) - # what: x,x^2 : (3., 9.) - -If you want to invoke the callback on the tangents during a :func:`jax.jvp`, -you can use a custom_jvp. For example, you can define a function that does -nothing interesting except that its custom_jvp will print the tangents:: - - @jax.custom_jvp - def print_tangents(arg): - return None - - @print_tangents.defjvp - def print_tangents_jvp(primals, tangents): - arg_dot, = tangents - hcb.id_print(arg_dot, what="tangents") - return primals, tangents - -Then you use this function in the places where you want to tap the tangents:: - - def power3_with_tangents(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2") - print_tangents((x, y)) - return y * x - - jax.jvp(power3_with_tangents, (3.,), (0.1,)) - # what: x,x^2 : (3., 9.) - # what: tangents : (0.1, 0.6) - -You can do a similar thing for the cotangents during :func:`jax.grad`. This -time you must be careful to use in the rest of the computation the values whose -cotangents you want to tap. Hence we make the ``print_cotangents`` return -its argument:: - - @jax.custom_vjp - def print_cotangents(arg): - # Must return the argument for which we want the cotangent. - return arg - - # f_fwd: a -> (b, residual) - def print_cotangents_fwd(arg): - return print_cotangents(arg), None - # f_bwd: (residual, CT b) -> [CT a] - def print_cotangents_bwd(residual, ct_b): - hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream) - return ct_b, - - print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd) - - def power3_with_cotangents(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) - (x1, y1) = print_cotangents((x, y)) - # Must use the output of print_cotangents - return y1 * x1 - - jax.grad(power3_with_cotangents)(3.) - # what: x,x^2 : (3., 9.) - # what: cotangents : (9., 3.) - -If you use :func:`ad_checkpoint.checkpoint` to rematerialize the residuals -for the backward pass, then the callbacks from the primal computation will -be called twice:: - - jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.) - # what: x,x^2 : (3., 9.) - # what: x,x^2 : (27., 729.) - # what: x,x^2 : (3., 9.) - -The callbacks are, in order from: the primal computation of the inner ``power3``, -the primal computation of the outer ``power3``, and the rematerialization -of the residuals for the inner ``power3``. - - -Behavior under jax.vmap ------------------------ - -The host callback functions :func:`id_print` and :func:`id_tap` support the -vectorization transformation :func:`jax.vmap`. - -For :func:`jax.vmap` the arguments to the callback are batched, -and the callback function is -passed an additional special ``transforms`` containing a list of transformation descriptors -in the form ``("batch", {"batch_dims": ...})``, where ``...``` denotes the -batched dimensions for the tapped values (one entry per argument, ` -`None`` denotes an argument that was broadcast). - - jax.vmap(power3)(np.array([2., 3.])) - # transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.]) - -See documentation for :func:`id_tap`, :func:`id_print`, and :func:`call`. - -For more usage example, see tests/host_callback_test.py. - -Using :func:`call` to call a TensorFlow function, with reverse-mode autodiff support ------------------------------------------------------------------------------------- - -Another possible use for host computation is to invoke a library written for -another framework, such as TensorFlow. -In this case it becomes interesting to support JAX autodiff for host callbacks -by deferring to the autodiff mechanism in TensorFlow, -using the :func:`jax.custom_vjp` mechanism. - -This is relatively easy to do, once one understands both the JAX custom VJP -and the TensorFlow autodiff mechanisms. -The code for how this can be done is shown in the ``call_tf_full_ad`` -function in `host_callback_to_tf_test.py `_. -This example supports arbitrary higher-order differentiation as well. - -Note that if you just want to call TensorFlow functions from JAX, you can also -use the `jax2tf.call_tf function `_. - -Using :func:`call` to call a JAX function on another device, with reverse-mode autodiff support ------------------------------------------------------------------------------------------------- - -It should not be surprising that we can use host computation to invoke a JAX -computation on another device. The arguments are sent from the accelerator to -the host, and then to the outside device on which the JAX host -computation will run, and then the results are sent back to the original accelerator. - -The code for how this can be done is shown in the ``call_jax_other_device function`` -in `host_callback_test.py `_. - -Low-level details and debugging -------------------------------- - -The host callback functions will be executed for each device in the order in -which the send operations were performed on the device. - -The host callback functions for multiple devices may be interleaved. -The data from the devices is received by separate threads managed by the JAX -runtime (one thread per device). The runtime maintains a buffer of -configurable size (see the flag ``--jax_host_callback_max_queue_byte_size``). -When the buffer is full, all the receiving threads are paused -which eventually pauses the computation on devices. The runtime has one -additional thread for each device to invoke the Python user functions with the -received data. If the processing of the callbacks is slow, it may actually -lead to the runtime buffer filling up, and eventually pausing the computation -on the devices when they need to send something. -For more details on the outfeed receiver runtime mechanism see -`runtime code -`_. - -In order to pause the execution until all data from computations already -started on devices has arrived and has been processed, use :func:`barrier_wait`. - -Exceptions from the user-defined callback functions are logged along with their -stack traces, but the receiving threads are not stopped. Instead the last -exception is recorded and the subsequent :func:`barrier_wait` will -raise :exc:`CallbackException` if any exception had occurred -in one of the tap functions. This exception will include the text and the -stack trace of the last exception encountered. - -One further complication arises for callback functions that must return -results to the call origin device, such as :func:`call()`. This is handled -differently on CPU/GPU devices compared to TPU devices. - -On CPU/GPU devices, in order to avoid the device computation -being stuck waiting for a result that will never arrive, in case of any -error during the processing of the callback (whether raised by the user-code -itself or due to a mismatch of the returned value and the expected return_shape) -we send the device a "fake" result of shape ``int8[12345]``. -This will make the device -computation abort because the received data is different than the one that -it expects. On CPU the runtime will crash with a distinctive error message: - -``` -Check failed: buffer->length() == buffer_length (12345 vs. ...) -``` - -On GPU, the failure is more user-friendly and will be surfaced to the Python -program as: - -``` -RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ... -``` - -To debug the underlying cause for these messages, see the Debugging section. - -On TPU devices, there is currently no shape check for infeed, so we take the -safer route of not sending this fake result in case of errors. This means -that the computation will hang, and no exception will be raised (but any -exceptions in the callback functions will still appear in the logs). - -The current implementation uses the outfeed mechanism provided by XLA. The -mechanism itself is quite primitive in the sense that a receiver must know -exactly the shape of each incoming packet, and how many packets are expected. -This makes it hard to use for multiple kinds of data in the same computation, -and it is practically impossible to use it under conditionals or in loops -of non-constant iteration count. Furthermore, code that uses the outfeed -mechanism directly cannot be transformed by JAX. All these limitations are -addressed by the host callback functions. The tapping API introduced here -makes it easy to share the outfeed mechanism for multiple purposes, while -supporting all transformations. - -**Note that after you have used the host callback functions, you cannot -use lax.outfeed directly**. You may want to :func:`stop_outfeed_receiver` -if you later need to use lax.outfeed. - -Since the actual calls to your callback functions are made from the C++ -receiver, it may be hard to debug the calls. In particular, the stack trace -will not include the calling code. You can use the flag -``jax_host_callback_inline`` (or the environment variable -``JAX_HOST_CALLBACK_INLINE``) to ensure that the calls to the callbacks are -inlined. This works only if the calls are outside a staging context -(:func:`~jax.jit` or a control-flow primitive). - -The C++ `receiver -`_ -is started automatically on the first call to :func:`id_tap`. In order to stop -it properly, upon start an ``atexit`` handler is registered to call -:func:`barrier_wait` with the logging name "at_exit". - -There are a few environment variables that you can use to turn on logging -for the C++ outfeed `receiver backend -`_. - - * ``TF_CPP_MIN_LOG_LEVEL=0``: will turn on INFO logging, needed for all below. - * ``TF_CPP_MIN_VLOG_LEVEL=3``: will make all VLOG logging up to level 3 behave - like INFO logs. This may be too much, but you will see which modules are - logging relevant info, and then you can select which modules to log from. - * ``TF_CPP_VMODULE==3`` (the module name can be either C++ or - Python, without the extension). - -You should also use the ``--verbosity=2`` flag so that you see the logs -from Python. - -For example, you can try to enable logging in the ``host_callback`` module: -``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple`` - -If you want to enable logging in lower-level implementation modules try: -``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple`` - -(For bazel tests use --test_arg=--vmodule=... - -Still to do: - * More performance tests. - * Explore implementation with outside compilation for TPU. - * Explore implementation with XLA CustomCall for CPU and GPU. - """ from __future__ import annotations -import atexit -import enum -from collections.abc import Callable, Sequence -import functools -import itertools -import logging -import math -import threading -import traceback -from typing import Any, cast - -import jax -from jax._src import api -from jax._src import core -from jax._src import config -from jax import custom_derivatives -from jax._src import dtypes -from jax import lax -from jax.experimental import pjit -from jax.experimental import io_callback -from jax._src.interpreters import ad, batching, pxla -from jax._src.interpreters import mlir -from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import xla -from jax._src import ad_checkpoint -from jax._src import compiler -from jax._src import dispatch -from jax._src import pretty_printer as pp -from jax._src import sharding_impls -from jax._src import source_info_util -from jax._src import tree_util -from jax._src import util -from jax._src import xla_bridge as xb -from jax._src.lib import xla_client -from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version -from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import hlo - -import numpy as np - - -_HOST_CALLBACK_INLINE = config.bool_flag( - 'jax_host_callback_inline', - config.bool_env('JAX_HOST_CALLBACK_INLINE', False), - help='Inline the host_callback, if not in a staged context.' -) -_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.int_flag( - 'jax_host_callback_max_queue_byte_size', - config.int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)), - help=('The size in bytes of the buffer used to hold outfeeds from each ' - 'device. When this capacity is reached consuming outfeeds from the ' - 'device is paused, thus potentially pausing the device computation, ' - 'until the Python callback consume more outfeeds.'), - lower_bound=int(16 * 1e6) -) -_HOST_CALLBACK_OUTFEED = config.bool_flag( - 'jax_host_callback_outfeed', - config.bool_env('JAX_HOST_CALLBACK_OUTFEED', False), - help=( - 'Use outfeed implementation for host_callback, even on CPU and GPU. ' - 'If false, use the CustomCall implementation. ' - 'Has no effect on TPU, since only the outfeed mechanism is implemented.' - ) -) -_HOST_CALLBACK_LEGACY = config.bool_flag( - 'jax_host_callback_legacy', - config.bool_env('JAX_HOST_CALLBACK_LEGACY', True), - help=( - 'Use old implementation of host_callback, documented in the module docstring.' - 'If False, use the jax.experimental.io_callback implementation. ' - 'See https://github.com/jax-ml/jax/issues/20385.' - ) -) - -logger = logging.getLogger(__name__) - - -def _use_outfeed(platform: str) -> bool: - return (platform in ("tpu", "gpu", "cuda", "rocm") or - _HOST_CALLBACK_OUTFEED.value) - - -def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): - """Should be called whenever outfeed (or infeed) will be used.""" - if xb.using_pjrt_c_api(backend): - raise NotImplementedError( - "host_callback functionality isn't supported with PJRT C API. " - "See https://jax.readthedocs.io/en/latest/debugging/index.html and " - "https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html" - " for alternatives. Please file a feature request at " - "https://github.com/jax-ml/jax/issues if none of the alternatives are " - "sufficient.") - - -xops = xla_client._xla.ops - -XlaOp = xla_client.XlaOp -XlaShape = xla_client.Shape -XlaBuilder = xla_client.XlaBuilder -XlaDevice = xla_client.Device -XlaLocalClient = xla_client.Client -DType = Any - -class CallbackFlavor(enum.Enum): - """Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False. - - See https://github.com/jax-ml/jax/issues/20385. - """ - IO_CALLBACK = 1 # uses jax.experimental.io_callback - PURE = 2 # uses jax.pure_callback - DEBUG = 3 # uses jax.debug.callback, valid only when there are no results - - -def _deprecated_id_tap(tap_func, - arg, - *, - result=None, - tap_with_device=False, - device_index=0, - callback_flavor=CallbackFlavor.IO_CALLBACK, - **kwargs): - """Host-callback tap primitive, like identity function with a call to ``tap_func``. - - .. warning:: - The host_callback APIs are deprecated as of March 20, 2024. - The functionality is subsumed by the - `new JAX external callbacks `_ - See https://github.com/jax-ml/jax/issues/20385. - - ``id_tap`` behaves semantically like the identity function but has the - side-effect that a user-defined Python function is called with the runtime - value of the argument. - - Args: - tap_func: tap function to call like ``tap_func(arg, transforms)``, with - ``arg`` as described below and where ``transforms`` is the sequence of - applied JAX transformations in the form ``(name, params)``. If the - `tap_with_device` optional argument is True, then the invocation also - includes the device from which the value is tapped as a keyword argument: - ``tap_func(arg, transforms, device=dev)``. - arg: the argument passed to the tap function, can be a pytree of JAX - types. - result: if given, specifies the return value of ``id_tap``. This value is - not passed to the tap function, and in fact is not sent from the device to - the host. If the ``result`` parameter is not specified then the return - value of ``id_tap`` is ``arg``. - tap_with_device: if True then the tap function is invoked with the - device from which the tap originates as a keyword argument. - device_index: specifies from which device the tap function is invoked in a - SPMD program. Works only when using the outfeed implementation mechanism, - i.e., does not work on CPU unless --jax_host_callback_outfeed=True. - callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies - the flavor of callback to use. - See https://github.com/jax-ml/jax/issues/20385. - - Returns: - ``arg``, or ``result`` if given. - - The order of execution is by data dependency: after all the arguments and - the value of ``result`` if present, are computed and before the returned - value is used. At least one of the returned values of ``id_tap`` must be - used in the rest of the computation, or else this operation has no effect. - - Tapping works even for code executed on accelerators and even for code under - JAX transformations. - - For more details see the :mod:`jax.experimental.host_callback` module documentation. - """ - if kwargs: - msg = ( - "Support for **kwargs in ``id_tap`` has been removed. Instead, " - "pre-apply keyword arguments, either by using a closure or by passing " - "``functools.partial(tap_func, **kwargs)``.") - raise TypeError(msg) - - if result is not None: - flat_results, _ = tree_util.tree_flatten(result) - for r in flat_results: - dispatch.check_arg(r) - - call_res = _call( - tap_func, - arg, - call_with_device=tap_with_device, - result_shape=None, - identity=True, - device_index=device_index, - callback_flavor=callback_flavor) - - if result is not None: - return result - else: - return call_res - - -def _deprecated_id_print(arg, - *, - result=None, - tap_with_device=False, - device_index=0, - output_stream=None, - threshold=None, - callback_flavor=CallbackFlavor.IO_CALLBACK, - **kwargs): - """Like :func:`id_tap` with a printing tap function. - - .. warning:: - The host_callback APIs are deprecated as of March 20, 2024. - The functionality is subsumed by the - `new JAX external callbacks `_ - See https://github.com/jax-ml/jax/issues/20385. - - On each invocation of the printing tap, the ``kwargs`` if present - will be printed first (sorted by keys). Then arg will be printed, - with the arrays stringified with ``numpy.array2string``. - - See the :func:`id_tap` documentation. - - Additional keyword arguments: - - * ``tap_with_device`` if True, will print also the device from which - the value originates. - * ``output_stream`` if given then it will be used instead of the - built-in ``print``. The string will be passed as - ``output_stream.write(s)``. - * ``threshold`` is passed to ``numpy.array2string``. - * ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies - the flavor of callback to use. - See https://github.com/jax-ml/jax/issues/20385. - - For more details see the :mod:`jax.experimental.host_callback` module documentation. - """ - printer = functools.partial(_print_tap_func, - output_stream=output_stream, - threshold=threshold, **kwargs) - return _deprecated_id_tap( - printer, - arg, - result=result, - tap_with_device=tap_with_device, - device_index=device_index, - callback_flavor=callback_flavor) - - -def _deprecated_call(callback_func: Callable, arg, *, - result_shape=None, - call_with_device=False, - device_index=0, - callback_flavor=CallbackFlavor.IO_CALLBACK): - """Make a call to the host, and expect a result. - - .. warning:: - The host_callback APIs are deprecated as of March 20, 2024. - The functionality is subsumed by the - `new JAX external callbacks `_ - See https://github.com/jax-ml/jax/issues/20385. - - Args: - callback_func: The Python function to invoke on the host as - ``callback_func(arg)``. If the ``call_with_device`` optional argument is True, - then the invocation also includes the ``device`` kwarg with the device - from which the call originates: ``callback_func(arg, device=dev)``. This function - must return a pytree of numpy ndarrays. - - arg: the argument passed to the callback function, can be a pytree of JAX - types. - - result_shape: a value that describes the expected shape and dtype of the - result. This can be a numeric scalar, from which a shape and dtype are - obtained, or an object that has ``.shape`` and ``.dtype`` attributes. - If the result of the callback is a pytree, then ``result_shape`` should - also be a pytree with the same structure. In particular, ``result_shape`` - can be `()` or `None` if the function does not have any results. - The device code containing ``call`` is compiled with the expected result shape and dtype, - and an error will be raised at runtime if the actual ``callback_func`` - invocation returns a different kind of result. - - call_with_device: if True then the callback function is invoked with the - device from which the call originates as a keyword argument. - - device_index: specifies from which device the tap function is invoked in a - SPMD program. Works only when using the outfeed implementation mechanism, - i.e., does not work on CPU unless --jax_host_callback_outfeed=True. - callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies - the flavor of callback to use. - See https://github.com/jax-ml/jax/issues/20385. - - Returns: - the result of the ``callback_func`` invocation. - - For more details see the :mod:`jax.experimental.host_callback` module documentation. - """ - if (not _HOST_CALLBACK_LEGACY.value and - callback_flavor is CallbackFlavor.DEBUG and - result_shape is not None): - raise NotImplementedError( - "When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` " - "flavor of callback only when the `result_shape` is None. " - "See https://github.com/jax-ml/jax/issues/20385." - ) - return _call(callback_func, arg, result_shape=result_shape, - call_with_device=call_with_device, identity=False, - device_index=device_index, callback_flavor=callback_flavor) - - -# We need the wrapper function to have hash and equality defined since it is -# used as a primitive keyword argument, and we want a compilation cache hit if -# the user uses the same function twice. -class _CallbackWrapper: - def __init__(self, callback_func, identity, call_with_device): - self.callback_func = callback_func - self.identity = identity - self.call_with_device = call_with_device - if not _HOST_CALLBACK_LEGACY.value and call_with_device: - raise NotImplementedError( - "When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs" - " do not support `tap_with_device` and `call_with_device`. " - "See https://github.com/jax-ml/jax/issues/20385.") - - def __hash__(self): - return hash((self.callback_func, self.identity, self.call_with_device)) - - def __eq__(self, other): - return (self.callback_func == other.callback_func and - self.identity == other.identity and - self.call_with_device == other.call_with_device) - - def __call__(self, *args, **kwargs): - if _HOST_CALLBACK_LEGACY.value: - return self._call_legacy(*args, **kwargs) - else: - if self.identity: - # For id_tap, we pass empty transforms, for backwards compatibility - return self.callback_func(args[0], ()) - return self.callback_func(*args, **kwargs) - - def _call_legacy(self, arg, device, transforms): - if self.identity: - # For id_tap, we pass the transforms, for backwards compatibility - if self.call_with_device: - return self.callback_func(arg, transforms, device=device) - else: - return self.callback_func(arg, transforms) - else: - if self.call_with_device: - return self.callback_func(arg, device=device) - else: - return self.callback_func(arg) - - -# Helper function to implement both `call` and `id_tap`. The two cases are -# differentiated by the `identity` flag. -def _call(callback_func: Callable, - arg, - *, - result_shape=None, - call_with_device=False, - device_index=0, - identity=False, - callback_flavor=CallbackFlavor.IO_CALLBACK): - if _HOST_CALLBACK_LEGACY.value: - # Lazy initialization - _initialize_outfeed_receiver( - max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value) - api.check_callable(callback_func) - flat_args, arg_treedef = tree_util.tree_flatten(arg) - for arg_ in flat_args: - dispatch.check_arg(arg_) - # See definition of outside_call_p for what parameters it takes - params: dict[str, Any] = {} - # TODO: wrap function - params["callback"] = _CallbackWrapper(callback_func, identity, - call_with_device) - params["identity"] = identity - params["arg_treedef"] = arg_treedef - params["device_index"] = device_index - - if not identity: - # Turn abstract values into ShapesDtypeStruct - flat_results_shape, result_treedef = tree_util.tree_flatten(result_shape) - try: - flat_results_aval = [core.ShapedArray(np.shape(r), dtypes.dtype(r, canonicalize=True)) - for r in flat_results_shape] - except Exception: - msg = ("result_shape should be a pytree of values with structure " - "matching the expected result of the callback function. The " - "values must be either numeric scalars, or must have 'shape' and " - f"'dtype' attributes. Got {result_shape}") - raise ValueError(msg) - - params["result_treedef"] = result_treedef - params["flat_results_aval"] = tuple(flat_results_aval) - - if _HOST_CALLBACK_LEGACY.value: - flat_results = outside_call_p.bind(*flat_args, **params) - return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results) - else: - callback_device = jax.local_devices()[device_index] - sharding = jax.sharding.SingleDeviceSharding(callback_device) - callback_func = _CallbackWrapper(callback_func, identity, - call_with_device) - if callback_flavor is CallbackFlavor.DEBUG: - assert identity - jax.debug.callback(callback_func, arg) - return arg - elif callback_flavor is CallbackFlavor.PURE: - call_res = jax.pure_callback(callback_func, result_shape, arg, - sharding=sharding) - else: - call_res = io_callback(callback_func, result_shape, arg, - sharding=sharding, - ordered=True) - return call_res if not identity else arg - - -# We need the lock for when we use the CustomCall implementation of callbacks. -# The outfeed implementation is driven by a single thread from C++. -_print_tap_lock = threading.Lock() - - -def _print_tap_func( - arg, transforms, *, device=None, - output_stream=None, threshold=1024, **kwargs): - """The consumer for id_print. - - We provide this as a simple tapping function for printing. - This is **experimental** and may not want to add many features to it; - it should be easy for the user to roll their own printing function. - - Args: - device: the device from which the value originates (only if - ``tap_with_device`` was used for :func:`id_print`). - output_stream: a function whose `write` method is called with the strings to - be output. - threshold: the value of numpy.array2string threshold parameter. - **kwargs: all other keyword args are printed before printing `arg`. - """ - def emit_str(s: str): - if output_stream is not None: - output_stream.write(s + "\n") - else: - print(s) - - if transforms: - kwargs['transforms'] = [(name, params) if params else name - for name, params in transforms] - if device is not None: - kwargs['device'] = device - kv_pairs = " ".join([ - f"{k}: {v}" for k, v in sorted(kwargs.items()) - ]) - - def pp_val(arg) -> pp.Doc: - if isinstance(arg, tuple): - return pp.group(pp.concat([ - pp.text("( "), - pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])), - pp.text(" )") - ])) - elif isinstance(arg, list): - return pp.group(pp.concat([ - pp.text("[ "), - pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])), - pp.text(" ]") - ])) - elif isinstance(arg, dict): - return pp.group(pp.concat([ - pp.text("{ "), - pp.nest(2, pp.join(pp.brk(), [ - pp.text(f"{k}=") + pp_val(v) for k, v in sorted(arg.items()) - ])), - pp.text(" }") - ])) - elif isinstance(arg, np.ndarray): - return pp.text(np.array2string(arg, threshold=threshold)) - else: - return pp.text(str(arg)) - - with _print_tap_lock: - if kv_pairs: - emit_str(kv_pairs) - emit_str(str(pp_val(arg))) - - -def _values_to_avals(vals) -> Sequence[core.ShapedArray]: - return tuple(core.raise_to_shaped(core.get_aval(v)) for v in vals) - -### The outside_call primitive -""" -This primitive is used to implement the `call` and `id_tap` functions. -It takes several positional arguments that are the flattened -according to `arg_treedef`. -The result of the primitive is computed based on the `identity` parameter, -as follows: - - * if `identity` is True, then the results are the same as the - positional arguments of the primitive (except perhaps the last couple of - arguments, see `has_token`). In this case, `result_treedef` and - `flat_results_aval` are ignored, and `args_treedef` describes the result also. - * if `identity` is False, then the results are those from - the call to the outside computation: - - flatten(callback(arg_treedef.unflatten(args), device=...)) - - In this case, the callback results must match `result_treedef` - and `flat_results_aval`. - -It takes the following parameters: - - * callback: the function to invoke with the unflattened arguments, - the device and the transforms: `callback(arrays, device, transforms)` - * arg_treedef: the treedef for the argument. - * identity: see description above. - * result_treedef, flat_results_aval: describes the expected result of the - callback. Only used when not `identity`. - * transforms: a tuple of the transformations that have been applied. Each - element of the tuple is itself a tuple with the first element the name - of the transform. The remaining elements depend on the transform. For - example, for `batch`, the parameters are the dimensions that have been - batched, and for `mask` the logical shapes. These are unpacked by - _outside_call_run_callback before passing to the user function. - * has_token: a boolean, when True it means that the last positional argument - is the current token. In this case, the result of the primitive is - going to be the non-token positional arguments, along with the updated - token. The tokens and this parameter are added after all the JAX - transformations, just before staging XLA. - * device_index: an integer, denotes from which device the invocation is from. - Works only when using the outfeed implementation mechanism, i.e., does - not work on CPU unless --jax_host_callback_outfeed=True. -""" -outside_call_p = core.Primitive("outside_call") -outside_call_p.multiple_results = True -core.outfeed_primitives.add(outside_call_p) - - -def _outside_call_abstract_eval(*args_a: pe.AbstractValue, - identity, **params) -> Sequence[pe.AbstractValue]: - if identity: - # Do some validation here - assert "result_treedef" not in params - assert "flat_results_aval" not in params - return args_a - assert params["device_index"] is not None - assert params["result_treedef"] is not None - assert params["flat_results_aval"] is not None - flat_results_aval = params["flat_results_aval"] - if "has_token" in params and params["has_token"]: - assert len(args_a) >= 2 - return flat_results_aval + args_a[-2:] - else: - return flat_results_aval - - -outside_call_p.def_abstract_eval(_outside_call_abstract_eval) - - -def _outside_call_impl(*args, **params): - assert "has_token" not in params - if _HOST_CALLBACK_INLINE.value: - device_index = params["device_index"] - device = xb.devices()[device_index] - results = _outside_call_run_callback(args, device, send_infeed=False, **params) - return results - else: - # We use the jitted-version of the primitive even for eager execution, both - # so that we do not duplicate logic, but also so that all outfeed is received - # by the outfeed_listeners, in the same thread from a given device. If we were - # to process the tap here, it would be coming from the main thread. Also, - # even in eager execution some primitives, such as while, are compiled. - # It would be confusing to process a sequence "id_tap; while" in two - # different threads. - return dispatch.apply_primitive(outside_call_p, *args, **params) - - -outside_call_p.def_impl(_outside_call_impl) - - -def _with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs): - """Builds op_fn(*args, **kwargs) with sharding annotation.""" - builder.set_sharding(sharding_proto) - try: - return op_fn(*args, **kwargs) - finally: - builder.clear_sharding() - -def _outside_call_translation_rule(ctx, - avals_in, - avals_out, - *args_op: XlaOp, - has_token, - identity, - device_index, - flat_results_aval=(), - **params): - # We expect the current tokens at the end, inserted by _rewrite_jaxpr. - assert has_token - use_outfeed = _use_outfeed(ctx.platform) - assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering' - current_token = args_op[-2] - current_itoken = args_op[-1] - comp = ctx.builder - assert comp.get_shape(current_token).is_token() and comp.get_shape(current_itoken).is_token(), ( - "The last two arguments must be tokens") - - args_to_outfeed = args_op[:-2] - # Some platforms refuse to infeed empty arrays. We generate constants - # instead. - non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)), - flat_results_aval)) - need_callback_results_on_device = (not identity and - len(non_empty_flat_results_aval) > 0) - send_infeed = use_outfeed and need_callback_results_on_device - generated_infeed = False # Keep track if we emitted an infeed op - - _raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform)) - callback_id = _register_callback( - functools.partial( - _outside_call_run_callback, - send_infeed=send_infeed, - identity=identity, - flat_results_aval=flat_results_aval, - **params)) - next_token = _callback_handler_data.receiver.add_outfeed( - comp, current_token, callback_id, args_to_outfeed, device_index) - if identity: - results = list(args_to_outfeed) - next_itoken = current_itoken - else: - empty_results = [ - xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype)) - for aval in flat_results_aval - if _aval_is_empty(aval) - ] - if non_empty_flat_results_aval: - assert need_callback_results_on_device - after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token]) - # We shard the infeed as AssignedDevice(device_index). This must match the - # outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support - # this kind of sharding, we use a custom translation for infeed. - array_sharding_proto = xla_client.OpSharding() - array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL - array_sharding_proto.tile_assignment_dimensions = [1] - array_sharding_proto.tile_assignment_devices = [device_index] - - token_sharding_proto = xla_client.OpSharding() - token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED - infeed_sharding_proto = xla.tuple_sharding_proto( - [array_sharding_proto] * len(non_empty_flat_results_aval) + - [token_sharding_proto]) - - shape = [ - shape.with_major_to_minor_layout_if_absent() - for x in non_empty_flat_results_aval - for shape in xla.aval_to_xla_shapes(x) - ] - - build_infeed = functools.partial(xops.InfeedWithToken, - after_outfeed_itoken, - xla_client.Shape.tuple_shape(shape)) - outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto, - build_infeed) - outs = xops.GetTupleElement(outs_and_token, 0) - next_itoken = xops.GetTupleElement(outs_and_token, 1) - non_empty_results = [ - xops.GetTupleElement(outs, i) - for i in range(len(non_empty_flat_results_aval)) - ] - generated_infeed = True - results = [ - empty_results.pop(0) - if _aval_is_empty(result_aval) else non_empty_results.pop(0) - for result_aval in flat_results_aval - ] - else: - results = empty_results - next_itoken = current_itoken - - assert generated_infeed == send_infeed, ( - f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})") - assert identity or len(results) == len(flat_results_aval), ( - f"got {len(results)} but expected {len(flat_results_aval)}. " - f"identity = {identity}") - return results + [next_token, next_itoken] - -if xla_extension_version < 287: - xla.register_translation(outside_call_p, _outside_call_translation_rule) - - -def _outside_call_outfeed_lowering(ctx: mlir.LoweringRuleContext, - *args_op, - identity, - device_index, - flat_results_aval=(), - **params): - # We expect the current tokens at the end, inserted by _rewrite_jaxpr. - current_token = args_op[-2] - current_itoken = args_op[-1] - - args_to_outfeed = args_op[:-2] - # Some platforms refuse to infeed empty arrays. We generate constants - # instead. - non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)), - flat_results_aval)) - need_callback_results_on_device = (not identity and - len(non_empty_flat_results_aval) > 0) - send_infeed = need_callback_results_on_device - generated_infeed = False # Keep track if we emitted an infeed op - for platform in ctx.module_context.platforms: - _raise_if_using_outfeed_with_pjrt_c_api( - xb.get_backend(platform) - ) - callback_id = _register_callback( - functools.partial( - _outside_call_run_callback, - send_infeed=send_infeed, - identity=identity, - flat_results_aval=flat_results_aval, - **params)) - - outfeed_sharding = xla_client.OpSharding() - outfeed_sharding.type = xla_client.OpSharding.Type.MAXIMAL - outfeed_sharding.tile_assignment_dimensions = [1] - outfeed_sharding.tile_assignment_devices = [device_index] - - # next_token = _callback_handler_data.receiver.add_outfeed( - # comp, current_token, callback_id, args_to_outfeed, device_index) - - xla_shapes = util.flatten( - xla.aval_to_xla_shapes(aval) for aval in ctx.avals_in[:-2]) - _callback_handler_data.receiver.register_outfeed(callback_id, xla_shapes) - outfeed_header_start = 271828 # Must match kOutfeedHeaderStart in C++ - header = mlir.ir_constant(np.array([outfeed_header_start, callback_id], - dtype=np.uint32)) - header_outfeed = hlo.OutfeedOp([header], current_token, - outfeed_config=ir.StringAttr.get('')) - mlir.set_sharding(header_outfeed, outfeed_sharding) - next_token, = header_outfeed.results - data_outfeed = hlo.OutfeedOp(args_to_outfeed, next_token, - outfeed_config=ir.StringAttr.get('')) - mlir.set_sharding(data_outfeed, outfeed_sharding) - next_token, = data_outfeed.results - - - if identity: - results = list(args_to_outfeed) - next_itoken = current_itoken - else: - empty_results = [ - mlir.ir_constant(np.zeros(aval.shape, aval.dtype)) - for aval in flat_results_aval - if _aval_is_empty(aval) - ] - if non_empty_flat_results_aval: - assert need_callback_results_on_device - after_outfeed_itoken = hlo.AfterAllOp([current_itoken, next_token]) - # We shard the infeed as AssignedDevice(device_index). This must match the - # outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support - # this kind of sharding, we use a custom translation for infeed. - array_sharding_proto = xla_client.OpSharding() - array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL - array_sharding_proto.tile_assignment_dimensions = [1] - array_sharding_proto.tile_assignment_devices = [device_index] - - token_sharding_proto = xla_client.OpSharding() - token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED - infeed_sharding_proto = xla.tuple_sharding_proto( - [array_sharding_proto] * len(non_empty_flat_results_aval) + - [token_sharding_proto]) - - output_types = map(mlir.aval_to_ir_types, non_empty_flat_results_aval) - flat_output_types = util.flatten(output_types) - - layouts = ir.ArrayAttr.get([ - ir.ArrayAttr.get( - [mlir.i64_attr(i) - for i in range(len(aval.shape) - 1, -1, -1)]) - for aval in non_empty_flat_results_aval - ]) - infeed = hlo.InfeedOp(flat_output_types + [hlo.TokenType.get()], - after_outfeed_itoken, - infeed_config=ir.StringAttr.get(''), - layout=layouts) - mlir.set_sharding(infeed, infeed_sharding_proto) - non_empty_results = list(infeed.results[:-1]) - next_itoken = infeed.results[-1] - generated_infeed = True - results = [ - empty_results.pop(0) - if _aval_is_empty(result_aval) else non_empty_results.pop(0) - for result_aval in flat_results_aval - ] - else: - results = empty_results - next_itoken = current_itoken - - assert generated_infeed == send_infeed, ( - f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})") - assert identity or len(results) == len(flat_results_aval), ( - f"got {len(results)} but expected {len(flat_results_aval)}. " - f"identity = {identity}") - return results + [next_token, next_itoken] - - -def _outside_call_lowering(ctx: mlir.LoweringRuleContext, - *args, - has_token: bool, - identity: bool, - device_index: int, - flat_results_aval=(), - **params): - """MLIR Lowering for `CustomCall`-based HCB.""" - if len(ctx.module_context.platforms) > 1: - raise NotImplementedError("multi-platform lowering for host_callback") - platform = ctx.module_context.platforms[0] - use_outfeed = _use_outfeed(platform) - if use_outfeed: - if xla_extension_version < 287: - return mlir.xla_fallback_lowering(outside_call_p)( - ctx, - *args, - has_token=has_token, - identity=identity, - device_index=device_index, - flat_results_aval=flat_results_aval, - **params, - ) - else: - return _outside_call_outfeed_lowering( - ctx, *args, - has_token=has_token, - identity=identity, - flat_results_aval=flat_results_aval, - device_index=device_index, - **params, - ) - else: - # TODO(necula): It seems that on CPU, with custom call, the device_index - # does not work, and the callback is always run on device_index=0 - if (device_index != 0 and "cpu" in ctx.module_context.platforms): - raise ValueError( - "The device_index feature on CPU works only when using outfeed.") - - # We expect the current tokens at the end, inserted by _rewrite_jaxpr. - assert has_token - current_token = args[-2] - current_itoken = args[-1] - assert current_token.type == hlo.TokenType.get(), "The last two arguments must be tokens" - assert current_itoken.type == hlo.TokenType.get(), "The last two arguments must be tokens" - - args_to_outfeed = args[:-2] - # TODO(necula): this is a weak attempt to get the device. This works - # inside pmap, but does not work when we just execute on a single device, - # because in such executions we always get replica_id == 0. - replica_id = hlo.ReplicaIdOp() - callback_operands = [replica_id, *args_to_outfeed] - callback_operand_avals = [ - core.ShapedArray((), np.uint32), *ctx.avals_in[:-2]] - if identity: - callback_flat_results_aval = [] - else: - callback_flat_results_aval = [*flat_results_aval] - - def wrapped_callback(*args): - replica_id, *arrays = args - result_arrays = _outside_call_run_callback( - arrays, - xb.local_devices()[replica_id], - send_infeed=False, - # The same parameters as outside_call_p - identity=identity, - flat_results_aval=flat_results_aval, - **params) - if identity: - # For identity, we do not pass the any results back to the device - result_arrays = () - return result_arrays - - if isinstance( - ctx.module_context.axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), - ): - # Apply maximal sharding so pjit only executes the callback on device device_index. - sharding = xla_client.OpSharding() - sharding.type = xla_client.OpSharding.Type.MAXIMAL - sharding.tile_assignment_dimensions = [1] - sharding.tile_assignment_devices = [device_index] - else: - sharding = None - results, next_token, keep_alive = mlir.emit_python_callback(ctx, - wrapped_callback, current_token, callback_operands, - callback_operand_avals, callback_flat_results_aval, # type: ignore[arg-type] - has_side_effect=True, sharding=sharding) - _callback_handler_data.keep_alives.append(keep_alive) - # We must put the two tokens at the end - if identity: - results = list(args_to_outfeed) - next_itoken = current_itoken - - assert identity or len(results) == len(flat_results_aval), ( - f"got {len(results)} but expected {len(flat_results_aval)}. " - f"identity = {identity}") - return list(results) + [next_token, next_itoken] - -if xla_extension_version < 287: - mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu") -else: - mlir.register_lowering(outside_call_p, _outside_call_lowering) - -def _outside_call_run_callback( - arrays, device, *, - send_infeed=True, - # The same parameters as outside_call_p - callback, arg_treedef, - identity, result_treedef=None, flat_results_aval=None, - transforms=(), has_token=False): - """Performs the callback: - callback(arg, device, transforms) - - Called during the device computation once we have the argument, either from - an inlined callback or from an XLA computation outfeed. - - Returns the flat list of result arrays. If `send_infeed` then it will also send - the flat list of results to the device. - """ - - def _unpack_transforms(transforms) -> tuple[tuple[str, dict[str, Any]], ...]: - def _unpack_transform(name, *params): - if name == "batch": - return name, dict(batch_dims=params[0]) - elif name == "mask": - return name, dict(logical_shapes=5) - else: - assert not params, f"{name}, {params}" - return name, {} - - return tuple(_unpack_transform(*t) for t in transforms) - - try: - arg = api.tree_unflatten(arg_treedef, arrays) - unpacked_transforms = _unpack_transforms(transforms) - logger.debug( - "Outside call invoking call_func %s, device=%s, transforms=%s", - callback, device, unpacked_transforms - ) - res = callback(arg, device, unpacked_transforms) - if identity: - return tuple(arrays) - - else: # Check the type of the callback results - assert result_treedef is not None - assert flat_results_aval is not None - actual_flat_results, actual_result_treedef = tree_util.tree_flatten(res) - if actual_result_treedef != result_treedef: - msg = (f"Callback func {callback} should have returned a result " - f"with pytree {result_treedef} but returned " - f"{actual_result_treedef}") - raise TypeError(msg) - - canonical_flat_results = tuple(util.safe_map(xla.canonicalize_dtype, actual_flat_results)) - actual_flat_results_aval = _values_to_avals(canonical_flat_results) - logger.debug( - "Outside call %s result %s. Sending to infeed for device %s.", - callback, flat_results_aval, device, - ) - - if not all(ea.strip_weak_type() == ra.strip_weak_type() - for ea, ra in util.safe_zip(flat_results_aval, - actual_flat_results_aval)): - msg = (f"Callback func {callback} should have returned a result " - "with abstract values " - f"{result_treedef.unflatten(flat_results_aval)} " - f"but returned {actual_result_treedef.unflatten(actual_flat_results_aval)}") - raise TypeError(msg) - - if send_infeed: - # Do not send the 0-sized arrays - non_empty_canonical_flat_results = tuple(filter(lambda r: not _aval_is_empty(r), - canonical_flat_results)) - device.transfer_to_infeed(non_empty_canonical_flat_results) - return canonical_flat_results - - except Exception as e: - logger.error("Outside call %s threw exception %s.", callback, e) - if send_infeed: - # Prepare some results to send in case of error. We are sending something - # with a distinctive shape (int8[12345]), one that is unlikely to be what the device - # expects. This should have the effect to abort the device computation, - # with an error message that we recognize. On TPU there seem to be no - # such check, and if we send anything at all the device computation will - # use some garbage data. So, on TPU we prefer to not send anything and let - # the computation hang. - # TODO: implement a proper error handling for TPU - if device.platform != "tpu": - canonical_flat_results = [xla.canonicalize_dtype(np.arange(12345, dtype=np.int8))] - logger.debug("Outside call consumer %s exception %s. Sending to infeed the error result.", - callback, e) - device.transfer_to_infeed(tuple(canonical_flat_results)) - else: - logger.debug("Outside call consumer %s exception %s. On TPU we do not send infeed.", - callback, e) - raise e # Let the exception propagate - - -def _add_transform(params: dict, name: str, *transform_params) -> dict: - """Adds the `transform` to the params["transforms"]. - - Uses a tuple representation internally, will be unpacked before the - callback by _ConsumerCallable. - """ - new_transform = (name, *transform_params) - return dict( - params, transforms=(params.get("transforms", ()) + (new_transform,))) - - -def _aval_is_empty(aval) -> bool: - return math.prod(aval.shape) == 0 - -def _instantiate_zeros(tan, arg): - del arg - return ad.instantiate_zeros(tan) - -def _outside_call_jvp_rule(primals, tangents, **params): - assert "has_token" not in params - if not params["identity"]: - raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.") - out_primals_tapped = outside_call_p.bind(*primals, **params) - return tuple(out_primals_tapped), tangents - - -ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule - -def _outside_call_transpose_rule(cts, *args, **params): - if not params["identity"]: - raise NotImplementedError("differentiation rules are implemented only for id_tap, not for call.") - assert "has_token" not in params - assert len(cts) == len(args) - cts_instantiated = tuple(map(_instantiate_zeros, cts, args)) - - # The args have been prepared by the id_tap_jvp_rule: tapped_primals, tapped_tangents, rest_primals, rest_tangents - transforms = params.get("transforms", ()) - if not transforms or transforms[-1] != ("jvp",): - # TODO: I should understand better when can this happen. It seems to arise - # in scan. - return outside_call_p.bind( - *cts_instantiated, - **_add_transform(params, "transpose")) - - assert False - - -ad.primitive_transposes[outside_call_p] = _outside_call_transpose_rule - - -def _outside_call_batching_rule(batched_args, batch_dims, **params): - if not params["identity"]: - raise NotImplementedError("batching rules are implemented only for id_tap, not for call.") - assert "has_token" not in params - new_params = _add_transform(params, "batch", batch_dims) - res = outside_call_p.bind(*batched_args, **new_params) - return res, batch_dims - - -batching.primitive_batchers[outside_call_p] = _outside_call_batching_rule - -#### -#### Jaxpr rewriting logic to thread the tokens through stateful primitives. -#### - - -def _rewrite_closed_jaxpr(cjaxpr: core.ClosedJaxpr, has_input_token: bool, - has_output_token: bool) -> core.ClosedJaxpr: - """Rewrites a ClosedJaxpr to thread the token, if needed.""" - new_jaxpr = _rewrite_jaxpr(cjaxpr.jaxpr, has_input_token, has_output_token) - return core.ClosedJaxpr(new_jaxpr, cjaxpr.consts) - - -def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, - has_output_token: bool) -> core.Jaxpr: - """Rewrite a Jaxpr to thread the token, if needed.""" - assert has_input_token or not has_output_token - - if not has_input_token and not core.jaxpr_uses_outfeed(jaxpr): - return jaxpr - - mk_new_var = core.gensym() - - eqns: list[core.JaxprEqn] = [] - # store the incoming tokens - last_token_var = mk_new_var(core.abstract_token) - last_itoken_var = mk_new_var(core.abstract_token) - if has_input_token: - invars = jaxpr.invars + [last_token_var, last_itoken_var] - else: - invars = jaxpr.invars - # We need tokens but none is given in input; make one depending on all invars - eqns.append( - core.new_jaxpr_eqn(jaxpr.invars, [last_token_var], - lax.create_token_p, {}, core.no_effects, source_info_util.current())) - eqns.append( - core.new_jaxpr_eqn(jaxpr.invars, [last_itoken_var], - lax.create_token_p, {}, core.no_effects, source_info_util.current())) - - for eqn in jaxpr.eqns: - if not core.primitive_uses_outfeed(eqn.primitive, eqn.params): - eqns.append(eqn) - else: - output_token_var = mk_new_var(last_token_var.aval) - output_itoken_var = mk_new_var(last_itoken_var.aval) - _rewrite_eqn(eqn, eqns, last_token_var, output_token_var, - last_itoken_var, output_itoken_var, mk_new_var) - last_token_var = output_token_var - last_itoken_var = output_itoken_var - - outvars = jaxpr.outvars + ([last_token_var, last_itoken_var] if has_output_token else []) - new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr.effects) - return new_jaxpr - - -def _rewrite_eqn(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn], - input_token_var: core.Var, output_token_var: core.Var, - input_itoken_var: core.Var, output_itoken_var: core.Var, - mk_new_var: Callable[[core.AbstractValue], core.Var]): - """Rewrite an `eqn` and append equations to `eqns`. - - This is only called if the current primitive uses outfeed. - Assume that the current token is in `input_token_var` and the resulting - token must end in `output_token_var`. - - Append the result of rewriting to `eqns`. - """ - if eqn.primitive is outside_call_p: - assert "has_token" not in eqn.params - eqns.append(eqn.replace(invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict(eqn.params, has_token=True))) - elif eqn.primitive is lax.while_p: - cond_jaxpr, _, body_jaxpr, _ = util.split_dict( - eqn.params, - ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) - if core.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): - _rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var, - input_itoken_var, output_itoken_var, - mk_new_var) - return - - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - body_jaxpr=_rewrite_closed_jaxpr(body_jaxpr, True, True), - cond_jaxpr=_rewrite_closed_jaxpr(cond_jaxpr, True, False)))) - elif eqn.primitive is lax.cond_p: - branches, = util.split_dict(eqn.params, ["branches"]) - index, *operands = eqn.invars - new_invars = [index, *operands, input_token_var, input_itoken_var] - eqns.append( - eqn.replace( - invars=new_invars, outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - branches=tuple( - _rewrite_closed_jaxpr(jaxpr, True, True) - for jaxpr in branches)))) - elif eqn.primitive is lax.scan_p: - num_consts, num_carry, carry_jaxpr, linear, _, _, _, _ = util.split_dict( - eqn.params, - ["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length", - "unroll", "_split_transpose"]) - # We add the tokens right at the end of carry - nr_const_and_carry = num_consts + num_carry - new_invars = eqn.invars[0:nr_const_and_carry] + [ - input_token_var, input_itoken_var] + eqn.invars[nr_const_and_carry:] - new_jaxpr = _rewrite_closed_jaxpr(carry_jaxpr, True, True) - # The rewrite has put the token at end, it has to be at end of carry - new_jaxpr_invars = new_jaxpr.jaxpr.invars - new_jaxpr_invars = ( - new_jaxpr_invars[0:nr_const_and_carry] + new_jaxpr_invars[-2:] + - new_jaxpr_invars[nr_const_and_carry:-2]) - new_jaxpr = new_jaxpr.replace(jaxpr=new_jaxpr.jaxpr.replace(invars=new_jaxpr_invars)) - - new_jaxpr_outvars = new_jaxpr.jaxpr.outvars - new_jaxpr_outvars = ( - new_jaxpr_outvars[0:num_carry] + new_jaxpr_outvars[-2:] + - new_jaxpr_outvars[num_carry:-2]) - new_jaxpr = new_jaxpr.replace(jaxpr=new_jaxpr.jaxpr.replace(outvars=new_jaxpr_outvars)) - eqns.append( - eqn.replace( - invars=new_invars, - # Output token is at the end of carry result - outvars=(eqn.outvars[0:num_carry] + [output_token_var, output_itoken_var] + - eqn.outvars[num_carry:]), - params=dict( - eqn.params, - jaxpr=new_jaxpr, - num_carry=num_carry + 2, - linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:]))) - elif eqn.primitive is pxla.xla_pmap_p: - # We broadcast the input token into an array of tokens - call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True), - donated_invars=eqn.params["donated_invars"] + (False, False), - # Sharding/unsharding of tokens in pmap_translation are special - # cased to just pass-through the token - in_axes=eqn.params["in_axes"] + (None, None), - out_axes=eqn.params["out_axes"] + (0, 0)))) - elif eqn.primitive is custom_derivatives.custom_jvp_call_p: - fun_jaxpr = eqn.params["call_jaxpr"] - - def unreachable_thunk(): - assert False, "Should not be reached" - unreachable_thunk.reset_stores = lambda: None - - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - call_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True), - jvp_jaxpr_thunk=unreachable_thunk - ))) - elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p: - fun_jaxpr = eqn.params["fun_jaxpr"] - new_invars = [*eqn.invars, input_token_var, input_itoken_var] - - def unreachable_thunk(): - assert False, "Should not be reached" - - eqns.append( - eqn.replace( - invars=new_invars, - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True), - fwd_jaxpr_thunk=unreachable_thunk, - # The following are illegal values for the parameters, they - # should not be needed because this rewrite is just before - # compilation to XLA, which does not use those parameters. - bwd="illegal param", - out_trees="illegal param"))) - elif eqn.primitive is pjit.pjit_p: - jaxpr = cast(core.ClosedJaxpr, eqn.params["jaxpr"]) - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - jaxpr=_rewrite_closed_jaxpr(jaxpr, True, True), - donated_invars=eqn.params["donated_invars"] + (False, False), - in_shardings=( - eqn.params["in_shardings"] - + (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED) - ), - out_shardings=( - eqn.params["out_shardings"] - + (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED) - ), - in_layouts=(eqn.params["in_layouts"] + (None, None)), - out_layouts=(eqn.params["out_layouts"] + (None, None)), - ), - ) - ) - elif eqn.primitive is ad_checkpoint.remat_p: - jaxpr_ = cast(core.Jaxpr, eqn.params["jaxpr"]) - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - jaxpr=_rewrite_jaxpr(jaxpr_, True, True), - ))) - else: - raise NotImplementedError(f"outfeed rewrite {eqn.primitive}") - - -def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn], - input_token_var: core.Var, - output_token_var: core.Var, - input_itoken_var: core.Var, - output_itoken_var: core.Var, - mk_new_var: Callable): - """Rewrite a while whose cond has outfeed""" - cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict( - eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) - transformed_cond_jaxpr = _rewrite_closed_jaxpr(cond_jaxpr, True, True) - carry_invars = eqn.invars[cond_nconsts + body_nconsts:] - # pred1, token1, itoken1 = rewrite(COND)(cond_consts, carry_invars, input_token, input_itoken) - pred1_and_token1 = [ - mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars - ] - eqns.append( - core.new_jaxpr_eqn( - eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var, input_itoken_var], - pred1_and_token1, core.call_p, - dict( - call_jaxpr=transformed_cond_jaxpr.jaxpr, - name="cond_before"), - transformed_cond_jaxpr.jaxpr.effects, - eqn.source_info)) - # Make a new cond "lambda pred, carry, token, itoken: pred" - new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0]) - new_cond_invars = ( - [new_cond_pred_invar] + [mk_new_var(cv.aval) for cv in carry_invars] + - [mk_new_var(input_token_var.aval), - mk_new_var(input_itoken_var.aval)]) - new_cond_jaxpr = core.ClosedJaxpr( - core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], [], set()), []) - # Make a new body: - # "lambda cond_constvars, body_constvars, pred, carry, token, itoken: - # carry2, token2, itoken2 = rewrite(BODY)(body_constvars, carry, token, itoken) - # pred2, token3, itoken3 = rewrite(COND)(cond_constvars, carry2, token2, itoken2) - # (pred2, carry2, token3, itoken3) - transformed_body_jaxpr = _rewrite_closed_jaxpr(body_jaxpr, True, True) - new_body_invars_cond_constvars = [ - mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts] - ] - new_body_invars_body_constvars = [ - mk_new_var(v.aval) - for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts] - ] - new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0]) - new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars] - new_body_invars_token = mk_new_var(input_token_var.aval) - new_body_invars_itoken = mk_new_var(input_itoken_var.aval) - - new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars] - new_body_token2 = mk_new_var(input_token_var.aval) - new_body_itoken2 = mk_new_var(input_itoken_var.aval) - new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0]) - new_body_token3 = mk_new_var(input_token_var.aval) - new_body_itoken3 = mk_new_var(input_itoken_var.aval) - - new_body_eqns = [ - core.new_jaxpr_eqn( - new_body_invars_body_constvars + new_body_invars_carry + - [new_body_invars_token, new_body_invars_itoken], - new_body_carry2 + [new_body_token2, new_body_itoken2], - core.call_p, - dict( - call_jaxpr=transformed_body_jaxpr.jaxpr, - name="body"), - transformed_body_jaxpr.effects, - eqn.source_info), - core.new_jaxpr_eqn( - new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2], - [new_body_pred2, new_body_token3, new_body_itoken3], core.call_p, - dict( - call_jaxpr=transformed_cond_jaxpr.jaxpr, - name="cond_body"), - transformed_cond_jaxpr.effects, - eqn.source_info) - ] - effects = core.join_effects(*(eqn.effects for eqn in new_body_eqns)) - new_body_jaxpr = core.ClosedJaxpr( - core.Jaxpr([], (new_body_invars_cond_constvars + - new_body_invars_body_constvars + [new_body_invars_pred] + - new_body_invars_carry + [new_body_invars_token, new_body_invars_itoken]), - ([new_body_pred2] + new_body_carry2 + [new_body_token3, new_body_itoken3]), - new_body_eqns, effects), []) - - pred_out = mk_new_var(cond_jaxpr.out_avals[0]) - eqns.append( - core.new_jaxpr_eqn( - (eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] + - carry_invars + pred1_and_token1[1:]), - ([pred_out] + eqn.outvars + [output_token_var, output_itoken_var]), - lax.while_p, - dict( - cond_jaxpr=new_cond_jaxpr, - cond_nconsts=0, - body_jaxpr=new_body_jaxpr, - body_nconsts=cond_nconsts + body_nconsts), - new_body_jaxpr.effects, - eqn.source_info)) - - -# We need an identity primitive to simplify rewriting -id_p = core.Primitive("id") -id_p.multiple_results = True -id_p.def_impl(lambda *args: args) -id_p.def_abstract_eval(lambda *args: args) -mlir.register_lowering(id_p, lambda ctx, *args: args) - -dispatch.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False) - - -class CallbackException(Exception): - """Signals that some callback function had exceptions. - - Raised by :func:`barrier_wait`. See the :mod:`jax.experimental.host_callback` - module documentation for details. - """ - pass - -TapFunctionException = CallbackException # For backwards compatibility - -class _CallbackHandlerData: - """Keep track of the outfeed receiver data.""" - receiver: Any - initialized: bool - on_exit: bool - lock: threading.Lock - last_callback_exception: tuple[Exception, str] | None - clients: tuple[XlaLocalClient, ...] - devices: tuple[XlaDevice, ...] - consumer_registry: dict[Callable, int] - consumer_registry_by_id: dict[int, Callable] - - def __init__(self): - self.receiver = None # Initialize lazily, when first needed - self.initialized = False - self.on_exit = False - self.lock = threading.Lock() - self.last_callback_exception = None - self.clients = () - self.devices = () - # The consumer registries must be live for the lifetime of the program, - # because we may have cached compilations that embed consumer ids, and we - # do not want the id reused for other shapes. - # Used only for the outfeed mechanism. - self.callback_registry = {} - self.callback_registry_by_id = {} - # For now we keep here the keep_alives for the emit_python_callback. This is - # a leak. We ought to attach these to the executable. - self.keep_alives = [] - - def stop(self): - """Wait for all pending outfeeds and stop the receiver.""" - self.receiver = None # GC will trigger the destructor - self.initialized = False - self.clients = () - self.devices = () - # Do not clear the consumer registries. - - -_callback_handler_data = _CallbackHandlerData() - - -# This function is called from C++; it must not allow exceptions through. -def _callback_input_received(device, consumer_id, arrays: tuple): - array_repr = ", ".join([f"({a.dtype}{a.shape})" for a in arrays]) - logger.debug("Callback input received on device %s for consumer %s arrays: %s", - device, consumer_id, array_repr) - callback = _callback_handler_data.callback_registry_by_id.get(consumer_id) - assert callback is not None, "We should have crashed in the runtime" - try: - return callback(arrays, device) - except Exception as e: - formatted_e = traceback.format_exc() - logger.error("Postponing exception raised in callback function: %s", formatted_e) - _callback_handler_data.last_callback_exception = (e, formatted_e) - - -def _register_callback(callback: Callable) -> int: - """Registers a callback function, cache by hash of callback. - - The callback is a function to be invoked as `callback(arrays, device)`. - """ - callback_id = _callback_handler_data.callback_registry.get(callback) - if callback_id is not None: - return callback_id - callback_id = hash(callback) & 0xFFFFFFFC # pybind11 has trouble here with large ints - callback_id += 1 # Reserve the consumer ID 0 - assert callback_id not in _callback_handler_data.callback_registry, ( - "callback id collision") - _callback_handler_data.callback_registry[callback] = callback_id - _callback_handler_data.callback_registry_by_id[callback_id] = callback - return callback_id - - -def _initialize_outfeed_receiver( - max_callback_queue_size_bytes: int = int(256 * 1e6)): - """Creates and starts the outfeed_receiver. - - This function is called lazily only when we compile an id_tap. - - Args: - * clients: the list of clients (backends) on whose devices to listen on. - * max_callback_queue_size_bytes: an optional integer to bound the maximum - size of arrays in the callback queue. When this limit is reached the - device listener pauses. - """ - outfeed_receiver_module = xla_extension.outfeed_receiver - - with _callback_handler_data.lock: - if _callback_handler_data.initialized: - return - - # By default, all devices on all supported backends. - clients = [backend for name, backend in xb.backends().items() - if name in ("cpu", "cuda", "rocm", "tpu")] - devices = list( - itertools.chain(*[backend.local_devices() for backend in clients])) - _callback_handler_data.clients = clients # type: ignore[assignment] - _callback_handler_data.devices = devices # type: ignore[assignment] - clients_with_outfeed = [c for c in clients if _use_outfeed(c.platform)] - for client in clients_with_outfeed: - _raise_if_using_outfeed_with_pjrt_c_api(client) - if clients_with_outfeed: - devices_with_outfeed = list( - itertools.chain(*[backend.local_devices() for backend in clients_with_outfeed])) - if logger.isEnabledFor(logging.DEBUG): - device_repr = ", ".join([str(d) for d in devices_with_outfeed]) - logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s", - device_repr, max_callback_queue_size_bytes) - _callback_handler_data.receiver = outfeed_receiver_module.start( - _callback_input_received, tuple(clients_with_outfeed), - max_callback_queue_size_bytes, - compiler.get_compile_options(1, 1).executable_build_options) - - def exit_handler(): - # Prevent logging usage during compilation, gives errors under pytest - dispatch._on_exit = True - if not _callback_handler_data.on_exit: - _callback_handler_data.on_exit = True - _deprecated_barrier_wait("at_exit") - - atexit.register(exit_handler) # We wait as long as we have callbacks - _callback_handler_data.initialized = True - - -def _deprecated_barrier_wait(logging_name: str | None = None): - """Blocks the calling thread until all current outfeed is processed. - - Waits until all callbacks from computations already running on all devices - have been received and processed by the Python callbacks. Raises - CallbackException if there were exceptions while processing the callbacks. - - This works by enqueueing a special tap computation to all devices to which - we are listening for outfeed. Once all those tap computations are done, we - return from barrier_wait. - - Note: If any of the devices are busy and cannot accept new computations, - this will deadlock. - - Args: - logging_name: an optional string that will be used in the logging statements - for this invocation. See `Debugging` in the module documentation. - - For more details see the :mod:`jax.experimental.host_callback` module documentation. - """ - if not _HOST_CALLBACK_LEGACY.value: - jax.effects_barrier() - return - - logging_name = logging_name or "" - logger.debug("barrier_wait[%s]: start", logging_name) - - lock = threading.Lock() - cv = threading.Condition(lock=lock) - devices_at_barrier = [] # Protected by lock - def barrier_tap_received(dev_idx, _): - device = _callback_handler_data.devices[dev_idx] - logger.debug( - "barrier_wait[%s]: at barrier_tap for device %s. Thread %s", - logging_name, device, threading.current_thread() - ) - with lock: - devices_at_barrier.append(device) - if logger.isEnabledFor(logging.DEBUG): - waiting_for_devices = [d for d in _callback_handler_data.devices - if d not in devices_at_barrier] - logger.debug( - "barrier_wait[%s]: still waiting for %s devices at barrier (%s)", - logging_name, len(waiting_for_devices), waiting_for_devices - ) - cv.notify() - - for d_idx, d in enumerate(_callback_handler_data.devices): - logger.debug("barrier_wait[%s]: enqueueing barrier on device %s", logging_name, d) - x_on_dev = api.device_put(d_idx, device=d) - api.jit(lambda x: _deprecated_id_tap(barrier_tap_received, x), device=d)(x_on_dev) - - logger.debug("barrier_wait[%s]: waiting for callbacks", logging_name) - - with lock: - cv.wait_for(lambda: len(devices_at_barrier) == len(_callback_handler_data.devices)) - - logger.debug("barrier_wait[%s]: done", logging_name) - - if _callback_handler_data.last_callback_exception is not None: - last_exception, formatted_last_exception = _callback_handler_data.last_callback_exception - _callback_handler_data.last_callback_exception = None - raise CallbackException( - "There were exceptions during callback processing. " - f"Last one was: {formatted_last_exception}") from last_exception - - -def _deprecated_stop_outfeed_receiver(): - """Stops the outfeed receiver runtime. - - .. warning:: - The host_callback APIs are deprecated as of March 20, 2024. - The functionality is subsumed by the - `new JAX external callbacks `_ - - This waits for all outfeeds from computations already running on all devices, - and then stops the outfeed receiver runtime. The runtime will be restarted - next time you use a tap function. - - It should not be necessary to use this function, unless you want to start - using lax.outfeed directly after having used host callbacks. - """ - _callback_handler_data.stop() - -_deprecation_msg = ( - "The host_callback APIs are deprecated as of March 20, 2024. The functionality " - "is subsumed by the new JAX external callbacks. " - "See https://github.com/jax-ml/jax/issues/20385.") +def call(*_, **__): + raise NotImplementedError( + "jax.experimental.host_callback has been deprecated since March 2024 and " + "is now no longer supported. " + "See https://github.com/jax-ml/jax/issues/20385" + ) -_deprecations = { - # Added March 20, 2024 - "id_tap": (_deprecation_msg, _deprecated_id_tap), - "id_print": (_deprecation_msg, _deprecated_id_print), - "call": (_deprecation_msg, _deprecated_call), - "barrier_wait": (_deprecation_msg, _deprecated_barrier_wait), - "stop_outfeed_receiver": (_deprecation_msg, _deprecated_stop_outfeed_receiver), -} -import typing -if typing.TYPE_CHECKING: - id_tap = _deprecated_id_tap - id_print = _deprecated_id_print - call = _deprecated_call - barrier_wait = _deprecated_barrier_wait - stop_outfeed_receiver = _deprecated_stop_outfeed_receiver -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing +id_tap = call diff --git a/jax/experimental/jax2tf/examples/serving/model_server_request.py b/jax/experimental/jax2tf/examples/serving/model_server_request.py index 0c64b7b55f8e..7f69f080603c 100644 --- a/jax/experimental/jax2tf/examples/serving/model_server_request.py +++ b/jax/experimental/jax2tf/examples/serving/model_server_request.py @@ -92,7 +92,7 @@ def serving_call_mnist(images): # You can see the name of the input ("inputs") in the SavedModel dump. data = f'{{"inputs": {images_json}}}' predict_url = f"http://{_PREDICTION_SERVICE_ADDR.value}/v1/models/{_MODEL_SPEC_NAME.value}:predict" - response = requests.post(predict_url, data=data) + response = requests.post(predict_url, data=data, timeout=60) if response.status_code != 200: msg = (f"Received error response {response.status_code} from model " f"server: {response.text}") diff --git a/jax/experimental/jax2tf/examples/tf_js/quickdraw/input_pipeline.py b/jax/experimental/jax2tf/examples/tf_js/quickdraw/input_pipeline.py index 1aeb18a08151..77d3af16e1e3 100644 --- a/jax/experimental/jax2tf/examples/tf_js/quickdraw/input_pipeline.py +++ b/jax/experimental/jax2tf/examples/tf_js/quickdraw/input_pipeline.py @@ -39,7 +39,7 @@ def download_dataset(dir_path, nb_classes): continue with open(cls_file_path, "wb") as save_file: try: - response = requests.get(url + cls_filename.replace('_', ' ')) + response = requests.get(url + cls_filename.replace('_', ' '), timeout=60) save_file.write(response.content) print(f'Successfully fetched {cls_filename}') except: diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 3c51e5d63f25..310cbaab6d59 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -364,15 +364,12 @@ def _conv_general_dilated( def _dot_general(lhs, rhs, *, dimension_numbers, precision: tuple[PrecisionType, PrecisionType] | None, preferred_element_type: DType | None, - algorithm: Any, transpose_algorithm: Any, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" # Unused arguments. del precision del preferred_element_type - del algorithm - del transpose_algorithm lhs, rhs, convert_result = jax2tf._dot_general_convert_to_common_dtype( lhs, _in_avals[0], rhs, _in_avals[1], _out_aval) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index f01a3ab7a036..84207921535e 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -334,7 +334,9 @@ def convert(fun_jax: Callable, """ if not enable_xla: if allow_enable_xla_false(): - warnings.warn("jax2tf.convert with enable_xla=False is deprecated.") + warnings.warn("jax2tf.convert with enable_xla=False is deprecated.", + DeprecationWarning, + stacklevel=2) else: raise ValueError("jax2tf.convert with enable_xla=False is not supported.") @@ -346,7 +348,9 @@ def convert(fun_jax: Callable, if not native_serialization: warnings.warn( - "jax2tf.convert with native_serialization=False is deprecated.") + "jax2tf.convert with native_serialization=False is deprecated.", + DeprecationWarning, + stacklevel=2) if native_serialization and not enable_xla: raise ValueError( "native_serialization is not supported with enable_xla=False") @@ -1555,6 +1559,9 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "bitcast", "repeat", "roll", + # temporary pending cudnn fix, see https://github.com/jax-ml/jax/pull/23740 + "bias_fwd", + "bias_bwd", ] tf_impl[random_internal.random_clone_p] = lambda x: x @@ -1567,7 +1574,7 @@ def _add(x: TfVal, y: TfVal) -> TfVal: tf_impl[ad_util.add_jaxvals_p] = _add -tf_impl[dispatch.device_put_p] = lambda *xs, devices=None, srcs=None: xs +tf_impl[dispatch.device_put_p] = lambda *xs, devices=None, srcs=None, copy_semantics=None: xs tf_impl[lax_internal.copy_p] = lambda x: x def _shard_alike(*args: TfVal, **_): @@ -2176,12 +2183,9 @@ def gen_conv(lhs, rhs, preferred_element_type: DType | None): def _dot_general(lhs, rhs, *, dimension_numbers, precision: tuple[PrecisionType, PrecisionType] | None, preferred_element_type: DType | None, - algorithm: Any, transpose_algorithm: Any, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" - del algorithm, transpose_algorithm # unused - # TODO(b/293247337): we ought to turn on this safety check, but this leads to # failures. Since we are going to turn on native serializaton soon, wait # until then to turn on this check. @@ -2873,6 +2877,9 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers): proto.offset_dims.extend(dimension_numbers.offset_dims) proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims) proto.start_index_map.extend(dimension_numbers.start_index_map) + proto.operand_batching_dims.extend(dimension_numbers.operand_batching_dims) + proto.start_indices_batching_dims.extend( + dimension_numbers.start_indices_batching_dims) assert indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto @@ -2984,6 +2991,9 @@ def _scatter_dimensions_proto(indices_shape, dimension_numbers): proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims) proto.scatter_dims_to_operand_dims.extend( dimension_numbers.scatter_dims_to_operand_dims) + proto.input_batching_dims.extend(dimension_numbers.operand_batching_dims) + proto.scatter_indices_batching_dims.extend( + dimension_numbers.scatter_indices_batching_dims) assert indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto @@ -3040,6 +3050,7 @@ def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal: tf_impl_with_avals[lax.scatter_max_p] = _scatter tf_impl_with_avals[lax.scatter_mul_p] = _scatter tf_impl_with_avals[lax.scatter_add_p] = _scatter +tf_impl_with_avals[lax.scatter_sub_p] = _scatter def _cond( @@ -3260,7 +3271,7 @@ def lexicographic_comparator(*tf_args: TfVal) -> TfVal: def _fft(x, *, fft_type, fft_lengths, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): - FFT, IFFT, RFFT, IRFFT = list(map(xla_client.FftType, [0, 1, 2, 3])) + FFT, IFFT, RFFT, IRFFT = list(map(lax.FftType, [0, 1, 2, 3])) tf_funcs = { FFT: [tf.signal.fft, tf.signal.fft2d, tf.signal.fft3d], IFFT: [tf.signal.ifft, tf.signal.ifft2d, tf.signal.ifft3d], diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 6411dc581424..0ec5cb0bb7df 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -317,6 +317,8 @@ def test_function(self): @jtu.sample_product(with_function=[False, True]) def test_gradients_disabled(self, with_function=False): + if tf.version.VERSION.split(".") <= ["2", "17", "0"]: + self.skipTest("This test works only with newer versions of TF") f_tf = jax2tf.convert(jnp.tan, with_gradient=False) if with_function: f_tf = tf.function(f_tf, autograph=False) diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 78c24b7ea411..34844aa77a89 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -158,11 +158,7 @@ def test_primitive_coverage(self): """Fail if there are JAX primitives that are not implemented.""" # Harvest primitives from XLA translation tables all_primitives = ( - set(xla._translations) - | set(xla._backend_specific_translations["cpu"]) - | set(xla._backend_specific_translations["gpu"]) - | set(xla._backend_specific_translations["tpu"]) - | set(mlir._lowerings) + set(mlir._lowerings) | set(mlir._platform_specific_lowerings["cpu"]) | set(mlir._platform_specific_lowerings["gpu"]) | set(mlir._platform_specific_lowerings["tpu"])) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 07bd9b5aed22..0ab35efb48ef 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -1824,11 +1824,11 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] lambda x, fft_type, nr_fft_lengths: lax.fft_p.bind( x, fft_type=fft_type, fft_lengths=tuple( - x.shape[-nr_fft_lengths:] if fft_type != xla_client.FftType.IRFFT else + x.shape[-nr_fft_lengths:] if fft_type != lax.FftType.IRFFT else [(x.shape[-1] - 1) * 2])), arg_descriptors=[ RandArg((3, 4, 5, 6), - np.float32 if fft_type == xla_client.FftType.RFFT else np.complex64), + np.float32 if fft_type == lax.FftType.RFFT else np.complex64), StaticArg(fft_type), StaticArg(nr_fft_lengths)], # All axes but the last one are dynamic. This means that the test @@ -1836,8 +1836,8 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] polymorphic_shapes=["b0, b1, b2, ..."], tol=1e-4) - for fft_type in (xla_client.FftType.FFT, xla_client.FftType.IFFT, - xla_client.FftType.RFFT, xla_client.FftType.IRFFT) + for fft_type in (lax.FftType.FFT, lax.FftType.IFFT, + lax.FftType.RFFT, lax.FftType.IRFFT) for nr_fft_lengths in (1, 2) ], PolyHarness("full", "", diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index f5944c862480..dac0c16fe4f7 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -47,6 +47,7 @@ fori, memref_fold, memref_slice, + memref_reshape, memref_transpose, memref_unfold, memref_unsqueeze, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index ae6c40b9416d..3c419ed62a15 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -214,10 +214,16 @@ def load_strided(cls, ref: ir.Value, *, is_signed: bool | None = None): raise TypeError(ref.type) ref_ty = ir.MemRefType(ref.type) - ref_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) layout = WGStridedFragLayout.from_memref_type(ref_ty) vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) - vecs = [vector.load(vec_ty, ref_1d, [vec_idx]) for vec_idx in layout.linear_thread_vec_idxs()] + try: + # Flattening the reference potentially produces simpler PTX but + # if the ref is not already 1D and has strided dimensions + # flattening won't work. + ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) + vecs = [vector.load(vec_ty, ref_, [vec_idx]) for vec_idx in layout.linear_thread_vec_idxs()] + except NotImplementedError: + vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_vec_idxs()] return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) @classmethod @@ -835,12 +841,22 @@ def _store_untiled_splat(self, ref: ir.Value): def _store_untiled_wg_strided(self, ref: ir.Value): ref_ty = ir.MemRefType(ref.type) + try: + # Flattening the reference potentially produces simpler PTX but + # if the ref is not already 1D and has strided dimensions + # flattening won't work. We use a different variable for ref in + # case `NotImplementedError` is thrown by + # .linear_thread_vec_idxs(). + ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) + idxs = ([i] for i in self.layout.linear_thread_vec_idxs()) + except NotImplementedError: + ref_ = ref + idxs = self.layout.thread_vec_idxs() ref_shape = tuple(ref_ty.shape) if ref_shape != self.shape: raise ValueError((ref_shape, self.shape)) - smem_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) - for idx, reg in zip(self.layout.linear_thread_vec_idxs(), self.registers.flat): - vector.store(reg, smem_1d, [idx]) + for idx, reg in zip(idxs, self.registers.flat): + vector.store(reg, ref_, idx) def _store_untiled_wgmma(self, ref: ir.Value): """Stores accumulator to a 2D memref. Not optimized at the moment.""" diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index a59ddbea5565..2709d0075a6d 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -20,7 +20,7 @@ import enum import functools import math -from typing import Any, Literal +from typing import Any, Literal, cast import jax from jax import numpy as jnp @@ -316,13 +316,15 @@ def memref_slice(ref: ir.Value, index) -> ir.Value: # dynamic, but we can at least catch some OOB slices). memref_strides, offset = ref_ty.get_strides_and_offset() + dynamic_offset = ir.ShapedType.get_dynamic_stride_or_offset() new_offset = offset - for idx, stride in zip(base_indices, memref_strides): - if isinstance(idx, int): - new_offset += idx * stride - else: - new_offset = ir.ShapedType.get_dynamic_stride_or_offset() - break + if new_offset != dynamic_offset: + for idx, stride in zip(base_indices, memref_strides): + if isinstance(idx, int): + new_offset += idx * stride + else: + new_offset = dynamic_offset + break new_strides = [ s for s, squeeze in zip(memref_strides, is_squeezed) if not squeeze ] @@ -357,6 +359,80 @@ def _is_contiguous_shape_slice( return True +def _reshape(ref: ir.Value, sh0: list[int], sh1: list[int]): + """Reshapes using only "parallel" folds/unfolds. + + This function uses folds/unfolds that are "parallel" in that they + only act on original dimensions, i.e. they won't fold into an + intermediate dimension that they will then unfold. + """ + + i0, i1 = 0, 0 + def fold_until(shape, off , target) -> tuple[int, int]: + assert shape[off] < target + dim = 1 + for to in range(off, len(shape)): + dim *= shape[to] + if dim == target: + return to + 1, dim + if dim > target: + # TODO(cperivol): Implement dependent fold-unfolds for subsections + # of the shape eg (..., 4,5,5, ...) -> (..., 10,10, ...) could be + # supported without touching any other dimensions. + raise NotImplementedError(f"Can't reshape {sh0} to {sh1} bu composing independent folds/unfolds.") + + raise AssertionError(f"Unreachable: number of elements don't match in each shape ({sh0} ans {sh1})") + + while i0 < len(sh0) and i1 < len(sh1): + if sh0[i0] > sh1[i1]: + # How many dimensions following i1 should we unfold i0 into. + idx, _ = fold_until(sh1, i1, sh0[i0]) + ref = memref_unfold(ref, i0, sh1[i1:idx]) + sh0[i0:i0+1] = sh1[i1:idx] + i0 += idx - i1 + i1 = idx + elif sh0[i0] < sh1[i1]: + # How many dimensions after i0 should we fold to make dim at i1. + idx, dim = fold_until(sh0, i0, sh1[i1]) + sh0[i0:idx] = [dim] + ref = memref_fold(ref, i0, idx - i0) + i0 += 1 + i1 += 1 + else: + i0 += 1 + i1 += 1 + + # Fold the trailing ones + if i0 < len(sh0): + assert i1 == len(sh1) + ref = memref_fold(ref, i0 - 1, len(sh0) - i0 + 1) + + if i1 < len(sh1): + assert i0 == len(sh0) + ref = memref_unfold(ref, i0 - 1, [sh0[i0 - 1]] + [1] * (len(sh1) - i1)) + + return ref + + +def memref_reshape(ref: ir.Value, shape: tuple[int, ...]) -> ir.Value: + """Reshape by means of folding and unfolding. + + The use of memref fold/unfold may avoid some possible issues with + strided memrefs. + """ + + ref_ty = ir.MemRefType(ref.type) + if math.prod(ref_ty.shape) != math.prod(shape): + raise ValueError("Cannot reshape to a different size") + if not all(dim > 0 for dim in shape): + raise ValueError( + "Shapes must havbe only positive dimensions (no -1 or 0 dimensions" + f" allowed) {shape}" + ) + + return _reshape(ref, list(ref_ty.shape), list(shape)) + + def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value: ref_ty = ir.MemRefType(ref.type) new_shape = list(ref_ty.shape) @@ -486,7 +562,7 @@ def parse_indices( slice_shape.append(1) is_squeezed.append(True) elif isinstance(idx, slice): - if idx.step is not None: + if idx.step is not None and idx.step != 1: raise NotImplementedError("Strided slices not implemented") base_indices.append(idx.start or 0) slice_shape.append((idx.stop or bound) - (idx.start or 0)) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index bb733e794c5f..0a82137f8dd6 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -18,17 +18,17 @@ https://jax.readthedocs.io/en/latest/pallas.html. """ -from jax._src.deprecations import register as _register_deprecation from jax._src.pallas.core import Blocked from jax._src.pallas.core import BlockSpec from jax._src.pallas.core import CompilerParams from jax._src.pallas.core import CostEstimate from jax._src.pallas.core import GridSpec from jax._src.pallas.core import IndexingMode +from jax._src.pallas.core import MemorySpace +from jax._src.pallas.core import MemoryRef from jax._src.pallas.core import no_block_spec from jax._src.pallas.core import Unblocked from jax._src.pallas.core import unblocked -from jax._src.pallas.core import MemorySpace from jax._src.pallas.pallas_call import pallas_call from jax._src.pallas.pallas_call import pallas_call_p from jax._src.pallas.primitives import atomic_add @@ -58,8 +58,5 @@ from jax._src.state.indexing import Slice from jax._src.state.primitives import broadcast_to -ANY = MemorySpace.ANY - -_register_deprecation("pallas-block-spec-order") -del _register_deprecation +ANY = MemorySpace.ANY diff --git a/jax/experimental/pallas/gpu.py b/jax/experimental/pallas/gpu.py index 4f38192e3a14..0ee84c8453ec 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/experimental/pallas/gpu.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Triton-specific Pallas APIs.""" +from jax._src import deprecations -from jax._src.pallas.triton.core import TritonCompilerParams -from jax._src.pallas.triton.primitives import approx_tanh -from jax._src.pallas.triton.primitives import debug_barrier -from jax._src.pallas.triton.primitives import elementwise_inline_asm +deprecations.warn( + "pallas-gpu-triton", + "The ``jax.experimental.pallas.gpu`` submodule is deprecated. " + " Use ``jax.experimental.pallas.triton`` instead.", + stacklevel=1, +) + +from jax.experimental.pallas.triton import * # noqa: F403 diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py new file mode 100644 index 000000000000..273a7279e717 --- /dev/null +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -0,0 +1,39 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experimental GPU backend for Pallas targeting H100. + +These APIs are highly unstable and can change weekly. Use at your own risk. +""" + +from jax._src.pallas.mosaic_gpu.core import Barrier +from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec +from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams +from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace +from jax._src.pallas.mosaic_gpu.core import SwizzleTransform +from jax._src.pallas.mosaic_gpu.core import TilingTransform +from jax._src.pallas.mosaic_gpu.core import TransposeTransform +from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef +from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC +from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem +from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import wait_barrier +from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import wgmma +from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait + +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`. +GMEM = GPUMemorySpace.GMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`. +SMEM = GPUMemorySpace.SMEM diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 8e28be840d37..66c9dea39734 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -21,7 +21,7 @@ import jax from jax import lax from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu +from jax.experimental.pallas import triton as plgpu import jax.numpy as jnp import numpy as np diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index dde80d4603cc..a7e1b33e1f35 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -21,7 +21,7 @@ import jax from jax import lax from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu +from jax.experimental.pallas import triton as plgpu import jax.numpy as jnp diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index e531395079ba..7d11e4faf299 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -24,7 +24,7 @@ from jax._src.lax.control_flow.for_loop import for_loop from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu +from jax.experimental.pallas import triton as plgpu def layer_norm_forward_kernel( x_ref, weight_ref, bias_ref, # Input arrays diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index 3e373b895b8d..98b26e2d7c82 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -26,7 +26,7 @@ from jax._src.lax.control_flow.for_loop import for_loop from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu +from jax.experimental.pallas import triton as plgpu def rms_norm_forward_kernel( x_ref, weight_ref, bias_ref, # Input arrays diff --git a/jax/experimental/pallas/ops/gpu/softmax.py b/jax/experimental/pallas/ops/gpu/softmax.py index 33b416d165d7..7fc6a0f50cb4 100644 --- a/jax/experimental/pallas/ops/gpu/softmax.py +++ b/jax/experimental/pallas/ops/gpu/softmax.py @@ -18,7 +18,7 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu +from jax.experimental.pallas import triton as plgpu def _vmappable_softmax_kernel( diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 8a1a223ae36e..40fa9dc45ec2 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -22,6 +22,8 @@ from jax._src.pallas.mosaic.core import SemaphoreType from jax._src.pallas.mosaic.core import TPUMemorySpace from jax._src.pallas.mosaic.core import TPUCompilerParams +from jax._src.pallas.mosaic.core import runtime_assert_enabled +from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert from jax._src.pallas.mosaic.lowering import LoweringException from jax._src.pallas.mosaic.pipeline import ARBITRARY from jax._src.pallas.mosaic.pipeline import BufferedRef diff --git a/jax/experimental/pallas/triton.py b/jax/experimental/pallas/triton.py new file mode 100644 index 000000000000..bcee04374ad2 --- /dev/null +++ b/jax/experimental/pallas/triton.py @@ -0,0 +1,20 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton-specific Pallas APIs.""" + +from jax._src.pallas.triton.core import TritonCompilerParams +from jax._src.pallas.triton.primitives import approx_tanh +from jax._src.pallas.triton.primitives import debug_barrier +from jax._src.pallas.triton.primitives import elementwise_inline_asm diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index f5131365cb50..4aa863708189 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -84,7 +84,7 @@ """ from functools import partial import math -from typing import Any +from typing import cast, Any import jax import numpy as np @@ -228,10 +228,10 @@ def _lstm_cudnn_allow_tf32(precision: lax.PrecisionLike) -> bool: # # but we prefer to still invoke it here for consistency precision = lax.canonicalize_precision(precision) - if precision is None: + if precision is None or not (isinstance(precision, tuple) and len(precision) == 2): return True # cuDNN allows only one precision specifier per RNN op - precision, _ = precision + precision, _ = cast(tuple[lax.Precision, lax.Precision], precision) if precision == lax.Precision.HIGHEST: return False elif precision == lax.Precision.HIGH: diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 0dace1977dc0..7a04c88c519b 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -663,8 +663,8 @@ def _shard_map_lowering_shardy( # Nested `ManualComputationOp`s cannot refer to axes that are already # manual. So figure out what axes are free thus far and get the new axis # context. - free_axis = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes - new_axis_context = sharding_impls.SPMDAxisContext(mesh, free_axis - auto) + free_axes = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes + new_axis_context = sharding_impls.SPMDAxisContext(mesh, free_axes - auto) else: new_axis_context = sharding_impls.SPMDAxisContext( mesh, frozenset(mesh.axis_names) - auto) @@ -676,9 +676,10 @@ def _shard_map_lowering_shardy( manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes]) if manual_axes_size == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. - out_nodes, _ = mlir.jaxpr_subcomp( - sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *args, - dim_var_values=ctx.dim_var_values) + with core.extend_axis_env_nd(tuple(mesh.shape.items())): + out_nodes, _ = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *args, + dim_var_values=ctx.dim_var_values) return out_nodes in_shardings = sdy.TensorShardingPerValueAttr.get(map( @@ -1010,8 +1011,8 @@ def _debug_callback_eager_rule(mesh, *args, callback: Callable[..., Any], return [] eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule -def _device_put_eager_rule(mesh, *xs, srcs, devices): - del mesh, srcs +def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): + del mesh, srcs, copy_semantics for device in devices: if device is not None: raise ValueError("device_put with explicit device not allowed within " diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index b20ed8da0326..9f2f0f69be63 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -609,8 +609,7 @@ def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation: Sequenc bcoo_dot_general_p = core.Primitive('bcoo_dot_general') def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers, - precision: None = None, preferred_element_type: None = None, - algorithm: None = None, transpose_algorithm: None = None) -> BCOO | Array: + precision: None = None, preferred_element_type: None = None) -> BCOO | Array: """A general contraction operation. Args: @@ -621,8 +620,6 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: (lhs_batch_dims, rhs_batch_dims))`. precision: unused preferred_element_type: unused - algorithm: unused - transpose_algorithm: unused Returns: An ndarray or BCOO-format sparse array containing the result. If both inputs @@ -630,7 +627,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: the result will be dense, of type ndarray. """ # TODO(jakevdp) make use of these? - del precision, algorithm, transpose_algorithm # unused + del precision # unused if isinstance(lhs, BCOO) and isinstance(rhs, BCOO): shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers) @@ -1056,9 +1053,7 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers) indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True) kwds = {'dimension_numbers': dimension_numbers, 'precision': None, - 'preferred_element_type': None, - 'algorithm': None, - 'transpose_algorithm': None} + 'preferred_element_type': None} A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds) return A, B, indices diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 1b877aec9c75..7275d6bb20aa 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -463,9 +463,7 @@ def _bcsr_extract_batching_rule(batched_args, batch_dims): def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, dimension_numbers: DotDimensionNumbers, precision: None = None, - preferred_element_type: None = None, - algorithm: None = None, - transpose_algorithm: None = None) -> Array: + preferred_element_type: None = None) -> Array: """A general contraction operation. Args: @@ -476,15 +474,13 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, (lhs_batch_dims, rhs_batch_dims))`. precision: unused preferred_element_type: unused - algorithm: unused - transpose_algorithm: unused Returns: An ndarray or BCSR-format sparse array containing the result. If both inputs are sparse, the result will be sparse, of type BCSR. If either input is dense, the result will be dense, of type ndarray. """ - del precision, algorithm, transpose_algorithm # unused + del precision # unused if isinstance(rhs, (np.ndarray, jax.Array)): if isinstance(lhs, (np.ndarray, jax.Array)): return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers, diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index c79dee09cec2..9aa9e42f2a60 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -113,5 +113,4 @@ def _dot_general_validated_shape( rhs = core.ShapedArray(rhs_shape, np.float32) return _dot_general_shape_rule( lhs, rhs, dimension_numbers=dimension_numbers, - precision=None, preferred_element_type=None, algorithm=None, - transpose_algorithm=None) + precision=None, preferred_element_type=None) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index bb72abb2ec32..7f42cfca5fe8 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -20,8 +20,7 @@ Precision as Precision, PrecisionLike as PrecisionLike, DotAlgorithm as DotAlgorithm, - DotAlgorithmLike as DotAlgorithmLike, - DotTransposeAlgorithmLike as DotTransposeAlgorithmLike, + DotAlgorithmPreset as DotAlgorithmPreset, RandomAlgorithm as RandomAlgorithm, RoundingMethod as RoundingMethod, abs as abs, @@ -280,6 +279,8 @@ scatter_mul as scatter_mul, scatter_mul_p as scatter_mul_p, scatter_p as scatter_p, + scatter_sub as scatter_sub, + scatter_sub_p as scatter_sub_p, slice as slice, slice_in_dim as slice_in_dim, slice_p as slice_p, @@ -339,6 +340,7 @@ from jax._src.lax.fft import ( fft as fft, fft_p as fft_p, + FftType as FftType, ) from jax._src.lax.parallel import ( all_gather as all_gather, diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index a51625eb072e..aaf3791037d0 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -12,30 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from jax._src.lax.fft import FftType as _FftType from jax._src.lib import xla_client as _xc -dtype_to_etype = _xc.dtype_to_etype -execute_with_python_values = _xc.execute_with_python_values get_topology_for_devices = _xc.get_topology_for_devices heap_profile = _xc.heap_profile mlir_api_version = _xc.mlir_api_version -ops = _xc.ops -register_custom_call_target = _xc.register_custom_call_target -shape_from_pyval = _xc.shape_from_pyval ArrayImpl = _xc.ArrayImpl Client = _xc.Client CompileOptions = _xc.CompileOptions DeviceAssignment = _xc.DeviceAssignment -FftType = _xc.FftType Frame = _xc.Frame HloSharding = _xc.HloSharding OpSharding = _xc.OpSharding -PaddingType = _xc.PaddingType -PrimitiveType = _xc.PrimitiveType -Shape = _xc.Shape Traceback = _xc.Traceback -XlaBuilder = _xc.XlaBuilder -XlaComputation = _xc.XlaComputation _deprecations = { # Added Aug 5 2024 @@ -48,9 +38,9 @@ _xc.bfloat16, ), # Added Sep 26 2024 - "Device" : ( - "jax.lib.xla_client.Device is deprecated; use jax.Device instead.", - _xc.Device + "Device": ( + "jax.lib.xla_client.Device is deprecated; use jax.Device instead.", + _xc.Device, ), "XlaRuntimeError": ( ( @@ -59,6 +49,52 @@ ), _xc.XlaRuntimeError, ), + # Added Oct 10 2024 + "FftType": ( + "jax.lib.xla_client.FftType is deprecated; use jax.lax.FftType.", + _FftType, + ), + "PaddingType": ( + ( + "jax.lib.xla_client.PaddingType is deprecated; this type is unused" + " by JAX so there is no replacement." + ), + _xc.PaddingType, + ), + # Added Oct 11 2024 + "dtype_to_etype": ( + "dtype_to_etype is deprecated; use StableHLO instead.", + _xc.dtype_to_etype, + ), + "ops": ( + "ops is deprecated; use StableHLO instead.", + _xc.ops, + ), + "register_custom_call_target": ( + "register_custom_call_target is deprecated; use the JAX FFI instead " + "(https://jax.readthedocs.io/en/latest/ffi.html)", + _xc.register_custom_call_target, + ), + "shape_from_pyval": ( + "shape_from_pyval is deprecated; use StableHLO instead.", + _xc.shape_from_pyval, + ), + "PrimitiveType": ( + "PrimitiveType is deprecated; use StableHLO instead.", + _xc.PrimitiveType, + ), + "Shape": ( + "Shape is deprecated; use StableHLO instead.", + _xc.Shape, + ), + "XlaBuilder": ( + "XlaBuilder is deprecated; use StableHLO instead.", + _xc.XlaBuilder, + ), + "XlaComputation": ( + "XlaComputation is deprecated; use StableHLO instead.", + _xc.XlaComputation, + ), } import typing as _typing @@ -66,7 +102,17 @@ if _typing.TYPE_CHECKING: _xla = _xc._xla bfloat16 = _xc.bfloat16 + dtype_to_etype = _xc.dtype_to_etype + ops = _xc.ops + register_custom_call_target = _xc.register_custom_call_target + shape_from_pyval = _xc.shape_from_pyval Device = _xc.Device + FftType = _FftType + PaddingType = _xc.PaddingType + PrimitiveType = _xc.PrimitiveType + Shape = _xc.Shape + XlaBuilder = _xc.XlaBuilder + XlaComputation = _xc.XlaComputation XlaRuntimeError = _xc.XlaRuntimeError else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr @@ -74,4 +120,5 @@ __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing +del _FftType del _xc diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 20c37c55902c..bd806872990f 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -443,6 +443,7 @@ sin as sin, sinc as sinc, sinh as sinh, + spacing as spacing, sqrt as sqrt, square as square, subtract as subtract, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index c23f659bd3f9..0ea9b5ee7fd3 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -218,7 +218,7 @@ def cbrt(x: ArrayLike, /) -> Array: ... cdouble: Any def ceil(x: ArrayLike, /) -> Array: ... character = _np.character -def choose(a: ArrayLike, choices: Sequence[ArrayLike], +def choose(a: ArrayLike, choices: Array | _np.ndarray | Sequence[ArrayLike], out: None = ..., mode: str = ...) -> Array: ... def clip( x: ArrayLike | None = ..., @@ -808,6 +808,7 @@ def sort( order: None = ..., ) -> Array: ... def sort_complex(a: ArrayLike) -> Array: ... +def spacing(x: ArrayLike, /) -> Array: ... def split( ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, diff --git a/jax/sharding.py b/jax/sharding.py index 26c542292e87..9a2d8db218ab 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -31,21 +31,16 @@ from jax._src.mesh import AbstractMesh _deprecations = { - # Added Jun 4, 2024. + # Finalized 2024-10-01; remove after 2025-01-01. "XLACompatibleSharding": ( ( - "jax.sharding.XLACompatibleSharding is deprecated. Use" - " jax.sharding.Sharding instead." + "jax.sharding.XLACompatibleSharding was removed in JAX v0.4.34. " + "Use jax.sharding.Sharding instead." ), - _deprecated_XLACompatibleSharding, + None, ) } -import typing -if typing.TYPE_CHECKING: - XLACompatibleSharding = _deprecated_XLACompatibleSharding -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jax/tools/jax_to_ir.py b/jax/tools/jax_to_ir.py index 3076f36935cb..904ce509a87e 100644 --- a/jax/tools/jax_to_ir.py +++ b/jax/tools/jax_to_ir.py @@ -160,7 +160,7 @@ def ordered_wrapper(*args): raise ValueError( 'Conversion to TF graph requires TensorFlow to be installed.') - f = jax2tf.convert(ordered_wrapper, native_serialization=False) + f = jax2tf.convert(ordered_wrapper) f = tf_wrap_with_input_names(f, input_shapes) f = tf.function(f, autograph=False) g = f.get_concrete_function(*args).graph.as_graph_def() diff --git a/jax/version.py b/jax/version.py index 6c64d75b9733..d888836e629d 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.4.34" +_version = "0.4.35" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.33" +_minimum_jaxlib_version = "0.4.34" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/jax_plugins/BUILD.bazel b/jax_plugins/BUILD.bazel index 6e2cf6aadbaf..08a590b0997a 100644 --- a/jax_plugins/BUILD.bazel +++ b/jax_plugins/BUILD.bazel @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +package( + default_visibility = ["//jax:internal"], +) + licenses(["notice"]) load( diff --git a/jaxlib/BUILD b/jaxlib/BUILD index ab60b3fadd37..8c402cfcefe8 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -239,6 +239,7 @@ pybind_extension( deps = [ "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/status", + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", "@nanobind", "@xla//third_party/python_runtime:headers", diff --git a/jaxlib/ducc_fft.py b/jaxlib/ducc_fft.py deleted file mode 100644 index be9e3aff652f..000000000000 --- a/jaxlib/ducc_fft.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.stablehlo as hlo - - -from .hlo_helpers import custom_call -from .cpu import _ducc_fft -import numpy as np - -from jaxlib import xla_client - -for _name, _value in _ducc_fft.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="cpu") - -FftType = xla_client.FftType - - -_C2C = 0 -_C2R = 1 -_R2C = 2 - - -def _dynamic_ducc_fft_descriptor( - dtype, ndims: int, fft_type: FftType, fft_lengths: list[int] -) -> bytes: - assert len(fft_lengths) >= 1 - assert len(fft_lengths) <= ndims, (fft_lengths, ndims) - - forward = fft_type in (FftType.FFT, FftType.RFFT) - is_double = np.finfo(dtype).dtype == np.float64 - if fft_type == FftType.RFFT: - ducc_fft_type = _R2C - elif fft_type == FftType.IRFFT: - ducc_fft_type = _C2R - else: - ducc_fft_type = _C2C - - # Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the - # C++ kernel to describe the FFT to perform. - axes = [ndims - len(fft_lengths) + d for d in range(len(fft_lengths))] - - descriptor = _ducc_fft.dynamic_ducc_fft_descriptor( - ndims=ndims, - is_double=is_double, - fft_type=ducc_fft_type, - axes=axes, - forward=forward) - - return descriptor - - -def dynamic_ducc_fft_hlo( - result_type: ir.Type, - input: ir.Value, *, - input_dtype: np.dtype, ndims:int, input_shape: ir.Value, - strides_in: ir.Value, strides_out: ir.Value, scale: ir.Value, - fft_type: FftType, fft_lengths: list[int], result_shape: ir.Value): - """DUCC FFT kernel for CPU, with support for dynamic shapes.""" - a_type = ir.RankedTensorType(input.type) - - fft_lengths = list(fft_lengths) - descriptor_bytes = _dynamic_ducc_fft_descriptor( - input_dtype, ndims, fft_type, fft_lengths) - - # PocketFft does not allow size 0 dimensions, but we handled this in fft.py - assert 0 not in a_type.shape - - u8_type = ir.IntegerType.get_unsigned(8) - descriptor = hlo.constant( - ir.DenseElementsAttr.get( - np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)) - layout = tuple(range(ndims - 1, -1, -1)) - return custom_call( - "dynamic_ducc_fft", - result_types=[result_type], - operands=[descriptor, input, input_shape, strides_in, strides_out, scale], - operand_layouts=[[0], layout, [0], [0], [0], [0]], - result_layouts=[layout], - result_shapes=[result_shape]).results diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 048ea23a9cff..7d50a91cfcda 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -82,7 +82,6 @@ cc_proto_library( xla_py_proto_library( name = "triton_py_pb2", - api_version = 2, visibility = jax_visibility("triton_proto_py_users"), deps = [":triton_proto"], ) diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 32cd97565f5e..7852da4bc04f 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -21,6 +21,14 @@ limitations under the License. #include #include +#if JAX_GPU_HAVE_64_BIT +#include +#endif + +#ifdef JAX_GPU_CUDA +#include +#endif + #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -33,14 +41,6 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#if JAX_GPU_64_BIT -#include -#endif - -#ifdef JAX_GPU_CUDA -#include -#endif - #define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) @@ -64,6 +64,28 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, return static_cast(maybe_workspace.value()); } +#if JAX_GPU_HAVE_64_BIT + +// Map an FFI buffer element type to the appropriate GPU solver type. +inline absl::StatusOr SolverDataType(ffi::DataType dataType, + std::string_view func) { + switch (dataType) { + case ffi::F32: + return GPU_R_32F; + case ffi::F64: + return GPU_R_64F; + case ffi::C64: + return GPU_C_32F; + case ffi::C128: + return GPU_C_64F; + default: + return absl::InvalidArgumentError(absl::StrFormat( + "Unsupported dtype %s in %s", absl::FormatStreamed(dataType), func)); + } +} + +#endif + #define SOLVER_DISPATCH_IMPL(impl, ...) \ switch (dataType) { \ case ffi::F32: \ @@ -392,11 +414,74 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, // dispatches dynamically to both syevd and syevj depending on the problem // size and the algorithm selected by the user via the `algorithm` attribute. +#if JAX_GPU_HAVE_64_BIT + +ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result w, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + auto dataType = a.element_type(); + FFI_ASSIGN_OR_RETURN(auto aType, SolverDataType(dataType, "syevd")); + FFI_ASSIGN_OR_RETURN(auto wType, SolverDataType(w->element_type(), "syevd")); + + gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + + gpusolverDnParams_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms)); + std::unique_ptr + params_cleanup( + params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); }); + + size_t workspaceInBytesOnDevice, workspaceInBytesOnHost; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd_bufferSize( + handle.get(), params, jobz, uplo, n, aType, /*a=*/nullptr, n, wType, + /*w=*/nullptr, aType, &workspaceInBytesOnDevice, + &workspaceInBytesOnHost)); + + auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice); + if (!maybe_workspace.has_value()) { + return ffi::Error(ffi::ErrorCode::kResourceExhausted, + "Unable to allocate device workspace for syevd"); + } + auto workspaceOnDevice = maybe_workspace.value(); + auto workspaceOnHost = + std::unique_ptr(new char[workspaceInBytesOnHost]); + + const char* a_data = static_cast(a.untyped_data()); + char* out_data = static_cast(out->untyped_data()); + char* w_data = static_cast(w->untyped_data()); + int* info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + size_t out_step = n * n * ffi::ByteWidth(dataType); + size_t w_step = n * ffi::ByteWidth(ffi::ToReal(dataType)); + + for (auto i = 0; i < batch; ++i) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd( + handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, w_data, + aType, workspaceOnDevice, workspaceInBytesOnDevice, + workspaceOnHost.get(), workspaceInBytesOnHost, info_data)); + out_data += out_step; + w_data += w_step; + ++info_data; + } + + return ffi::Error::Success(); +} + +#endif + template ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, - ffi::ScratchAllocator& scratch, SyevdAlgorithm algorithm, - bool lower, ffi::AnyBuffer a, - ffi::Result out, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, ffi::Result w, ffi::Result> info) { FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); @@ -408,59 +493,84 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); - auto w_data = static_cast::value*>(w->untyped_data()); + auto w_data = + static_cast::value*>(w->untyped_data()); auto info_data = info->typed_data(); if (a_data != out_data) { JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } - if (algorithm == SyevdAlgorithm::kJacobi || - (algorithm == SyevdAlgorithm::kDefault && size <= 32)) { - gpuSyevjInfo_t params; - JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(¶ms)); - std::unique_ptr params_cleanup( - params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); - - if (batch == 1) { - FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize( - handle.get(), jobz, uplo, n, params)); - FFI_ASSIGN_OR_RETURN(auto workspace, - AllocateWorkspace(scratch, lwork, "syevj")); - FFI_RETURN_IF_ERROR_STATUS(solver::Syevj(handle.get(), jobz, uplo, n, - out_data, w_data, workspace, - lwork, info_data, params)); - } else { - FFI_ASSIGN_OR_RETURN( - int lwork, solver::SyevjBatchedBufferSize(handle.get(), jobz, uplo, - n, params, batch)); - FFI_ASSIGN_OR_RETURN( - auto workspace, - AllocateWorkspace(scratch, lwork, "syevj_batched")); - FFI_RETURN_IF_ERROR_STATUS( - solver::SyevjBatched(handle.get(), jobz, uplo, n, out_data, w_data, - workspace, lwork, info_data, params, batch)); - } + + FFI_ASSIGN_OR_RETURN(int lwork, + solver::SyevdBufferSize(handle.get(), jobz, uplo, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "syevd")); + int out_step = n * n; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Syevd(handle.get(), jobz, uplo, n, + out_data, w_data, workspace, + lwork, info_data)); + out_data += out_step; + w_data += n; + ++info_data; + } + + return ffi::Error::Success(); +} + +template +ffi::Error SyevdjImpl(int64_t batch, int64_t size, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result w, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + + gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto w_data = + static_cast::value*>(w->untyped_data()); + auto info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + gpuSyevjInfo_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(¶ms)); + std::unique_ptr params_cleanup( + params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); + + if (batch == 1) { + FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize( + handle.get(), jobz, uplo, n, params)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "syevj")); + FFI_RETURN_IF_ERROR_STATUS(solver::Syevj(handle.get(), jobz, uplo, n, + out_data, w_data, workspace, + lwork, info_data, params)); } else { FFI_ASSIGN_OR_RETURN( - int lwork, solver::SyevdBufferSize(handle.get(), jobz, uplo, n)); + int lwork, solver::SyevjBatchedBufferSize(handle.get(), jobz, uplo, + n, params, batch)); FFI_ASSIGN_OR_RETURN(auto workspace, - AllocateWorkspace(scratch, lwork, "syevd")); - int out_step = n * n; - for (auto i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS(solver::Syevd(handle.get(), jobz, uplo, n, - out_data, w_data, workspace, - lwork, info_data)); - out_data += out_step; - w_data += n; - ++info_data; - } + AllocateWorkspace(scratch, lwork, "syevj_batched")); + FFI_RETURN_IF_ERROR_STATUS( + solver::SyevjBatched(handle.get(), jobz, uplo, n, out_data, w_data, + workspace, lwork, info_data, params, batch)); } + return ffi::Error::Success(); } ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, - SyevdAlgorithm algorithm, bool lower, - ffi::AnyBuffer a, ffi::Result out, + SyevdAlgorithm algorithm, bool lower, ffi::AnyBuffer a, + ffi::Result out, ffi::Result w, ffi::Result> info) { auto dataType = a.element_type(); @@ -479,8 +589,18 @@ ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, CheckShape(out->dimensions(), {batch, rows, cols}, "out", "syevd")); FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "syevd")); FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "syevd")); - SOLVER_DISPATCH_IMPL(SyevdImpl, batch, cols, stream, scratch, algorithm, - lower, a, out, w, info); + if (algorithm == SyevdAlgorithm::kJacobi || + (algorithm == SyevdAlgorithm::kDefault && cols <= 32)) { + SOLVER_DISPATCH_IMPL(SyevdjImpl, batch, cols, stream, scratch, lower, a, + out, w, info); + } else { +#if JAX_GPU_HAVE_64_BIT + return Syevd64Impl(batch, cols, stream, scratch, lower, a, out, w, info); +#else + SOLVER_DISPATCH_IMPL(SyevdImpl, batch, cols, stream, scratch, lower, a, out, + w, info); +#endif + } return ffi::Error::InvalidArgument(absl::StrFormat( "Unsupported dtype %s in syevd", absl::FormatStreamed(dataType))); } @@ -577,7 +697,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch, // Singular Value Decomposition: gesvd -#if JAX_GPU_64_BIT +#if JAX_GPU_HAVE_64_BIT ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream, ffi::ScratchAllocator& scratch, bool full_matrices, @@ -589,30 +709,9 @@ ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream, ffi::Result> info) { FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N'; - auto dataType = a.element_type(); - gpuDataType aType, sType; - switch (dataType) { - case ffi::F32: - aType = GPU_R_32F; - sType = GPU_R_32F; - break; - case ffi::F64: - aType = GPU_R_64F; - sType = GPU_R_64F; - break; - case ffi::C64: - aType = GPU_C_32F; - sType = GPU_R_32F; - break; - case ffi::C128: - aType = GPU_C_64F; - sType = GPU_R_64F; - break; - default: - return ffi::Error::InvalidArgument(absl::StrFormat( - "Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType))); - } + FFI_ASSIGN_OR_RETURN(auto aType, SolverDataType(dataType, "syevd")); + FFI_ASSIGN_OR_RETURN(auto sType, SolverDataType(s->element_type(), "syevd")); gpusolverDnParams_t params; JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms)); @@ -692,7 +791,8 @@ ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols, AllocateWorkspace(scratch, lwork, "gesvd")); auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); - auto s_data = static_cast::value*>(s->untyped_data()); + auto s_data = + static_cast::value*>(s->untyped_data()); auto u_data = compute_uv ? static_cast(u->untyped_data()) : nullptr; auto vt_data = compute_uv ? static_cast(vt->untyped_data()) : nullptr; auto info_data = info->typed_data(); @@ -717,7 +817,7 @@ ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols, return ffi::Error::Success(); } -#endif // JAX_GPU_64_BIT +#endif // JAX_GPU_HAVE_64_BIT ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, bool full_matrices, bool compute_uv, bool transposed, @@ -763,7 +863,7 @@ ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, } FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvd")); -#if JAX_GPU_64_BIT +#if JAX_GPU_HAVE_64_BIT return Gesvd64Impl(batch, m, n, stream, scratch, full_matrices, compute_uv, a, out, s, u, vt, info); #else diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index fa247b08b207..648580f08a92 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -332,7 +332,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuGetDeviceProperties cudaGetDeviceProperties #define gpuLaunchCooperativeKernel cudaLaunchCooperativeKernel -#define JAX_GPU_64_BIT 1 +#define JAX_GPU_HAVE_64_BIT 1 #define GPU_R_32F CUDA_R_32F #define GPU_R_64F CUDA_R_64F @@ -345,6 +345,8 @@ typedef cusolverDnParams_t gpusolverDnParams_t; #define gpusolverDnCreateParams cusolverDnCreateParams #define gpusolverDnDestroyParams cusolverDnDestroyParams +#define gpusolverDnXsyevd_bufferSize cusolverDnXsyevd_bufferSize +#define gpusolverDnXsyevd cusolverDnXsyevd #define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize #define gpusolverDnXgesvd cusolverDnXgesvd @@ -368,7 +370,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #define JAX_GPU_PREFIX "hip" #define JAX_GPU_HAVE_SPARSE 1 -#define JAX_GPU_64_BIT 0 +#define JAX_GPU_HAVE_64_BIT 0 #define JAX_GPU_HAVE_FP8 0 typedef hipFloatComplex gpuComplex; diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index ff1e5570bb04..457d9f59d210 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence from functools import partial import importlib import math @@ -22,13 +21,9 @@ import numpy as np -from .gpu_common_utils import GpuLibNotLinkedError - from jaxlib import xla_client -from .hlo_helpers import ( - DimensionSize, ShapeTypePair, mk_result_types_and_shapes, - custom_call, ensure_hlo_s32, hlo_s32, dense_int_array) +from .hlo_helpers import custom_call, dense_int_array try: from .cuda import _blas as _cublas # pytype: disable=import-error @@ -99,138 +94,6 @@ def _real_type(dtype): return np.finfo(dtype).dtype -# TODO(b/357034884): Remove this function after the forward compat window. -def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a): - """LU decomposition.""" - a_type = ir.RankedTensorType(a.type) - dims = a_type.shape - assert len(dims) >= 2 - m, n = dims[-2:] - batch_dims = tuple(dims[:-2]) - num_bd = len(batch_dims) - i32_type = ir.IntegerType.get_signless(32) - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - - if not gpu_blas: - raise GpuLibNotLinkedError() - - batch = math.prod(batch_dims) - if batch > 1 and m == n and m // batch <= 128: - lwork, opaque = gpu_blas.build_getrf_batched_descriptor( - np.dtype(dtype), batch, m) - workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) - kernel = f"{platform}blas_getrf_batched" - else: - lwork, opaque = gpu_solver.build_getrf_descriptor( - np.dtype(dtype), batch, m, n) - workspace = ir.RankedTensorType.get([lwork], a_type.element_type) - kernel = f"{platform}solver_getrf" - - out = custom_call( - kernel, - result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type), - ir.RankedTensorType.get(batch_dims, i32_type), - workspace, - ], - operands=[a], - backend_config=opaque, - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={0: 0}).results - return out[:3] - - -cuda_getrf = partial(_getrf_hlo, "cu", _cublas, _cusolver) -rocm_getrf = partial(_getrf_hlo, "hip", _hipblas, _hipsolver) - - -def _geqrf_hlo(platform, gpu_solver, dtype, a): - """QR decomposition.""" - a_type = ir.RankedTensorType(a.type) - dims = a_type.shape - assert len(dims) >= 2 - m, n = dims[-2:] - batch_dims = tuple(dims[:-2]) - num_bd = len(batch_dims) - batch = math.prod(batch_dims) - - lwork, opaque = gpu_solver.build_geqrf_descriptor( - np.dtype(dtype), batch, m, n) - - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - i32_type = ir.IntegerType.get_signless(32) - out = custom_call( - f"{platform}solver_geqrf", - result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims + (min(m, n),), a_type.element_type), - ir.RankedTensorType.get(batch_dims, i32_type), - ir.RankedTensorType.get([lwork], a_type.element_type), - ], - operands=[a], - backend_config=opaque, - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={0: 0}).results - return out[:3] - -cuda_geqrf = partial(_geqrf_hlo, "cu", _cusolver) -rocm_geqrf = partial(_geqrf_hlo, "hip", _hipsolver) - -def _geqrf_batched_hlo(platform, gpu_blas, dtype, a): - """Batched QR decomposition.""" - a_type = ir.RankedTensorType(a.type) - dims = a_type.shape - assert len(dims) >= 2 - m, n = dims[-2:] - batch_dims = tuple(dims[:-2]) - num_bd = len(batch_dims) - batch = math.prod(batch_dims) - - if not gpu_blas: - raise GpuLibNotLinkedError() - - lwork, opaque = gpu_blas.build_geqrf_batched_descriptor( - np.dtype(dtype), batch, m, n) - - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - out = custom_call( - f"{platform}blas_geqrf_batched", - result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims + (min(m, n),), a_type.element_type), - ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)), - ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)), - ], - operands=[a], - backend_config=opaque, - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - [0], - [0], - ], - operand_output_aliases={0: 0} - ).results - return out[:2] - -cuda_geqrf_batched = partial(_geqrf_batched_hlo, "cu", _cublas) -rocm_geqrf_batched = partial(_geqrf_batched_hlo, "hip", _hipblas) - - def _csrlsvqr_hlo(platform, gpu_solver, dtype, data, indices, indptr, b, tol, reorder): """Sparse solver via QR decomposition. CUDA only.""" @@ -256,124 +119,6 @@ def _csrlsvqr_hlo(platform, gpu_solver, dtype, data, cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver) -def _orgqr_hlo(platform, gpu_solver, dtype, a, tau): - """Product of elementary Householder reflections.""" - a_type = ir.RankedTensorType(a.type) - dims = a_type.shape - assert len(dims) >= 2 - m, n = dims[-2:] - batch_dims = tuple(dims[:-2]) - num_bd = len(batch_dims) - batch = math.prod(batch_dims) - - tau_dims = ir.RankedTensorType(tau.type).shape - assert tau_dims[:-1] == dims[:-2] - k = tau_dims[-1] - - lwork, opaque = gpu_solver.build_orgqr_descriptor( - np.dtype(dtype), batch, m, n, k) - - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - i32_type = ir.IntegerType.get_signless(32) - out = custom_call( - f"{platform}solver_orgqr", - result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims, i32_type), - ir.RankedTensorType.get([lwork], a_type.element_type), - ], - operands=[a, tau], - backend_config=opaque, - operand_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - ], - result_layouts=[ - layout, - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={0: 0}).results - return out[:2] - -cuda_orgqr = partial(_orgqr_hlo, "cu", _cusolver) -rocm_orgqr = partial(_orgqr_hlo, "hip", _hipsolver) - - -def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, *, - a_shape_vals: tuple[DimensionSize, ...], lower=False): - """Symmetric (Hermitian) eigendecomposition.""" - a_type = ir.RankedTensorType(a.type) - assert len(a_shape_vals) >= 2 - m, n = a_shape_vals[-2:] - assert type(m) is int and type(n) is int and m == n, a_shape_vals - batch_dims_vals = a_shape_vals[:-2] - - num_bd = len(batch_dims_vals) - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - - dynamic_batch_dims = any(type(d) != int for d in batch_dims_vals) - if dynamic_batch_dims: - batch_int = -1 # Signals to the kernel that the batch is an operand. - else: - batch_int = math.prod(batch_dims_vals) - - if have_jacobi_solver and n <= 32 and not dynamic_batch_dims: - # We cannot use syevj for dynamic shapes because the workspace size - # depends on the batch size. - kernel = f"{platform}solver_syevj" - lwork, opaque = gpu_solver.build_syevj_descriptor( - np.dtype(dtype), lower, batch_int, n) - else: - kernel = f"{platform}solver_syevd" - lwork, opaque = gpu_solver.build_syevd_descriptor( - np.dtype(dtype), lower, batch_int, n) - # TODO(Ruturaj4): Currently, hipsolverSsyevd sets lwork to 0 if n==0. - # Remove if this behavior changes in then new ROCm release. - if n > 0 or platform != "hip": - assert lwork > 0 - - if ir.ComplexType.isinstance(a_type.element_type): - eigvals_type = ir.ComplexType(a_type.element_type).element_type - else: - eigvals_type = a_type.element_type - - i32_type = ir.IntegerType.get_signless(32) - operands = [a] - operand_layouts = [layout] - if dynamic_batch_dims: - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - operands.append(batch_size_val) - operand_layouts.append(()) - - shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - (batch_dims_vals + (n,), eigvals_type), - (batch_dims_vals, i32_type), - ([lwork], a_type.element_type)] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - out = custom_call( - kernel, - result_types=result_types, - operands=operands, - backend_config=opaque, - operand_layouts=operand_layouts, - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={0: 0}, - result_shapes=result_shapes).results - return out[:3] - -cuda_syevd = partial(_syevd_hlo, "cu", _cusolver, True) -rocm_syevd = partial(_syevd_hlo, "hip", _hipsolver, True) - - def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, full_matrices=True, compute_uv=True): """Singular value decomposition.""" diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 2e37e694b506..3c812d62cfae 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -277,6 +277,7 @@ def jax_multiplatform_test( "//jax:test_util", ] + deps + if_building_jaxlib([ "//jaxlib/cuda:gpu_only_test_deps", + "//jaxlib/rocm:gpu_only_test_deps", "//jax_plugins:gpu_plugin_only_test_deps", ]), data = data, diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index e23cc0075139..9b3acf641db2 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -30,7 +30,7 @@ from .cpu._lapack import eig from .hlo_helpers import ( custom_call, hlo_u8, hlo_s32, - ensure_hlo_s32, hlo_add, hlo_min, + ensure_hlo_s32, hlo_add, DimensionSize, ShapeTypePair, mk_result_types_and_shapes, ) @@ -162,209 +162,6 @@ def trsm_hlo(dtype, alpha, a, b, ).results -# # ?getrf: LU decomposition - -def getrf_hlo(dtype, a: ir.Value, *, a_shape_vals: tuple[DimensionSize, ...]): - a_type = ir.RankedTensorType(a.type) - assert len(a_shape_vals) >= 2 - batch_dims_vals = a_shape_vals[:-2] - num_bd = len(a_shape_vals) - 2 - m, n = a_shape_vals[-2:] - fn = prepare_lapack_call(fn_base="getrf", dtype=dtype) - - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - - i32_type = ir.IntegerType.get_signless(32) - shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - (batch_dims_vals + (hlo_min(m, n),), i32_type), - (batch_dims_vals, i32_type) - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - - scalar_layout = [] - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - - return custom_call( - fn, - result_types=result_types, - operands=[batch_size_val, ensure_hlo_s32(m), ensure_hlo_s32(n), a], - operand_layouts=[scalar_layout] * 3 + [layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - ], - operand_output_aliases={3: 0}, - result_shapes=result_shapes, - ).results - -# # ?geqrf: QR decomposition - - -def geqrf_hlo( - ctx, dtype, a: ir.Value, *, a_shape_vals: tuple[DimensionSize, ...] -): - a_type = ir.RankedTensorType(a.type) - assert len(a_shape_vals) >= 2 - m, n = a_shape_vals[-2:] - assert type(m) is int - assert type(n) is int - - batch_dims_vals = a_shape_vals[:-2] - num_bd = len(batch_dims_vals) - fn_base = prepare_lapack_call(fn_base="geqrf", dtype=dtype) - - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - i32_type = ir.IntegerType.get_signless(32) - - if ctx.is_forward_compat(): - fn = fn_base - if dtype == np.float32: - lwork = _lapack.lapack_sgeqrf_workspace(m, n) - elif dtype == np.float64: - lwork = _lapack.lapack_dgeqrf_workspace(m, n) - elif dtype == np.complex64: - lwork = _lapack.lapack_cgeqrf_workspace(m, n) - elif dtype == np.complex128: - lwork = _lapack.lapack_zgeqrf_workspace(m, n) - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - scalar_layout = [] - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - (batch_dims_vals + (min(m, n),), a_type.element_type), - (batch_dims_vals, i32_type), - ([lwork], a_type.element_type), - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - return custom_call( - fn, - result_types=result_types, - operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a], - operand_layouts=[scalar_layout] * 4 + [layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={4: 0}, - result_shapes=result_shapes, - ).results[:3] - fn = fn_base + "_ffi" - shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - (batch_dims_vals + (min(m, n),), a_type.element_type), - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - return custom_call( - fn, - result_types=result_types, - operands=[a], - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - ], - operand_output_aliases={0: 0}, - result_shapes=result_shapes, - backend_config={}, - api_version=4, - ).results - - -# # ?orgqr: product of elementary Householder reflectors: -def orgqr_hlo(ctx, dtype, a: ir.Value, tau, *, - a_shape_vals: tuple[DimensionSize, ...], - tau_shape_vals: tuple[DimensionSize, ...]): - fn_base = "un" if dtype == np.complex64 or dtype == np.complex128 else "or" - fn_base = prepare_lapack_call(fn_base=fn_base + "gqr", dtype=dtype) - a_type = ir.RankedTensorType(a.type) - dims = a_type.shape - dims_vals = a_shape_vals - assert len(dims) >= 2 - m, n = dims[-2:] - assert m != ir.ShapedType.get_dynamic_size() - assert n != ir.ShapedType.get_dynamic_size() - batch_dims_vals = dims_vals[:-2] - num_bd = len(batch_dims_vals) - k = tau_shape_vals[-1] - assert type(k) is int - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - i32_type = ir.IntegerType.get_signless(32) - - if ctx.is_forward_compat(): - fn = fn_base - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - - if dtype == np.float32: - lwork = _lapack.lapack_sorgqr_workspace(m, n, k) - elif dtype == np.float64: - lwork = _lapack.lapack_dorgqr_workspace(m, n, k) - elif dtype == np.complex64: - lwork = _lapack.lapack_cungqr_workspace(m, n, k) - elif dtype == np.complex128: - lwork = _lapack.lapack_zungqr_workspace(m, n, k) - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - scalar_layout = [] - shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - (batch_dims_vals, i32_type), - ([lwork], a_type.element_type), - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - return custom_call( - fn, - result_types=result_types, - operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(k), - hlo_s32(lwork), a, tau], - operand_layouts=[scalar_layout] * 5 + [ - layout, - tuple(range(num_bd, -1, -1)), - ], - result_layouts=[ - layout, - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={5: 0}, - result_shapes=result_shapes, - ).results[:2] - fn = fn_base + "_ffi" - shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - return custom_call( - fn, - result_types=result_types, - operands=[ - a, tau - ], - operand_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - ], - result_layouts=[ - layout, - ], - operand_output_aliases={0: 0}, - result_shapes=result_shapes, - backend_config={}, - api_version=4, - ).results - # ?potrf: Cholesky decomposition @@ -543,120 +340,6 @@ def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, ).results[1:] -# # syevd: Symmetric eigendecomposition - -def syevd_hlo(ctx, dtype, a: ir.Value, - a_shape_vals: tuple[DimensionSize, ...], - lower=False): - a_type = ir.RankedTensorType(a.type) - assert len(a_shape_vals) >= 2 - m, n = a_shape_vals[-2:] - # Non-batch dimensions must be static - assert type(m) is int and type(n) is int and m == n, a_shape_vals - - batch_dims_vals = a_shape_vals[:-2] - num_bd = len(a_shape_vals) - 2 - mode = _enum_to_char_attr(eig.ComputationMode.kComputeEigenvectors) - - i32_type = ir.IntegerType.get_signless(32) - workspace: list[ShapeTypePair] - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - # Hermitian is for complex square matrices, symmetric otherwise. - fn_base = "he" if dtype == np.complex64 or dtype == np.complex128 else "sy" - fn_base = prepare_lapack_call(fn_base=fn_base + "evd", dtype=dtype) - if ctx.is_forward_compat(): - fn = fn_base - if dtype == np.float32: - eigvals_type = ir.F32Type.get() - workspace = [ - ([_lapack.syevd_work_size(n)], a_type.element_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.float64: - eigvals_type = ir.F64Type.get() - workspace = [ - ([_lapack.syevd_work_size(n)], a_type.element_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.complex64: - eigvals_type = ir.F32Type.get() - workspace = [ - ([_lapack.heevd_work_size(n)], a_type.element_type), - ([_lapack.heevd_rwork_size(n)], eigvals_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.complex128: - eigvals_type = ir.F64Type.get() - workspace = [ - ([_lapack.heevd_work_size(n)], a_type.element_type), - ([_lapack.heevd_rwork_size(n)], eigvals_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - - scalar_layout = [] - shape_layout = [0] - workspace_layouts = [shape_layout] * len(workspace) - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - - result_types, result_shapes = mk_result_types_and_shapes( - [(a_shape_vals, a_type.element_type), - (batch_dims_vals + (n,), eigvals_type), - (batch_dims_vals, i32_type)] + workspace - ) - - return custom_call( - fn, - result_types=result_types, - operands=[hlo_s32(1 if lower else 0), batch_size_val, ensure_hlo_s32(n), a], - operand_layouts=[scalar_layout] * 3 + [layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - ] + workspace_layouts, - operand_output_aliases={3: 0}, - result_shapes=result_shapes, - ).results[:3] - fn = fn_base + "_ffi" - if dtype == np.float32 or dtype == np.complex64: - eigvals_type = ir.F32Type.get() - elif dtype == np.float64 or dtype == np.complex128: - eigvals_type = ir.F64Type.get() - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - result_types, result_shapes = mk_result_types_and_shapes([ - (a_shape_vals, a_type.element_type), - (batch_dims_vals + (n,), eigvals_type), - (batch_dims_vals, i32_type), - ]) - - return custom_call( - fn, - result_types=result_types, - operands=[a], - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - ], - operand_output_aliases={0: 0}, - result_shapes=result_shapes, - backend_config={ - "uplo": _matrix_uplo_attr(lower=lower), - "mode": mode, - }, - api_version=4, - ).results - - # # geev: Nonsymmetric eigendecomposition (eig) def geev_hlo(ctx, dtype, input, *, @@ -879,8 +562,8 @@ def gees_hlo(dtype, a, *, jobvs=True, sort=False, select=None, # gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form. -def gehrd_hlo(dtype, a): - _lapack.initialize() +def gehrd_hlo(ctx, dtype, a): + fn_base = prepare_lapack_call(fn_base="gehrd", dtype=dtype) a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 @@ -888,47 +571,68 @@ def gehrd_hlo(dtype, a): assert m == n, (m, n) batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) - b = 1 - for d in batch_dims: - b *= d + if ctx.is_forward_compat(): + fn = fn_base + b = 1 + for d in batch_dims: + b *= d - if dtype == np.float32: - fn = "lapack_sgehrd" - lwork = _lapack.lapack_sgehrd_workspace(n, n, 1, n) - elif dtype == np.float64: - fn = "lapack_dgehrd" - lwork = _lapack.lapack_dgehrd_workspace(n, n, 1, n) - elif dtype == np.complex64: - fn = "lapack_cgehrd" - lwork = _lapack.lapack_cgehrd_workspace(n, n, 1, n) - elif dtype == np.complex128: - fn = "lapack_zgehrd" - lwork = _lapack.lapack_zgehrd_workspace(n, n, 1, n) - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") + if dtype == np.float32: + lwork = _lapack.lapack_sgehrd_workspace(n, n, 1, n) + elif dtype == np.float64: + lwork = _lapack.lapack_dgehrd_workspace(n, n, 1, n) + elif dtype == np.complex64: + lwork = _lapack.lapack_cgehrd_workspace(n, n, 1, n) + elif dtype == np.complex128: + lwork = _lapack.lapack_zgehrd_workspace(n, n, 1, n) + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + i32_type = ir.IntegerType.get_signless(32) + return custom_call( + fn, + result_types=[ + a.type, + ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type), + ir.RankedTensorType.get(batch_dims, i32_type), + ir.RankedTensorType.get([lwork], a_type.element_type), + ], + operands=[hlo_s32(n), hlo_s32(1), hlo_s32(n), hlo_s32(n), hlo_s32(b), + hlo_s32(lwork), a], + operand_layouts=[[]] * 6 + [layout], + result_layouts=[ + layout, + (num_bd,) + tuple(range(num_bd - 1, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + [0], + ], + operand_output_aliases={6: 0}, + ).results[:3] + fn = fn_base + "_ffi" layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) - out = custom_call( + return custom_call( fn, result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type), - ir.RankedTensorType.get(batch_dims, i32_type), - ir.RankedTensorType.get([lwork], a_type.element_type), + a.type, + ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type), + ir.RankedTensorType.get(batch_dims, i32_type), ], - operands=[hlo_s32(n), hlo_s32(1), hlo_s32(n), hlo_s32(n), hlo_s32(b), - hlo_s32(lwork), a], - operand_layouts=[[]] * 6 + [layout], + operands=[a], + operand_layouts=[layout], result_layouts=[ - layout, - (num_bd,) + tuple(range(num_bd - 1, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], + layout, + (num_bd,) + tuple(range(num_bd - 1, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), ], - operand_output_aliases={6: 0}, + operand_output_aliases={0: 0}, + backend_config={ + "low": _lapack_int_attr(1), + "high": _lapack_int_attr(n), + }, + api_version=4, ).results - return out[:3] # sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form. diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index ffcc8d52cd05..e28c27129d16 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -38,18 +38,8 @@ class TPU_Attr traits = []> let mnemonic = mnemonic_; } -def TPU_Vreg : Type< - And<[IsVectorTypePred, - Or<[ - And<[ - CPred<"llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef{8, 128}">, - CPred<"llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth() == 32"> - ]>, - CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef{" - "8, 128, 32 / ::llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth()}">, - ]> - ]>, - "native-sized vreg", "::mlir::VectorType">; +// TODO(b/369418606): Find out the way to verify vreg size. +def TPU_Vreg : Type; class TPU_Type traits = []> : TypeDef { @@ -738,6 +728,8 @@ def InferMemRefLayoutPass : Pass<"tpu-infer-memref-layout", "::mlir::func::FuncO // If hardware_generation is not set, the default value of -1 will crash on // runOnOperation. Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, + Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, + Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, Option<"tpu_tiling_flags", "tpu-tiling-flags", "::mlir::tpu::TpuTilingFlags", /*default=*/"::mlir::tpu::TpuTilingFlags{}", "">, ]; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 00bd15b57153..e827faed3d0e 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -68,13 +68,15 @@ struct ApplyVectorLayoutContext { std::pair mightCommunicateBetweenChips(Operation* op); std::unique_ptr> createInferMemRefLayoutPass( - int hardware_generation = -1, const TpuTilingFlags &tpu_tiling_flags = {}); + int hardware_generation = -1, + std::array target_shape = {8, 128}, + const TpuTilingFlags &tpu_tiling_flags = {}); std::unique_ptr> createCanonicalizeMosaicPass( int hardware_generation = -1); std::unique_ptr> createInferVectorLayoutPass( - int lane_count = 128, int sublane_count = 8); + std::array target_shape = {8, 128}); std::unique_ptr> createApplyVectorLayoutPass( const ApplyVectorLayoutContext &ctx = ApplyVectorLayoutContext{}); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index f6e1c7918646..ae7923d1d3d7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -164,7 +164,7 @@ FailureOr> getInternalScratch( FAILUREOR_ASSIGN_OR_RETURN( MemRefType scratch_ref_ty, inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation, - /*tpu_tiling_flags=*/{}, sublane_tiling)); + ctx.target_shape, /*tpu_tiling_flags=*/{}, sublane_tiling)); return builder.create(loc, scratch_ref_ty) .getResult(); } @@ -490,7 +490,7 @@ FailureOr appendConstant(RewriteContext &ctx, func::FuncOp func, MemRefType arg_type, inferMemref( MemRefType::get(value_ty.getShape(), value_ty.getElementType()), - ctx.hardware_generation, /*tpu_tiling_flags=*/{})); + ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{})); const BlockArgument argument = entry_block.insertArgument( entry_block.getNumArguments() - 1, arg_type, UnknownLoc::get(mlir_ctx)); const FunctionType func_ty = func.getFunctionType(); @@ -1714,12 +1714,12 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, llvm::all_of(layouts_in, [&](const Layout &l) { return l.has_value(); })); TPU_ASSERT_OP(layouts_out.front().has_value()); auto matmul_op = cast(op); - auto transpose_lhs = matmul_op.getTransposeLhs(); - auto transpose_rhs = matmul_op.getTransposeRhs(); - auto &layout_lhs = *layouts_in[0]; - auto &layout_rhs = *layouts_in[1]; - auto &layout_acc = *layouts_in[2]; - auto layout_out = *layouts_out[0]; + const auto transpose_lhs = matmul_op.getTransposeLhs(); + const auto transpose_rhs = matmul_op.getTransposeRhs(); + const auto &layout_lhs = *layouts_in[0]; + const auto &layout_rhs = *layouts_in[1]; + const auto &layout_acc = *layouts_in[2]; + const auto &layout_out = *layouts_out[0]; if (transpose_lhs) { return op.emitOpError("Not implemented: Transposed LHS"); } @@ -1740,7 +1740,7 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, acc = tpu_matmul_op.getAcc(); res = tpu_matmul_op.getResult(); } else { - LOG(FATAL) << "Unexpected op type"; + return op.emitOpError("Expected a tpu::MatmulOp"); } for (const std::optional &layout_opt : layouts_in) { @@ -1755,7 +1755,7 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, } } if (acc.getType().getElementType().getIntOrFloatBitWidth() != 32) { - return op.emitOpError("Not implemented: Non-32-bit matmul result"); + return op.emitOpError("Not implemented: Non-32-bit matmul acc"); } const ArrayRef lhs_shape = lhs.getType().getShape(); const ArrayRef rhs_shape = rhs.getType().getShape(); @@ -1882,20 +1882,32 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, // At this point, all paddings on vregs are masked out. For now, we // append zero vregs to make LHS's second dim, both RHS's dims and ACC's // second dim to be a multiple of mxu_size. - if (ctx.mxu_shape[0] != ctx.mxu_shape[1]) { - return op.emitOpError( - "Not implemented: MXU contracting size and noncontracting size are " - "different"); - } - int64_t mxu_size = ctx.mxu_shape[0]; - CHECK_EQ(mxu_size % ctx.target_shape[0], 0); - CHECK_EQ(mxu_size % ctx.target_shape[1], 0); - auto mxu_row_vregs = mxu_size / (ctx.target_shape[0] * layout_rhs.packing()); - auto mxu_col_vregs = mxu_size / ctx.target_shape[1]; - int64_t target_lhs_col_vregs = llvm::alignTo(lhs_vregs.dim(1), mxu_col_vregs); - int64_t target_rhs_row_vregs = llvm::alignTo(rhs_vregs.dim(0), mxu_row_vregs); - int64_t target_rhs_col_vregs = llvm::alignTo(rhs_vregs.dim(1), mxu_col_vregs); - int64_t target_acc_col_vregs = llvm::alignTo(acc_vregs.dim(1), mxu_col_vregs); + auto mxu_contracting_size = ctx.mxu_shape[0]; + auto mxu_noncontracting_size = ctx.mxu_shape[1]; + auto rhs_row_size = mxu_contracting_size; + auto rhs_col_size = mxu_noncontracting_size; + if (transpose_rhs) { + rhs_row_size = mxu_noncontracting_size; + rhs_col_size = mxu_contracting_size; + } + CHECK_EQ(rhs_row_size % ctx.target_shape[1], 0); + CHECK_EQ(rhs_col_size % ctx.target_shape[1], 0); + + // Here, a single group corresponds to a single matmul invocation in unrolled + // code. The RHS group matches the MXU shape. + auto lhs_col_vregs_per_group = mxu_contracting_size / ctx.target_shape[1]; + auto rhs_row_vregs_per_group = + rhs_row_size / (ctx.target_shape[0] * layout_rhs.packing()); + auto rhs_col_vregs_per_group = rhs_col_size / ctx.target_shape[1]; + auto acc_col_vregs_per_group = mxu_noncontracting_size / ctx.target_shape[1]; + int64_t target_lhs_col_vregs = + llvm::alignTo(lhs_vregs.dim(1), lhs_col_vregs_per_group); + int64_t target_rhs_row_vregs = + llvm::alignTo(rhs_vregs.dim(0), rhs_row_vregs_per_group); + int64_t target_rhs_col_vregs = + llvm::alignTo(rhs_vregs.dim(1), rhs_col_vregs_per_group); + int64_t target_acc_col_vregs = + llvm::alignTo(acc_vregs.dim(1), acc_col_vregs_per_group); xla::Array target_lhs_vregs({lhs_vregs.dim(0), target_lhs_col_vregs}, lhs_zeros_vreg); @@ -1908,10 +1920,11 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, target_acc_vregs.UpdateSlice(acc_vregs, {0, 0}); // Now we can regroup vregs from target vregs. - const auto lhs_col_ty = VectorType::get({padded_lhs_rows, mxu_size}, - lhs.getType().getElementType()); - const auto acc_col_ty = VectorType::get({padded_lhs_rows, mxu_size}, - acc.getType().getElementType()); + const auto lhs_col_ty = VectorType::get( + {padded_lhs_rows, mxu_contracting_size}, lhs.getType().getElementType()); + const auto acc_col_ty = + VectorType::get({padded_lhs_rows, mxu_noncontracting_size}, + acc.getType().getElementType()); const ArrayAttr lhs_layout_attr = builder.getArrayAttr({builder.getAttr(layout_lhs)}); const ArrayAttr rhs_layout_attr = @@ -1919,40 +1932,39 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, const ArrayAttr acc_layout_attr = builder.getArrayAttr({builder.getAttr(layout_acc)}); - int64_t nk = llvm::divideCeil(lhs_shape[1], mxu_size); - CHECK_EQ(nk, target_lhs_vregs.dim(1) / mxu_col_vregs); + int64_t nk = llvm::divideCeil(lhs_shape[1], mxu_contracting_size); + CHECK_EQ(nk, target_lhs_vregs.dim(1) / lhs_col_vregs_per_group); SmallVector lhs_cols(nk); for (int64_t i = 0; i < nk; ++i) { const xla::Array col_vregs = target_lhs_vregs.Slice( - {0, i * mxu_col_vregs}, - {target_lhs_vregs.dim(0), (i + 1) * mxu_col_vregs}); + {0, i * lhs_col_vregs_per_group}, + {target_lhs_vregs.dim(0), (i + 1) * lhs_col_vregs_per_group}); lhs_cols[i] = builder.create( op.getLoc(), lhs_col_ty, XlaArrayToFlatArrayRef(col_vregs)); lhs_cols[i]->setAttr("out_layout", lhs_layout_attr); } - // Here, "tile" is used as in the context of the MXU shape (NOT as in the - // context of tiled layouts). - const auto rhs_tile_ty = - VectorType::get({mxu_size, mxu_size}, rhs.getType().getElementType()); - const int64_t rhs_vregs_per_tile = mxu_row_vregs * mxu_col_vregs; + const auto rhs_group_ty = VectorType::get({rhs_row_size, rhs_col_size}, + rhs.getType().getElementType()); + const int64_t rhs_vregs_per_group = + rhs_row_vregs_per_group * rhs_col_vregs_per_group; int64_t nj; if (transpose_rhs) { - nj = llvm::divideCeil(rhs_shape[0], mxu_size); - CHECK_EQ(nk, llvm::divideCeil(rhs_shape[1], mxu_size)); - CHECK_EQ(nk, target_rhs_vregs.dim(1) / mxu_col_vregs); - target_rhs_vregs.Reshape( - {nj, rhs_vregs_per_tile / mxu_col_vregs, nk, mxu_col_vregs}); + nj = llvm::divideCeil(rhs_shape[0], rhs_row_size); + CHECK_EQ(nk, llvm::divideCeil(rhs_shape[1], rhs_col_size)); + CHECK_EQ(nk, target_rhs_vregs.dim(1) / rhs_col_vregs_per_group); + target_rhs_vregs.Reshape({nj, rhs_vregs_per_group / rhs_col_vregs_per_group, + nk, rhs_col_vregs_per_group}); target_rhs_vregs.TransposeDimensions({2, 0, 1, 3}); - target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_tile}); + target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_group}); } else { - nj = llvm::divideCeil(rhs_shape[1], mxu_size); - CHECK_EQ(nk, llvm::divideCeil(rhs_shape[0], mxu_size)); - CHECK_EQ(nk, target_rhs_vregs.dim(0) / mxu_row_vregs); - target_rhs_vregs.Reshape( - {nk, rhs_vregs_per_tile / mxu_col_vregs, nj, mxu_col_vregs}); + nj = llvm::divideCeil(rhs_shape[1], rhs_col_size); + CHECK_EQ(nk, llvm::divideCeil(rhs_shape[0], rhs_row_size)); + CHECK_EQ(nk, target_rhs_vregs.dim(0) / rhs_row_vregs_per_group); + target_rhs_vregs.Reshape({nk, rhs_vregs_per_group / rhs_col_vregs_per_group, + nj, rhs_col_vregs_per_group}); target_rhs_vregs.TransposeDimensions({0, 2, 1, 3}); - target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_tile}); + target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_group}); } const tpu::ContractPrecisionAttr precision_attr = // May be null @@ -1960,28 +1972,29 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, for (int64_t j = 0; j < nj; ++j) { for (int64_t k = 0; k < nk; ++k) { // TODO(tlongeri): there should be a way to slice without copying - xla::Array rhs_tile = - target_rhs_vregs.Slice({k, j, 0}, {k + 1, j + 1, rhs_vregs_per_tile}); - auto rhs_rolled_tile = builder.create( - op.getLoc(), rhs_tile_ty, XlaArrayToFlatArrayRef(rhs_tile)); - rhs_rolled_tile->setAttr("out_layout", rhs_layout_attr); + xla::Array rhs_group = target_rhs_vregs.Slice( + {k, j, 0}, {k + 1, j + 1, rhs_vregs_per_group}); + auto rhs_rolled_group = builder.create( + op.getLoc(), rhs_group_ty, XlaArrayToFlatArrayRef(rhs_group)); + rhs_rolled_group->setAttr("out_layout", rhs_layout_attr); const xla::Array acc_col_vregs = target_acc_vregs.Slice( - {0, j * mxu_col_vregs}, - {target_acc_vregs.dim(0), (j + 1) * mxu_col_vregs}); + {0, j * acc_col_vregs_per_group}, + {target_acc_vregs.dim(0), (j + 1) * acc_col_vregs_per_group}); auto acc_col = builder.create( op.getLoc(), acc_col_ty, XlaArrayToFlatArrayRef(acc_col_vregs)); acc_col->setAttr("out_layout", acc_layout_attr); auto new_acc_col = builder.create( - op.getLoc(), acc_col_ty, lhs_cols[k], rhs_rolled_tile, acc_col, + op.getLoc(), acc_col_ty, lhs_cols[k], rhs_rolled_group, acc_col, transpose_lhs, transpose_rhs, precision_attr); auto new_acc_vregs = builder.create( op.getLoc(), TypeRange(ValueRange(XlaArrayToFlatArrayRef(acc_col_vregs))), new_acc_col); new_acc_vregs->setAttr("in_layout", acc_layout_attr); - updateSliceFromRange(target_acc_vregs, new_acc_vregs->getResults(), - {0, j * mxu_col_vregs}, - {target_acc_vregs.dim(0), (j + 1) * mxu_col_vregs}); + updateSliceFromRange( + target_acc_vregs, new_acc_vregs->getResults(), + {0, j * acc_col_vregs_per_group}, + {target_acc_vregs.dim(0), (j + 1) * acc_col_vregs_per_group}); } } op.replaceAllUsesWith( @@ -5808,8 +5821,8 @@ FailureOr>> changeTiling( if (try_replicate_rows && packing == 1 && *(vregs.dimensions().end() - 2) == 1 && src.offsets() == LayoutOffsets{0, 0} && - src.tiling() == std::array{1, 128} && - dst_tiling == std::array{8, 128}) { + src.tiling() == std::array{1, ctx.target_shape[1]} && + dst_tiling == ctx.target_shape) { xla::Array retiled(dst_tiles_shape); retiled.Each([&](absl::Span idx, Value *tile) { SmallVector src_idx(idx.begin(), idx.end()); @@ -5826,9 +5839,9 @@ FailureOr>> changeTiling( return std::pair(dst, std::move(retiled)); } // (8,128) -> (8 * packing,128) tiling change for packed type. - if (bitwidth < 32 && 32 % bitwidth == 0 && - src_tiling == std::array{8, 128} && - dst_tiling == std::array{8 * dst.packing(), 128}) { + if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape && + dst_tiling == std::array{ctx.target_shape[0] * dst.packing(), + ctx.target_shape[1]}) { // Note: for int4, retiling with scratch is always faster. if (bitwidth != 4 || !has_enough_scratch) { xla::Array retiled(dst_tiles_shape); @@ -5870,8 +5883,8 @@ FailureOr>> changeTiling( // match corresponding elements without shifting. It's just that // the tiles are not adjacent (no contiguous vreg slice). if (bitwidth < 32 && 32 % bitwidth == 0 && - src_tiling == std::array{1, 128 * packing} && - dst_tiling == std::array{packing, 128}) { + src_tiling == std::array{1, ctx.target_shape[1] * packing} && + dst_tiling == std::array{packing, ctx.target_shape[1]}) { // To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of // 4 sublanes and 2 lanes (this is convenient for to keep the example small // yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling. diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index e70e01dfbce7..9f2a8ed73a44 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -332,6 +332,24 @@ LogicalResult canonicalize_extract(int hardware_generation, Operation &raw_op) { return success(); } +LogicalResult canonicalize_select(int hardware_generation, Operation &raw_op) { + auto op = dyn_cast(raw_op); + if (!isa(op.getType()) || + isa(op.getCondition().getType())) { + return success(); + } + // Canonicalize `i1 ? v1 : v2` -> `broadcast(i1) ? v1 : v2`. + ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + auto cond_ty = VectorType::get(cast(op.getType()).getShape(), + op.getCondition().getType()); + auto cond = builder.create(cond_ty, op.getCondition()); + auto new_op = builder.create( + op.getLoc(), cond, op.getTrueValue(), op.getFalseValue()); + op.replaceAllUsesWith(new_op.getResult()); + op.erase(); + return success(); +} + using canonicalize_rule_type = std::function; @@ -341,7 +359,8 @@ const llvm::StringMap &rules() { {vector::ContractionOp::getOperationName(), canonicalize_contraction}, {vector::ContractionOp::getOperationName(), canonicalize_extract}, {vector::MultiDimReductionOp::getOperationName(), - canonicalize_multi_dim_reduction}}; + canonicalize_multi_dim_reduction}, + {arith::SelectOp::getOperationName(), canonicalize_select}}; return *rules; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 6563d6533b1f..541393fc2758 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -1,6 +1,7 @@ #include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h" #include +#include #include #include #include @@ -33,16 +34,17 @@ namespace mlir::tpu { #define GEN_PASS_DEF_INFERMEMREFLAYOUTPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" - -// Returns the number of 128-element groups in a tile. +// Returns the number of lanes (usually 128) groups in a tile. // // Arguments: -// num_128s: A number of 128-element groups in the full operand. +// num_lanes: A number of lanes in the full operand. // hardware_generation: An integer indicating the target TPU generation. +// sublane_count: The number of sublanes. // tpu_tiling_flags: A struct of flags indicating which large tiling modes are // enabled by XLA for memrefs. // bitwidth: The bitwidth of the element type of the operand. -int getTilingFactor(const int num_128s, const int hardware_generation, +int getTilingFactor(const int num_lanes, const int hardware_generation, + const int64_t sublane_count, const TpuTilingFlags &tpu_tiling_flags, const int8_t bitwidth) { CHECK(llvm::isPowerOf2_32(bitwidth)); @@ -50,29 +52,29 @@ int getTilingFactor(const int num_128s, const int hardware_generation, CHECK_LE(bitwidth, 32); const int packing = 32 / bitwidth; const int min_tiling = (1 + (hardware_generation < 4)) * packing; - const int max_normal_tiling = 8; + const int max_normal_tiling = sublane_count; const int large_tiling = [&] { if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) { - return 64; + return sublane_count * 8; } if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) { - return 32; + return sublane_count * 4; } if (bitwidth == 16 && tpu_tiling_flags.use_x16_large_second_minor) { - return 16; + return sublane_count * 2; } - return 8; + return sublane_count; }(); // Use large tiling if our operand is tall enough to fit at least one full // tile. - if (large_tiling <= num_128s) { + if (large_tiling <= num_lanes) { return large_tiling; } int tiling = min_tiling; - while (tiling < std::min(num_128s, max_normal_tiling)) { + while (tiling < std::min(num_lanes, max_normal_tiling)) { tiling *= 2; } return tiling; @@ -80,6 +82,7 @@ int getTilingFactor(const int num_128s, const int hardware_generation, FailureOr inferLayout(MemRefType memref_ty, const int hardware_generation, + std::array target_shape, const TpuTilingFlags &tpu_tiling_flags, int64_t leading_tile_rows = 0) { if (auto tiled_layout_attr = @@ -100,12 +103,14 @@ FailureOr inferLayout(MemRefType memref_ty, "Invalid element type for memref"); } const int8_t bitwidth = memref_ty.getElementTypeBitWidth(); + const auto [sublane_count, lane_count] = target_shape; // Infer the layout if (memref_ty.getRank() == 1) { const int64_t leading_tile = - getTilingFactor(llvm::divideCeil(memref_ty.getShape().back(), 128), - hardware_generation, tpu_tiling_flags, bitwidth) * - 128; + getTilingFactor( + llvm::divideCeil(memref_ty.getShape().back(), lane_count), + hardware_generation, sublane_count, tpu_tiling_flags, bitwidth) * + lane_count; SmallVector tiles{xla::Tile({leading_tile})}; if (bitwidth != 32) { if (!llvm::has_single_bit(bitwidth) || bitwidth > 32) { @@ -113,7 +118,7 @@ FailureOr inferLayout(MemRefType memref_ty, "Unsupported bitwidth: ") << bitwidth; } - tiles.append({xla::Tile({128}), xla::Tile({32 / bitwidth, 1})}); + tiles.append({xla::Tile({lane_count}), xla::Tile({32 / bitwidth, 1})}); } return TiledLayoutAttr::get(memref_ty.getContext(), tiles, {1}); } @@ -122,10 +127,11 @@ FailureOr inferLayout(MemRefType memref_ty, const ArrayRef shape = memref_ty.getShape(); const int64_t second_minor = shape[shape.size() - 2]; if (leading_tile_rows == 0) { - leading_tile_rows = getTilingFactor(second_minor, hardware_generation, - tpu_tiling_flags, bitwidth); + leading_tile_rows = + getTilingFactor(second_minor, hardware_generation, sublane_count, + tpu_tiling_flags, bitwidth); } - SmallVector tiles{xla::Tile({leading_tile_rows, 128})}; + SmallVector tiles{xla::Tile({leading_tile_rows, lane_count})}; if (bitwidth != 32) { if (!llvm::has_single_bit(bitwidth) || bitwidth > 32) { return emitError(UnknownLoc::get(memref_ty.getContext()), @@ -134,7 +140,8 @@ FailureOr inferLayout(MemRefType memref_ty, } tiles.push_back(xla::Tile({32 / bitwidth, 1})); } - auto tile_strides = ComputeTileStrides(memref_ty, {leading_tile_rows, 128}); + auto tile_strides = + ComputeTileStrides(memref_ty, {leading_tile_rows, lane_count}); return TiledLayoutAttr::get(memref_ty.getContext(), tiles, tile_strides); } return emitError(UnknownLoc::get(memref_ty.getContext()), @@ -167,6 +174,7 @@ LogicalResult checkTiles(MLIRContext *mlir_ctx, FailureOr inferMemref(MemRefType memref, const int hardware_generation, + std::array target_shape, const TpuTilingFlags &tpu_tiling_flags, int64_t leading_tile_rows) { if (isa(memref.getElementType())) { @@ -188,9 +196,10 @@ FailureOr inferMemref(MemRefType memref, tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::vmem); const Attribute memory_space = memref.getMemorySpace() == nullptr ? vmem : memref.getMemorySpace(); - FAILUREOR_ASSIGN_OR_RETURN(const TiledLayoutAttr layout, - inferLayout(memref, hardware_generation, - tpu_tiling_flags, leading_tile_rows)); + FAILUREOR_ASSIGN_OR_RETURN( + const TiledLayoutAttr layout, + inferLayout(memref, hardware_generation, target_shape, tpu_tiling_flags, + leading_tile_rows)); const ArrayRef tiles = layout.getTiles(); if (failed(checkTiles(memref.getContext(), tiles))) { @@ -212,13 +221,14 @@ FailureOr inferMemref(MemRefType memref, } LogicalResult inferOp(Operation &op, const int hardware_generation, + std::array target_shape, const TpuTilingFlags &tpu_tiling_flags) { if (auto alloca_op = dyn_cast(op)) { TypedValue arg = alloca_op.getResult(); const MemRefType memref_ty = alloca_op.getResult().getType(); - FAILUREOR_ASSIGN_OR_RETURN( - const MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation, tpu_tiling_flags)); + FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty, + inferMemref(memref_ty, hardware_generation, + target_shape, tpu_tiling_flags)); alloca_op.getResult().setType(new_memref_ty); if (memref_ty != new_memref_ty) { OpBuilder builder(alloca_op->getContext()); @@ -233,9 +243,9 @@ LogicalResult inferOp(Operation &op, const int hardware_generation, } else if (auto alloca_op = dyn_cast(op)) { TypedValue arg = alloca_op.getResult(); const MemRefType memref_ty = alloca_op.getResult().getType(); - FAILUREOR_ASSIGN_OR_RETURN( - const MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation, tpu_tiling_flags)); + FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty, + inferMemref(memref_ty, hardware_generation, + target_shape, tpu_tiling_flags)); alloca_op.getResult().setType(new_memref_ty); if (memref_ty != new_memref_ty) { OpBuilder builder(alloca_op->getContext()); @@ -251,7 +261,8 @@ LogicalResult inferOp(Operation &op, const int hardware_generation, for (Region ®ion : op.getRegions()) { for (Block &block : region) { for (Operation& op : block) { - if (failed(inferOp(op, hardware_generation, tpu_tiling_flags))) { + if (failed(inferOp(op, hardware_generation, target_shape, + tpu_tiling_flags))) { return failure(); } } @@ -261,6 +272,7 @@ LogicalResult inferOp(Operation &op, const int hardware_generation, } LogicalResult inferFunc(func::FuncOp f, const int hardware_generation, + std::array target_shape, const TpuTilingFlags &tpu_tiling_flags) { if (!f.getBody().hasOneBlock()) { return f.emitOpError("Functions should only have a single block"); @@ -285,8 +297,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation, FAILUREOR_ASSIGN_OR_RETURN( const MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation, tpu_tiling_flags, - leading_tile_rows)); + inferMemref(memref_ty, hardware_generation, target_shape, + tpu_tiling_flags, leading_tile_rows)); arg.setType(new_memref_ty); new_arg_types.push_back(arg.getType()); if (memref_ty != new_memref_ty) { @@ -305,30 +317,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation, f.setFunctionType( builder.getAttr(new_arg_types, f.getResultTypes())); for (Operation &op : entry.getOperations()) { - if (failed(inferOp(op, hardware_generation, tpu_tiling_flags))) { - return failure(); - } - } - return success(); -} - -// Infers the layout and memory space attributes of function memref arguments. -// -// In the future we should require those annotations from Mosaic users, but it's -// best to keep them internal for as long as they are under development. -// -// Arguments: -// module: The MLIR module on which to perform the inference. -// hardware_generation: The TPU hardware generation to target. -LogicalResult inferModule(ModuleOp module, const int hardware_generation, - const TpuTilingFlags &tpu_tiling_flags) { - // TODO(apaszke): Do layout assignment for scoped allocations too. - for (Operation &op : *module.getBody()) { - auto f = dyn_cast(op); - if (f == nullptr) { - return module.emitOpError("Expected only FuncOps but found ") << op; - } - if (failed(inferFunc(f, hardware_generation, tpu_tiling_flags))) { + if (failed( + inferOp(op, hardware_generation, target_shape, tpu_tiling_flags))) { return failure(); } } @@ -338,8 +328,11 @@ LogicalResult inferModule(ModuleOp module, const int hardware_generation, struct InferMemRefLayoutPass : public impl::InferMemRefLayoutPassBase { InferMemRefLayoutPass(int hardware_generation_, + std::array target_shape_, const TpuTilingFlags &tpu_tiling_flags_) { hardware_generation = hardware_generation_; + sublane_count = target_shape_[0]; + lane_count = target_shape_[1]; tpu_tiling_flags = tpu_tiling_flags_; } void runOnOperation() override { @@ -349,7 +342,8 @@ struct InferMemRefLayoutPass return; } func::FuncOp func = getOperation(); - if (failed(inferFunc(func, hardware_generation, tpu_tiling_flags))) { + if (failed(inferFunc(func, hardware_generation, {sublane_count, lane_count}, + tpu_tiling_flags))) { signalPassFailure(); return; } @@ -357,9 +351,10 @@ struct InferMemRefLayoutPass }; std::unique_ptr> createInferMemRefLayoutPass( - int hardware_generation, const TpuTilingFlags &tpu_tiling_flags_) { - return std::make_unique(hardware_generation, - tpu_tiling_flags_); + int hardware_generation, std::array target_shape, + const TpuTilingFlags &tpu_tiling_flags_) { + return std::make_unique( + hardware_generation, target_shape, tpu_tiling_flags_); } } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h index c869ef40bf46..ed2a34793536 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h @@ -1,7 +1,9 @@ #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ -#include +#include +#include +#include #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" @@ -10,6 +12,7 @@ namespace mlir::tpu { FailureOr inferMemref(MemRefType memref, int hardware_generation, + std::array target_shape, const TpuTilingFlags& tpu_tiling_flags, int64_t leading_tile_rows = 0); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 2894b0797e7b..50edf0833227 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -21,9 +21,7 @@ limitations under the License. #include #include #include -#include #include -#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" @@ -34,7 +32,6 @@ limitations under the License. #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" @@ -42,6 +39,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/types/span.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" @@ -49,6 +47,7 @@ limitations under the License. #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/Visitors.h" +#include "mlir/include/mlir/Pass/Pass.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "xla/layout.h" @@ -144,11 +143,9 @@ class VectorLayoutInferer { if (!isa(any_op)) { const SmallVector layouts_in = getLayoutFromOperands(&any_op); for (const Layout &layout : layouts_in) { - if (layout && layout->offsets()[1].has_value() && - layout->offsets()[1].value() > layout->tiling()[1]) { - return any_op.emitOpError( - "Not implemented: Inferring from input offsets outside of the " - "first tile"); + if (layout && + layout->offsets()[1].value_or(0) >= layout->tiling()[1]) { + force_first_tile_offsets_ = true; } } } @@ -349,6 +346,7 @@ class VectorLayoutInferer { } CHECK(any_op.getNumResults() == 0 || any_op.hasAttr("out_layout")); CHECK(any_op.getNumOperands() == 0 || any_op.hasAttr("in_layout")); + force_first_tile_offsets_ = false; } return match_terminator(block.getTerminator()); } @@ -856,7 +854,7 @@ class VectorLayoutInferer { auto shape = vty.getShape().take_back(2); if (shape[0] % major_multiple.value_or(tiling[0]) != 0 || shape[1] % minor_multiple.value_or(tiling[1]) != 0) { - op->emitOpError("Matmul operand") + op->emitOpError("Matmul operand ") << operand_name << " must have a shape divisible by (" << major_multiple.value_or(tiling[0]) << ", " << minor_multiple.value_or(tiling[1]) << "), but got: (" << shape[0] @@ -1940,7 +1938,14 @@ class VectorLayoutInferer { auto result_index = op_result.getResultNumber(); auto out_attrs = op->getAttrOfType("out_layout").getValue(); CHECK(out_attrs.size() > result_index); - return cast(out_attrs[result_index]).getLayout(); + auto layout = cast(out_attrs[result_index]).getLayout(); + if (force_first_tile_offsets_ && + layout->offsets()[1].value_or(0) >= layout->tiling()[1]) { + // Force the out-of-first-tile offset to be zero. + layout = VectorLayout(layout->bitwidth(), {layout->offsets()[0], 0}, + layout->tiling(), layout->implicit_dim()); + } + return layout; } SmallVector getLayoutFromOperands(Operation *op) { @@ -2024,6 +2029,10 @@ class VectorLayoutInferer { std::array target_shape_; std::array default_tiling_; + // TODO(b/342235360): Deprecate force_first_tile_offsets_ once we fully + // remove the restriction that offsets must fall within the first tile. + bool force_first_tile_offsets_ = false; + // Address alignment requirement, counted in 32-bit increments. static constexpr int64_t kVmemAlignment32 = 128; // TODO(apaszke): This is not really native on newer generations of TPUs. @@ -2033,9 +2042,9 @@ class VectorLayoutInferer { struct InferVectorLayoutPass : public impl::InferVectorLayoutPassBase { - InferVectorLayoutPass(int lane_count, int sublane_count) { - this->sublane_count = sublane_count; - this->lane_count = lane_count; + InferVectorLayoutPass(std::array target_shape) { + this->sublane_count = target_shape[0]; + this->lane_count = target_shape[1]; } void runOnOperation() override { func::FuncOp func = getOperation(); @@ -2049,8 +2058,8 @@ struct InferVectorLayoutPass } // namespace std::unique_ptr> createInferVectorLayoutPass( - int lane_count, int sublane_count) { - return std::make_unique(lane_count, sublane_count); + std::array target_shape) { + return std::make_unique(target_shape); } } // namespace mlir::tpu diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm_plugin_extension.cc index 0100b37b22e9..c6855879e8be 100644 --- a/jaxlib/rocm_plugin_extension.cc +++ b/jaxlib/rocm_plugin_extension.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#include #include +#include #include #include "nanobind/nanobind.h" @@ -34,7 +36,8 @@ namespace nb = nanobind; namespace xla { namespace { -absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, +absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, + const char* fn_name_c_str, size_t fn_name_size, nb::object fn, int api_version, XLA_FFI_Handler_Traits traits) { if (c_api->extension_start == nullptr) { @@ -59,8 +62,8 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, PJRT_Gpu_Register_Custom_Call_Args args; args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; - args.function_name = fn_name.c_str(); - args.function_name_size = nb::len(fn_name); + args.function_name = fn_name_c_str; + args.function_name_size = fn_name_size; #if PJRT_API_GPU_EXTENSION_VERSION >= 1 args.api_version = api_version; @@ -179,12 +182,23 @@ NB_MODULE(rocm_plugin_extension, m) { tsl::ImportNumpy(); m.def( "register_custom_call_target", - [](nb::capsule c_api, nb::str fn_name, nb::object fn, + [](nb::capsule c_api, nb::object fn_name_py, nb::capsule fn, nb::str xla_platform_name, int api_version, XLA_FFI_Handler_Traits traits) { + const char* fn_name_c_str; + size_t fn_name_size; + nb::str fn_name_bn_str; + if (nb::try_cast(fn_name_py, fn_name_bn_str)) { + fn_name_c_str = fn_name_bn_str.c_str(); + fn_name_size = nb::len(fn_name_bn_str); + } else{ + nb::bytes bytes = nb::cast(fn_name_py); + fn_name_c_str = bytes.c_str(); + fn_name_size = bytes.size(); + } xla::ThrowIfError(RegisterCustomCallTarget( - static_cast(c_api.data()), fn_name, std::move(fn), - api_version, traits)); + static_cast(c_api.data()), fn_name_c_str, + fn_name_size, std::move(fn), api_version, traits)); }, nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), nb::arg("xla_platform_name"), nb::arg("api_version") = 0, diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 3c40c2d11fb5..cbbce31f1be5 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -104,7 +104,6 @@ def patch_copy_mlir_import(src_file, dst_dir): "ifrt_proxy.pyi", "jax_jit.pyi", "ops.pyi", - "outfeed_receiver.pyi", "pmap_lib.pyi", "profiler.pyi", "pytree.pyi", diff --git a/pyproject.toml b/pyproject.toml index 3423783e2407..3b10f4f3607b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ module = [ "jraph.*", "libtpu.*", "matplotlib.*", + "nvidia.*", "numpy.*", "opt_einsum.*", "optax.*", @@ -72,30 +73,6 @@ doctest_optionflags = [ ] addopts = "--doctest-glob='*.rst' --ignore='examples/ffi'" -[tool.pylint.master] -extension-pkg-whitelist = "numpy" - -[tool.pylint."messages control"] -disable = [ - "missing-docstring", - "too-many-locals", - "invalid-name", - "redefined-outer-name", - "redefined-builtin", - "protected-name", - "no-else-return", - "fixme", - "protected-access", - "too-many-arguments", - "blacklisted-name", - "too-few-public-methods", - "unnecessary-lambda" -] -enable = "c-extension-no-member" - -[tool.pylint.format] -indent-string=" " - [tool.ruff] preview = true exclude = [ diff --git a/setup.py b/setup.py index 762b5ad7a281..8f558ccf0434 100644 --- a/setup.py +++ b/setup.py @@ -19,10 +19,10 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.33' +_current_jaxlib_version = '0.4.34' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.4.33' -_libtpu_version = '0.1.dev20240916' +_latest_jaxlib_version_on_pypi = '0.4.34' +_libtpu_version = '0.1.dev20241002' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( diff --git a/tests/BUILD b/tests/BUILD index df9a28236e6a..615437ce4164 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -34,6 +34,7 @@ jax_generate_backend_suites() jax_multiplatform_test( name = "api_test", srcs = ["api_test.py"], + enable_configs = ["tpu_v3_2x2"], shard_count = 10, ) @@ -70,6 +71,9 @@ jax_multiplatform_test( "cpu", "gpu", ], + enable_configs = [ + "gpu_2gpu", + ], tags = ["multiaccelerator"], deps = py_deps("tensorflow_core"), ) @@ -82,9 +86,13 @@ jax_multiplatform_test( }, ) -jax_multiplatform_test( +jax_py_test( name = "config_test", srcs = ["config_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -96,11 +104,6 @@ jax_multiplatform_test( }, ) -jax_multiplatform_test( - name = "custom_object_test", - srcs = ["custom_object_test.py"], -) - jax_multiplatform_test( name = "debug_nans_test", srcs = ["debug_nans_test.py"], @@ -215,6 +218,18 @@ jax_py_test( jax_multiplatform_test( name = "memories_test", srcs = ["memories_test.py"], + enable_configs = [ + "cpu", + "gpu_2gpu", + "tpu_v3_2x2", + "tpu_v4_2x2", + "tpu_v5p_2x2", + "tpu_v5e_4x2", + "cpu_shardy", + "gpu_2gpu_shardy", + "tpu_v3_2x2_shardy", + "tpu_v5e_4x2_shardy", + ], shard_count = { "tpu": 5, }, @@ -235,6 +250,8 @@ jax_multiplatform_test( "gpu_2gpu_shardy", "tpu_v3_2x2_shardy", "tpu_v4_2x2_shardy", + "tpu_v3_2x2", + "gpu_2gpu", ], shard_count = { "cpu": 5, @@ -259,6 +276,11 @@ jax_multiplatform_test( jax_multiplatform_test( name = "shard_alike_test", srcs = ["shard_alike_test.py"], + enable_configs = [ + "tpu_v3_2x2", + "tpu_v5e_4x2", + "tpu_v4_2x2", + ], deps = [ "//jax:experimental", ], @@ -299,6 +321,9 @@ jax_multiplatform_test( backend_tags = { "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, + enable_configs = [ + "tpu_v3_2x2", + ], tags = ["multiaccelerator"], deps = [ "//jax:experimental", @@ -331,7 +356,6 @@ jax_multiplatform_test( name = "infeed_test", srcs = ["infeed_test.py"], deps = [ - "//jax:experimental_host_callback", ], ) @@ -397,11 +421,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "custom_linear_solve_test", srcs = ["custom_linear_solve_test.py"], - shard_count = { - "cpu": 10, - "gpu": 10, - "tpu": 10, - }, ) jax_multiplatform_test( @@ -651,6 +670,10 @@ jax_py_test( jax_multiplatform_test( name = "multibackend_test", srcs = ["multibackend_test.py"], + enable_configs = [ + "tpu_v3_2x2", + "gpu_2gpu", + ], ) jax_multiplatform_test( @@ -700,6 +723,10 @@ jax_multiplatform_test( "requires-mem:16g", # Under tsan on 2x2 this test exceeds the default 12G memory limit. ], }, + enable_configs = [ + "gpu_v100", + "tpu_v3_2x2", + ], shard_count = { "cpu": 30, "gpu": 30, @@ -743,13 +770,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], - # The following cases are disabled because they time out in Google's CI, mostly because the - # CUDA kernels in Torch take a very long time to compile. - disable_configs = [ - "gpu_p100", # Pytorch P100 build times out in Google's CI. - "gpu_a100", # Pytorch A100 build times out in Google's CI. - "gpu_h100", # Pytorch H100 build times out in Google's CI. - ], enable_backends = [ "cpu", "gpu", @@ -858,7 +878,7 @@ jax_multiplatform_test( "nomsan", ], # Times out on TPU with asan/tsan/msan. }, - shard_count = 4, + shard_count = 12, ) jax_multiplatform_test( @@ -1037,6 +1057,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "checkify_test", srcs = ["checkify_test.py"], + enable_configs = ["tpu_v3_2x2"], shard_count = { "gpu": 2, "tpu": 4, @@ -1148,49 +1169,6 @@ jax_multiplatform_test( deps = ["//jax:ode"], ) -jax_multiplatform_test( - name = "host_callback_outfeed_test", - srcs = ["host_callback_test.py"], - args = ["--jax_host_callback_outfeed=true"], - shard_count = { - "tpu": 5, - }, - tags = [ - "noasan", # Times out. - ], - deps = [ - "//jax:experimental", - "//jax:experimental_host_callback", - "//jax:ode", - ], -) - -jax_multiplatform_test( - name = "host_callback_test", - srcs = ["host_callback_test.py"], - args = ["--jax_host_callback_outfeed=false"], - main = "host_callback_test.py", - shard_count = { - "gpu": 5, - }, - tags = ["noasan"], # Times out - deps = [ - "//jax:experimental", - "//jax:experimental_host_callback", - "//jax:ode", - ], -) - -jax_multiplatform_test( - name = "host_callback_to_tf_test", - srcs = ["host_callback_to_tf_test.py"], - tags = ["noasan"], # Linking TF causes a linker OOM. - deps = [ - "//jax:experimental_host_callback", - "//jax:ode", - ] + py_deps("tensorflow_core"), -) - jax_multiplatform_test( name = "key_reuse_test", srcs = ["key_reuse_test.py"], @@ -1237,8 +1215,11 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ - "gpu_h100", "cpu", + "gpu_h100", + "tpu_v2_1x1", + "tpu_v3_2x2", + "tpu_v4_2x2", ], tags = ["multiaccelerator"], ) @@ -1247,8 +1228,11 @@ jax_multiplatform_test( name = "debugging_primitives_test", srcs = ["debugging_primitives_test.py"], enable_configs = [ - "gpu_h100", "cpu", + "gpu_h100", + "tpu_v2_1x1", + "tpu_v3_2x2", + "tpu_v4_2x2", ], ) @@ -1258,6 +1242,11 @@ jax_multiplatform_test( backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, + enable_configs = [ + "tpu_v2_1x1", + "tpu_v3_2x2", + "tpu_v4_2x2", + ], tags = ["multiaccelerator"], deps = [ "//jax:experimental", @@ -1268,8 +1257,11 @@ jax_multiplatform_test( name = "debugger_test", srcs = ["debugger_test.py"], enable_configs = [ - "gpu_h100", "cpu", + "gpu_h100", + "tpu_v2_1x1", + "tpu_v3_2x2", + "tpu_v4_2x2", ], ) @@ -1445,10 +1437,8 @@ jax_multiplatform_test( "tpu": 20, }, tags = [ - "noasan", # Times out, TODO(b/314760446): test failures on Sapphire Rapids. + "noasan", # Times out "nodebug", # Times out. - "nomsan", # TODO(b/314760446): test failures on Sapphire Rapids. - "notsan", # TODO(b/314760446): test failures on Sapphire Rapids. ], deps = [ "//jax:internal_test_harnesses", @@ -1502,15 +1492,11 @@ jax_py_test( jax_multiplatform_test( name = "cudnn_fusion_test", srcs = ["cudnn_fusion_test.py"], - enable_backends = ["gpu"], + enable_backends = [], enable_configs = [ - "gpu_a100", "gpu_h100", ], - tags = [ - "multiaccelerator", - "notap", # TODO(phawkins): this test fails in our internal CI. - ], + tags = ["multiaccelerator"], ) exports_files( diff --git a/tests/api_test.py b/tests/api_test.py index c73d5960f123..d0a711f4a617 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -450,6 +450,28 @@ def test_jit_donate_argnames_kwargs_static_argnums(self): self.assertDeleted(d) self.assertDeleted(e) + def test_device_put_aliasing(self): + arr = jax.device_put(np.arange(8), jax.devices()[0]) + out = jax.device_put(arr, may_alias=True, donate=False) + self.assertEqual(id(arr), id(out)) + + out = jax.device_put(arr, may_alias=False, donate=False) + self.assertNotEqual(id(arr), id(out)) + + with self.assertRaisesRegex( + ValueError, "may_alias and donate cannot be True at the same time."): + jax.device_put(arr, may_alias=True, donate=True) + + out = jax.device_put(arr, + jax.sharding.SingleDeviceSharding(jax.devices()[0]), + may_alias=True, donate=False) + self.assertEqual(id(arr), id(out)) + + out = jax.device_put(arr, + jax.sharding.SingleDeviceSharding(jax.devices()[0]), + may_alias=False, donate=False) + self.assertNotEqual(id(arr), id(out)) + @parameterized.named_parameters( ("argnums", "donate_argnums", 0), ("argnames", "donate_argnames", 'x'), @@ -10906,6 +10928,8 @@ def test_call_wrapped_second_phase_cleanup(self): class EnvironmentInfoTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) def test_print_environment_info(self, return_string): + # Flush stdout buffer before checking. + sys.stdout.flush() with jtu.capture_stdout() as stdout: result = jax.print_environment_info(return_string=return_string) if return_string: diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 24387a767659..2ddd5870268d 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -912,6 +912,25 @@ def f(x): jax.jit(checkify.checkify(f))(0) # Does not crash bc of leaked tracer. + @parameterized.parameters(True, False) + def test_remat(self, jit): + # basic test from https://github.com/jax-ml/jax/issues/23867 + def fn(x: jax.Array): + checkify.check(jnp.all(x > 0), "x must be positive") + return x + 1 + + fn = jax.remat(fn) + if jit: + fn = jax.jit(fn) + fn = checkify.checkify(fn) + err, y = fn(jnp.array([1, 2, 3])) + self.assertIsNone(err.get()) + self.assertAllClose(y, jnp.array([2, 3, 4]), check_dtypes=False) + + err, _ = fn(jnp.array([0, 2, 3])) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "x must be positive") + @jtu.with_config(jax_check_tracer_leaks=True) class AssertPrimitiveTests(jtu.JaxTestCase): diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 75c52822a223..e5222814fb02 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -40,7 +40,6 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.compilation_cache_interface import CacheInterface -from jax._src.lib import xla_client from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P import numpy as np @@ -177,15 +176,11 @@ def test_put_executable(self): executable_retrieved, compile_time_retrieved = cc.get_executable_and_time( key, compile_options, backend) inputs_to_executable = ( - np.array(1, dtype=np.int32), - np.array(2, dtype=np.int32), - ) - expected = xla_client.execute_with_python_values( - executable, inputs_to_executable, backend - ) - actual = xla_client.execute_with_python_values( - executable_retrieved, inputs_to_executable, backend + jnp.array(1, dtype=np.int32), + jnp.array(2, dtype=np.int32), ) + expected = executable.execute(inputs_to_executable) + actual = executable_retrieved.execute(inputs_to_executable) self.assertEqual(expected, actual) self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved) diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py index e70ba12361a2..151cb72be8dc 100644 --- a/tests/cudnn_fusion_test.py +++ b/tests/cudnn_fusion_test.py @@ -26,8 +26,8 @@ class CudnnFusionTest(jtu.JaxTestCase): def setUp(self): if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on >= sm80 GPUs") + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on >= sm90 GPUs") super().setUp() @parameterized.parameters(["", "pmap"]) @@ -58,11 +58,13 @@ def comp1(x, y, z): self.assertIn('custom_call_target="__cudnn$fusion"', hlo) self.assertIn("called_computations=", hlo) - hlo_after_opt = lowered.compile().as_text() + compiled = lowered.compile({"xla_gpu_cublas_fallback": False}) + hlo_after_opt = compiled.as_text() + self.assertIn("kind=kCustom", hlo_after_opt) self.assertIn("plan_id", hlo_after_opt) - self.assertAllClose(jitted(x, y, z), fn(x, y, z)) + self.assertAllClose(compiled(x, y, z), fn(x, y, z)) if __name__ == '__main__': diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py deleted file mode 100644 index 4b1182e16b5a..000000000000 --- a/tests/custom_object_test.py +++ /dev/null @@ -1,348 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -import math -import unittest - -import numpy as np - -import jax -import jax.numpy as jnp -from jax import jit, lax, make_jaxpr -from jax.interpreters import mlir -from jax.interpreters import xla - -from jax._src import core -from jax._src import dtypes -from jax._src import test_util as jtu -from jax._src import xla_bridge -from jax._src.lib.mlir import ir -from jax._src.lib import xla_client - -xc = xla_client -xb = xla_bridge - -jax.config.parse_flags_with_absl() - -# TODO(jakevdp): use a setup/teardown method to populate and unpopulate all the -# dictionaries associated with the following objects. - -# Define a sparse array data structure. The important feature here is that -# it is a jaxpr object that is backed by two device buffers. -class SparseArray: - """Simple sparse COO array data structure.""" - def __init__(self, aval, data, indices): - self.aval = aval - self.shape = aval.shape - self.data = data - self.indices = indices - - @property - def index_dtype(self): - return self.indices.dtype - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.shape[0] - - def __repr__(self): - return repr([(tuple(ind), d) for ind, d in zip(self.indices, self.data)]) - - -class AbstractSparseArray(core.ShapedArray): - __slots__ = ['index_dtype', 'nnz', 'data_aval', 'indices_aval'] - - def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False): - super().__init__(shape, dtypes.canonicalize_dtype(dtype)) - self.index_dtype = index_dtype - self.nnz = nnz - self.data_aval = core.ShapedArray( - (nnz,), dtypes.canonicalize_dtype(dtype), weak_type) - self.indices_aval = core.ShapedArray( - (nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype)) - - def update(self, shape=None, dtype=None, index_dtype=None, nnz=None, - weak_type=None): - if shape is None: - shape = self.shape - if dtype is None: - dtype = self.dtype - if index_dtype is None: - index_dtype = self.dtype - if nnz is None: - nnz = self.nnz - if weak_type is None: - weak_type = self.weak_type - return AbstractSparseArray(shape, dtype, index_dtype, nnz, weak_type) - - def strip_weak_type(self): - return self - - @core.aval_property - def data(self): - return sp_data_p.bind(self) - - @core.aval_property - def indices(self): - return sp_indices_p.bind(self) - -class ConcreteSparseArray(AbstractSparseArray): - pass - - -def sparse_array_shape_handler(a): - return ( - xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape), - xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape), - ) - - -core.pytype_aval_mappings[SparseArray] = lambda x: x.aval -core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval -xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval -xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x - -def sparse_array_mlir_type_handler(a): - return ( - ir.RankedTensorType.get( - a.data_aval.shape, mlir.dtype_to_ir_type(a.data_aval.dtype)), - ir.RankedTensorType.get( - a.indices_aval.shape, mlir.dtype_to_ir_type(a.indices_aval.dtype)), - ) - -mlir.ir_type_handlers[AbstractSparseArray] = sparse_array_mlir_type_handler - -sp_indices_p = core.Primitive('sp_indices') - -@sp_indices_p.def_impl -def _sp_indices_impl(mat): - return mat.indices - -@sp_indices_p.def_abstract_eval -def _sp_indices_abstract_eval(mat): - return mat.indices_aval - -# Note: cannot use lower_fun to define attribute access primitives -# because it leads to infinite recursion. - -def _sp_indices_hlo_lowering(ctx, data_and_indices): - return [data_and_indices[1]] - -mlir.register_lowering(sp_indices_p, _sp_indices_hlo_lowering) - -sp_data_p = core.Primitive('sp_data') - -@sp_data_p.def_impl -def _sp_data_impl(mat): - return mat.data - -@sp_data_p.def_abstract_eval -def _sp_data_abstract_eval(mat): - return mat.data_aval - -# Note: cannot use lower_fun to define attribute access primitives -# because it leads to infinite recursion. - -def _sp_data_hlo_lowering(ctx, data_and_indices): - return [data_and_indices[0]] - -mlir.register_lowering(sp_data_p, _sp_data_hlo_lowering) - -def identity(x): - return identity_p.bind(x) - -identity_p = core.Primitive('identity') - -@identity_p.def_impl -def _identity_impl(mat): - return mat - -@identity_p.def_abstract_eval -def _identity_abstract_eval(mat): - return AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz) - -mlir.register_lowering( - identity_p, mlir.lower_fun(_identity_impl, multiple_results=False)) - -def split(x): - return split_p.bind(x) - -split_p = core.Primitive('split') -split_p.multiple_results = True - -@split_p.def_impl -def _split_impl(mat): - return mat, mat - -@split_p.def_abstract_eval -def _split_abstract_eval(mat): - m = AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz) - return m, m - -mlir.register_lowering( - split_p, mlir.lower_fun(_split_impl, multiple_results=True)) - -def make_sparse_array(rng, shape, dtype, nnz=0.2): - mat = rng(shape, dtype) - size = math.prod(shape) - if 0 < nnz < 1: - nnz = nnz * size - nnz = int(nnz) - if nnz == 0: - mat = np.zeros_like(mat) - elif nnz < size: - # TODO(jakevdp): do we care about duplicates? - cutoff = np.sort(mat.ravel())[nnz] - mat[mat >= cutoff] = 0 - nz = (mat != 0) - data = jnp.array(mat[nz]) - indices = jnp.array(np.where(nz)).T - aval = AbstractSparseArray(shape, data.dtype, indices.dtype, len(indices)) - return SparseArray(aval, data, indices) - -def matvec(mat, v): - v = jnp.asarray(v) - assert v.ndim == 1 - assert len(mat.shape) == 2 - assert v.shape[0] == mat.shape[1] - rows = mat.indices[:, 0] - cols = mat.indices[:, 1] - dv = mat.data * v[cols] - return jnp.zeros(mat.shape[0], dtype=dv.dtype).at[rows].add(dv) - - -class Empty: - def __init__(self, aval): - self.aval = aval - -class AbstractEmpty(core.AbstractValue): - - def join(self, other): - assert isinstance(other, self.__class__), other - return self - - def __hash__(self): - return hash(()) - - def __eq__(self, other): - return isinstance(other, AbstractEmpty) - -class ConcreteEmpty(AbstractEmpty): - pass - - -core.pytype_aval_mappings[Empty] = lambda x: ConcreteEmpty() -core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval -xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty() -xla.canonicalize_dtype_handlers[Empty] = lambda x: x - - -@unittest.skip("Test does not work with jax.Array") -class CustomObjectTest(jtu.JaxTestCase): - - @jtu.sample_product( - primitive=[True, False], - compile=[True, False], - ) - def testSparseIdentity(self, compile, primitive): - f = identity if primitive else (lambda x: x) - f = jit(f) if compile else f - rng = jtu.rand_default(self.rng()) - M = make_sparse_array(rng, (10,), jnp.float32) - M2 = f(M) - - jaxpr = make_jaxpr(f)(M).jaxpr - core.check_jaxpr(jaxpr) - - self.assertEqual(M.dtype, M2.dtype) - self.assertEqual(M.index_dtype, M2.index_dtype) - self.assertAllClose(M.data, M2.data) - self.assertAllClose(M.indices, M2.indices) - - @jtu.sample_product( - compile=[True, False], - ) - def testSparseSplit(self, compile): - f = jit(split) if compile else split - rng = jtu.rand_default(self.rng()) - M = make_sparse_array(rng, (10,), jnp.float32) - M2, M3 = f(M) - - jaxpr = make_jaxpr(f)(M).jaxpr - core.check_jaxpr(jaxpr) - - for MM in M2, M3: - self.assertEqual(M.dtype, MM.dtype) - self.assertEqual(M.index_dtype, MM.index_dtype) - self.assertArraysEqual(M.data, MM.data) - self.assertArraysEqual(M.indices, MM.indices) - - @jtu.sample_product( - primitive=[True, False], - compile=[True, False], - ) - def testSparseLaxLoop(self, compile, primitive): - rng = jtu.rand_default(self.rng()) - f = identity if primitive else (lambda x: x) - f = jit(f) if compile else f - body_fun = lambda _, A: f(A) - M = make_sparse_array(rng, (10,), jnp.float32) - lax.fori_loop(0, 10, body_fun, M) - - @jtu.sample_product(attr=["data", "indices"]) - def testSparseAttrAccess(self, attr): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [make_sparse_array(rng, (10,), jnp.float32)] - f = lambda x: getattr(x, attr) - self._CompileAndCheck(f, args_maker) - - @jtu.sample_product( - shape=[(3, 3), (2, 6), (6, 2)], - dtype=jtu.dtypes.floating, - ) - def testSparseMatvec(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [make_sparse_array(rng, shape, dtype), rng(shape[-1:], dtype)] - self._CompileAndCheck(matvec, args_maker) - - def testLowerToNothing(self): - empty = Empty(AbstractEmpty()) - jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr - core.check_jaxpr(jaxpr) - - # cannot return a unit, because CompileAndCheck assumes array output. - testfunc = lambda e: None - args_maker = lambda: [empty] - self._CompileAndCheck(testfunc, args_maker) - - def testConstantHandler(self): - def make_const_array(): - data = np.arange(3.0) - indices = np.arange(3)[:, None] - shape = (5,) - aval = AbstractSparseArray(shape, data.dtype, indices.dtype, len(indices)) - return SparseArray(aval, data, indices) - out1 = make_const_array() - out2 = jit(make_const_array)() - self.assertArraysEqual(out1.data, out2.data) - self.assertArraysEqual(out1.indices, out2.indices) - - -if __name__ == '__main__': - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 103357ac18ac..7daa20cb159e 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -43,6 +43,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_lapack_gees from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm +from jax._src.internal_test_util.export_back_compat_test_data import cpu_hessenberg_lapack_gehrd from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32 from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf @@ -118,6 +119,7 @@ def test_custom_call_coverage(self): cpu_eigh_lapack_syev.data_2024_08_19, cpu_lu_lapack_getrf.data_2024_05_31, cpu_svd_lapack_gesdd.data_2024_08_13, + cpu_hessenberg_lapack_gehrd.data_2024_08_31, ] # Add here all the testdatas that should cover the targets guaranteed # stable @@ -131,13 +133,14 @@ def test_custom_call_coverage(self): cpu_lu_lapack_getrf.data_2023_06_14, cuda_lu_pivots_to_permutation.data_2024_08_08, cuda_lu_cusolver_getrf.data_2024_08_19, - cuda_qr_cusolver_geqrf.data_2023_03_18, - cuda_eigh_cusolver_syev.data_2023_03_17, + cuda_qr_cusolver_geqrf.data_2024_09_26, + cuda_eigh_cusolver_syev.data_2024_09_30, rocm_qr_hipsolver_geqrf.data_2024_08_05, rocm_eigh_hipsolver_syev.data_2024_08_05, cpu_schur_lapack_gees.data_2023_07_16, cpu_svd_lapack_gesdd.data_2023_06_19, cpu_triangular_solve_blas_trsm.data_2023_07_16, + cpu_hessenberg_lapack_gehrd.data_2024_08_30, tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17, tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17, tpu_ApproxTopK.data_2023_05_16, @@ -160,6 +163,9 @@ def test_custom_call_coverage(self): "tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py "tpu_custom_call", # tested separately "__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py + # The following require ROCm to test + "hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi", + "hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", "hipsolver_syevd_ffi", }) not_covered = targets_to_cover.difference(covered_targets) self.assertEmpty(not_covered, @@ -304,16 +310,19 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): size = 8 operand = CompatTest.eigh_input((size, size), dtype) func = lambda: CompatTest.eigh_harness((8, 8), dtype) - data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name]) rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] + + info = cpu_eigh_lapack_syev.data_2024_08_19[dtype_name] + data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) - # FFI Kernel test - with config.export_ignore_forward_compatibility(True): - data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_eigh_results, operand)) + + # Legacy custom call test + data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_eigh_results, operand), + expect_current_custom_calls=info["custom_call_targets"]) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{variant}", @@ -321,17 +330,19 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): for dtype_name in ("f32", "f64") # We use different custom calls for sizes <= 32 for variant in ["syevj", "syevd"]) - def test_gpu_eigh_solver_syev(self, dtype_name="f32", variant="syevj"): + def test_gpu_eigh_solver_syev_legacy(self, dtype_name="f32", variant="syevj"): if not config.enable_x64.value and dtype_name == "f64": self.skipTest("Test disabled for x32 mode") - if jtu.test_device_matches(["cuda"]): + if jtu.test_device_matches(["rocm"]): + data = self.load_testdata(rocm_eigh_hipsolver_syev.data_2024_08_05[f"{dtype_name}_{variant}"]) + prefix = "hip" + elif jtu.test_device_matches(["cuda"]): if _is_required_cusolver_version_satisfied(11600): # The underlying problem is that this test assumes the workspace size can be # queried from an older version of cuSOLVER and then be used in a newer one. self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized") data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"]) - elif jtu.test_device_matches(["rocm"]): - data = self.load_testdata(rocm_eigh_hipsolver_syev.data_2024_08_05[f"{dtype_name}_{variant}"]) + prefix = "cu" else: self.skipTest("Unsupported platform") # For lax.linalg.eigh @@ -341,6 +352,26 @@ def test_gpu_eigh_solver_syev(self, dtype_name="f32", variant="syevj"): atol = dict(f32=1e-2, f64=1e-10)[dtype_name] operand = CompatTest.eigh_input((size, size), dtype) func = lambda: CompatTest.eigh_harness((size, size), dtype) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_eigh_results, operand), + expect_current_custom_calls=[f"{prefix}solver_syevd_ffi"]) + + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) + def test_gpu_eigh_solver_syev(self, dtype_name="f32"): + if not jtu.test_device_matches(["cuda"]): + self.skipTest("Unsupported platform") + if not config.enable_x64.value and dtype_name in ["f64", "c128"]: + self.skipTest("Test disabled for x32 mode") + dtype = dict(f32=np.float32, f64=np.float64, + c64=np.complex64, c128=np.complex128)[dtype_name] + size = 4 + rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] + atol = dict(f32=1e-2, f64=1e-10, c64=1e-2, c128=1e-10)[dtype_name] + operand = CompatTest.eigh_input((size, size), dtype) + data = self.load_testdata(cuda_eigh_cusolver_syev.data_2024_09_30[dtype_name]) + func = lambda: CompatTest.eigh_harness((size, size), dtype) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) @@ -379,10 +410,8 @@ def test_cuda_lu_lapack_getrf(self, dtype_name:str): c64=np.complex64, c128=np.complex128)[dtype_name] shape = (3, 4) func = lambda: CompatTest.lu_harness(shape, dtype) - # TODO(b/360788062): Clean up after the compatibility period. - with config.export_ignore_forward_compatibility(True): - data = self.load_testdata(cuda_lu_cusolver_getrf.data_2024_08_19[dtype_name]) - self.run_one_test(func, data) + data = self.load_testdata(cuda_lu_cusolver_getrf.data_2024_08_19[dtype_name]) + self.run_one_test(func, data) @staticmethod def qr_harness(shape, dtype): @@ -394,41 +423,59 @@ def qr_harness(shape, dtype): dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): - # For lax.linalg.qr if not config.enable_x64.value and dtype_name in ["f64", "c128"]: self.skipTest("Test disabled for x32 mode") - + rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] dtype = dict(f32=np.float32, f64=np.float64, c64=np.complex64, c128=np.complex128)[dtype_name] func = lambda: CompatTest.qr_harness((3, 3), dtype) - data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) - rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] + + info = cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] + data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol) - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata( - cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol) + # TODO(b/369826500): Remove legacy custom call test after mid March 2025. + data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) + self.run_one_test(func, data, rtol=rtol, + expect_current_custom_calls=info["custom_call_targets"]) + + # TODO(b/369826500): Remove legacy custom call test after mid March 2025. @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{batched}", dtype_name=dtype_name, batched=batched) for dtype_name in ("f32",) # For batched qr we use cublas_geqrf_batched/hipblas_geqrf_batched. for batched in ("batched", "unbatched")) - def test_gpu_qr_solver_geqrf(self, dtype_name="f32", batched="unbatched"): - if jtu.test_device_matches(["cuda"]): - data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched]) - elif jtu.test_device_matches(["rocm"]): + def test_gpu_qr_solver_geqrf_legacy(self, dtype_name, batched): + if jtu.test_device_matches(["rocm"]): data = self.load_testdata(rocm_qr_hipsolver_geqrf.data_2024_08_05[batched]) + prefix = "hip" + elif jtu.test_device_matches(["cuda"]): + data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched]) + prefix = "cu" else: self.skipTest("Unsupported platform") - # For lax.linalg.qr - dtype = dict(f32=np.float32, f64=np.float64)[dtype_name] - rtol = dict(f32=1e-3, f64=1e-5)[dtype_name] + dtype = dict(f32=np.float32)[dtype_name] + rtol = dict(f32=1e-3)[dtype_name] shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched] func = lambda: CompatTest.qr_harness(shape, dtype) + self.run_one_test(func, data, rtol=rtol, expect_current_custom_calls=[ + f"{prefix}solver_geqrf_ffi", f"{prefix}solver_orgqr_ffi"]) + + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) + def test_gpu_qr_solver_geqrf(self, dtype_name="f32"): + if not jtu.test_device_matches(["cuda"]): + self.skipTest("Unsupported platform") + if not config.enable_x64.value and dtype_name in ["f64", "c128"]: + self.skipTest("Test disabled for x32 mode") + dtype = dict(f32=np.float32, f64=np.float64, + c64=np.complex64, c128=np.complex128)[dtype_name] + rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] + shape = (2, 3, 3) + func = lambda: CompatTest.qr_harness(shape, dtype) + data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2024_09_26[dtype_name]) self.run_one_test(func, data, rtol=rtol) def test_tpu_Qr(self): @@ -480,19 +527,22 @@ def test_cpu_lu_lapack_getrf(self, dtype_name:str): c64=np.complex64, c128=np.complex128)[dtype_name] shape = (3, 3) func = lambda: CompatTest.lu_harness(shape, dtype) - data = self.load_testdata(cpu_lu_lapack_getrf.data_2023_06_14[dtype_name]) operand = np.reshape(np.arange(math.prod(shape), dtype=dtype), shape) rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] + info = cpu_lu_lapack_getrf.data_2024_05_31[dtype_name] + data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_lu_results, operand, dtype=dtype)) - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_lu_results, operand, - dtype=dtype)) + + # TODO(b/357034884): Remove legacy custom call test after mid March 2025. + legacy_data = self.load_testdata( + cpu_lu_lapack_getrf.data_2023_06_14[dtype_name]) + self.run_one_test(func, legacy_data, rtol=rtol, atol=atol, + check_results=partial(self.check_lu_results, operand, + dtype=dtype), + expect_current_custom_calls=info["custom_call_targets"]) def check_svd_results(self, input, res_run, res_exp, rtol=None, atol=None): @@ -656,6 +706,36 @@ def check_triangular_solve_results(res_run, res_expected, *, rtol, atol): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_triangular_solve_results) + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) + @jax.default_matmul_precision("float32") + def test_cpu_hessenberg_lapack_gehrd(self, dtype_name="f32"): + if not config.enable_x64.value and dtype_name in ["f64", "c128"]: + self.skipTest("Test disabled for x32 mode") + + dtype = dict(f32=np.float32, f64=np.float64, + c64=np.complex64, c128=np.complex128)[dtype_name] + shape = (2, 4, 4) + input_data = jtu.rand_default(self.rng())(shape, dtype) + # del input_data # Input is in the testdata, here for readability + def func(): + return lax.linalg.hessenberg(input_data) + + rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] + atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] + + data = self.load_testdata( + cpu_hessenberg_lapack_gehrd.data_2024_08_30[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_hessenberg_lapack_gehrd.data_2024_08_31[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol) + def test_approx_top_k(self): def func(): x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0]) diff --git a/tests/extend_test.py b/tests/extend_test.py index fff3314a7656..805ad937bc02 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import unittest import numpy as np from absl.testing import absltest @@ -178,6 +179,55 @@ def fun(): self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) + def testEffectsHlo(self): + # The target name must exist on the current platform, but we don't actually + # need to call it with the correct syntax, because we're only checking the + # compiled HLO. + if jtu.test_device_matches(["cpu"]): + target_name = "lapack_sgetrf_ffi" + elif jtu.test_device_matches(["rocm"]): + target_name = "hipsolver_getrf_ffi" + elif jtu.test_device_matches(["cuda", "gpu"]): + target_name = "cusolver_getrf_ffi" + else: + raise unittest.SkipTest("Unsupported device") + def fun(): + jex.ffi.ffi_call(target_name, (), has_side_effect=True) + hlo = jax.jit(fun).lower() + self.assertIn(target_name, hlo.as_text()) + self.assertIn("has_side_effect = true", hlo.as_text()) + self.assertIn(target_name, hlo.compile().as_text()) + + def testJvpError(self): + def fun(x): + return jex.ffi.ffi_call("test_ffi", x, x, non_hashable_arg={"a": 1}) + with self.assertRaisesRegex( + ValueError, "The FFI call to `.+` cannot be differentiated."): + jax.jvp(fun, (0.5,), (0.5,)) + + def testNonHashableAttributes(self): + def fun(x): + return jex.ffi.ffi_call("test_ffi", x, x, non_hashable_arg={"a": 1}) + + self.assertIn("HashableDict", str(jax.make_jaxpr(fun)(jnp.ones(5)))) + hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() + self.assertIn("non_hashable_arg = {a = 1", hlo) + + # If non-hashable arguments aren't handled properly, this will raise a + # TypeError. We make sure it doesn't. + with self.assertRaises(Exception) as manager: + fun(jnp.ones(5)) + self.assertNotIsInstance(manager.exception, TypeError) + + def fun(x): + return jex.ffi.ffi_call("test_ffi", x, x, non_hashable_arg=np.arange(3)) + self.assertIn("HashableArray", str(jax.make_jaxpr(fun)(jnp.ones(5)))) + hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() + self.assertIn("non_hashable_arg = array", hlo) + with self.assertRaises(Exception) as manager: + fun(jnp.ones(5)) + self.assertNotIsInstance(manager.exception, TypeError) + @jtu.sample_product( shape=[(1,), (4,), (5,)], dtype=(np.int32,), @@ -195,10 +245,11 @@ def testFfiCall(self, shape, dtype): @jtu.sample_product( shape=[(1,), (4,), (5,)], dtype=(np.int32,), - vectorized=(False, True), + vmap_method=("broadcast", "broadcast_fullrank", "sequential", + "legacy_vectorized"), ) @jtu.run_on_devices("gpu") - def testFfiCallBatching(self, shape, dtype, vectorized): + def testFfiCallBatching(self, shape, dtype, vmap_method): shape = (10,) + shape pivots_size = shape[-1] permutation_size = 2 * pivots_size @@ -206,15 +257,29 @@ def testFfiCallBatching(self, shape, dtype, vectorized): pivots = jnp.broadcast_to(pivots, shape) expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) actual = jax.vmap(lambda x: ffi_call_lu_pivots_to_permutation( - x, permutation_size, vectorized=vectorized))(pivots) + x, permutation_size, vmap_method=vmap_method))(pivots) self.assertArraysEqual(actual, expected) + @jtu.run_on_devices("gpu") + def testVectorizedDeprecation(self): + pivots_size = 4 + shape = (10, pivots_size) + permutation_size = 2 * pivots_size + pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, + dtype=np.int32) + pivots = jnp.broadcast_to(pivots, shape) + with self.assertWarns(DeprecationWarning): + ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True) + with self.assertWarns(DeprecationWarning): + jax.vmap( + lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size))(pivots) + # TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation` # custom call target because that's the only one in jaxlib that uses the # new FFI interface. Once more are available, consider using something that # can be run on multiple platforms. -def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True): +def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs): return jex.ffi.ffi_call( "cu_lu_pivots_to_permutation", jax.ShapeDtypeStruct( @@ -222,9 +287,7 @@ def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True) dtype=pivots.dtype, ), pivots, - # TODO(b/358275922): Remove this after jaxlib v0.4.32 is released. - permutation_size=np.int32(permutation_size), - vectorized=vectorized, + **kwargs, ) diff --git a/tests/fft_test.py b/tests/fft_test.py index a87b7b66e150..e64fa4db1277 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -26,6 +26,7 @@ from jax._src import config from jax._src import dtypes from jax._src import test_util as jtu +from jax._src.lib import version as jaxlib_version from jax._src.numpy.util import promote_dtypes_complex from jax._src.numpy.fft import _fft_norm @@ -161,7 +162,7 @@ def testFftn(self, inverse, real, shape, dtype, axes, s, norm): # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(jnp_fn, args_maker) + self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6}) # Test gradient for differentiable types. if (config.enable_x64.value and dtype in (float_dtypes if real and not inverse else inexact_dtypes)): @@ -477,5 +478,16 @@ def testFftnormOverflow(self, norm, func_name, dtype): np_norm = np.reciprocal(np_norm) self.assertArraysAllClose(jax_norm, np_norm, rtol=3e-8, check_dtypes=False) + def testFftNormalizationPrecision(self): + # reported in https://github.com/jax-ml/jax/issues/23827 + if not config.enable_x64.value: + raise self.skipTest("requires jax_enable_x64=true") + if jaxlib_version <= (0, 4, 33): + raise self.skipTest("requires jaxlib version > 0.4.33") + n = 31 + a = np.ones((n, 15), dtype="complex128") + self.assertArraysAllClose( + jnp.fft.ifft(a, n=n, axis=1), np.fft.ifft(a, n=n, axis=1), atol=1e-14) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py deleted file mode 100644 index 837d205fbbed..000000000000 --- a/tests/host_callback_test.py +++ /dev/null @@ -1,3089 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import contextlib -from collections.abc import Callable, Sequence -from functools import partial -import itertools -import logging -import os -import re -import threading -import time -import unittest -from unittest import skip, SkipTest - -from absl.testing import absltest - -import jax -from jax import ad_checkpoint -from jax import dtypes -from jax import lax -from jax import numpy as jnp -from jax.experimental import host_callback as hcb -from jax.experimental import pjit -from jax.sharding import PartitionSpec as P -from jax._src import core -from jax._src import xla_bridge -from jax._src import test_util as jtu -from jax._src.lib import xla_client - -from jax.experimental.host_callback import _deprecated_id_print as hcb_id_print - -xops = xla_client.ops - -import numpy as np - -jax.config.parse_flags_with_absl() - - -class _TestingOutputStream: - """Use as `output_stream` for tests.""" - - def __init__(self): - self._output = [] - self._test_method_name = None - - def write(self, what: str) -> None: - logging.info(f"output_stream[{self._test_method_name}]: {what}") - self._output.append(what) - - @property - def output(self): - return "".join(self._output) - - @property - def output_sorted_by_device(self): - # Assume that the output is a sequence of strings including metadata - # and data, with metadata containing `device: xxx` - by_device = [] # each element is a pair (device, str_list) - for s in self._output: - m = re.match(r".*device: (\S+)", s) - if m: - by_device.append((m.group(1), [])) - assert by_device, f"output does not include 'device:': {self._output}" - by_device[-1][1].append(s) - - sorted_by_device = sorted(by_device, key=lambda x: x[0]) - return "\n".join(itertools.chain(*[s[1] for s in sorted_by_device])) - - def __str__(self): - return "TestingOutputStream" - - def reset(self): - self._output = [] - - -testing_stream = _TestingOutputStream() - - -def fun1(a): - """Function used for several `id_tap` tests.""" - y = hcb_id_print(a * 2., what="a * 2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - y = hcb_id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return y ** 2 # Some computation to make the gradient interesting - - -def fun1_equiv(a): # Numerical equivalent of fun1 - return (a * 2.) ** 2 - - -def maybe_print(do_print: bool, - arg, - what: str, - tap_with_device: bool | None = False, - device_index: int = 0): - """Conditionally print on testing_string""" - if do_print: - return hcb_id_print( - arg, - what=what, - output_stream=testing_stream, - tap_with_device=tap_with_device, - device_index=device_index) - else: - return arg - - -def local_devices(): - # Tests require using not more than 2 devices. - return jax.local_devices()[:2] - - -ignore_jit_of_pmap_warning = partial( - jtu.ignore_warning, message=".*jit-of-pmap.*") - - -def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, - expected: str, what: str): - """A variant that preprocesses the string to eliminate non-determinism in - floating point values, and several uninteresting id_tap primitive params. - """ - - # Sometimes we get floating points in the output; we round them - def repl_floats(match_group): - matched = match_group.group(0) - if matched == ".": return matched - x = np.around(float(matched), decimals=2) - return f"{x:.2f}" - - what = re.sub(r"\-?\d+\.[\-\def]*", repl_floats, what) - what = re.sub(r"output_stream=[^\]\n,]*,?", "", what) - what = re.sub(r"threshold=[^\]\n,]*,?", "", what) - what = re.sub(r"bwd=[^\]\n]*", "", what) - what = re.sub(r"out_trees=[^\]\n]*", "", what) - what = re.sub(r"fwd_jaxpr_thunk=[^\]\n]*", "", what) - what = re.sub(r"jvp_jaxpr_thunk=[^\]\n]*", "", what) - # Empty lines - what = re.sub(r"^\s*\n", "", what, flags=re.MULTILINE) - - def repl_func(match_group): - matched = match_group.group(3) - if "function _print_consumer" in matched: - return match_group.group(1) + "=_print" - else: - return match_group.group(1) + "=..." - - what = re.sub(r"((tap_func_)|(callback))=([^\]\n,]*),?", repl_func, what) - tst.assertMultiLineStrippedEqual(expected, what) - - -def helper_set_hlo_dump(): - flags_str = os.getenv("XLA_FLAGS", "") - import shutil - dump_dir = "/tmp/xla_dump" - os.environ["XLA_FLAGS"] = f"{flags_str} --xla_dump_to={dump_dir}" - if os.path.isdir(dump_dir): - logging.warning("Deleting old XLA dump directory %s", dump_dir) - shutil.rmtree(dump_dir) - logging.warning("Setting XLA dump directory %s", dump_dir) - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - - -def helper_print_optimized_hlo(fun, *args): - backend = xla_bridge.get_backend(platform=jtu.device_under_test()) - c = jax.jit(fun, backend=backend.platform).lower(*args) - logging.info(re.sub(r", metadata.*", "", c.compile().as_text())) - - -def helper_log_ir(name, - f_jax, - *args, - num_partitions=None, - strip_metadata=False): - logging.info(f"Jaxpr[{name}]: {jax.make_jaxpr(f_jax)(*args)}") - jax_comp = f_jax.lower(*args) - logging.info(f"HLO[{name}]: {jax_comp.compiler_ir(dialect='hlo').as_hlo_text()}") - jax_optimized_hlo = jax_comp.compile().as_text() - if strip_metadata: - jax_optimized_hlo = re.sub(r", metadata.*", "", jax_optimized_hlo) - logging.info(f"Optimized HLO[{name}]: {jax_optimized_hlo}") - - -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() - - -def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase, - expected_2CPUs: str): - """Check that the multi-device output is equal to the expected. - - The tests run with 2 devices if available, otherwise 1 device. - We adjust the expected output here for 1 device. - - Args: - expected_2CPUs: the expected output for 2 CPUs. If there is only - one device, this is trimmed to the first device. If the current - device_under_test is not a CPU, then we change the names - """ - expected = expected_2CPUs - if len(local_devices()) == 1: - start_device_1 = expected.find('device: cpu:1') - if start_device_1 >= 0: - expected = expected[0:start_device_1] - - def replace_device_name(m) -> str: - return str(local_devices()[int(m.group(1))]) - - expected = re.sub(r'cpu:(\d+)', replace_device_name, expected) - what = testing_stream.output_sorted_by_device - return assertMultiLineStrippedEqual(tst, expected, what) - - -class HostCallbackImportsTest(jtu.JaxTestCase): - @jtu.ignore_warning( - category=DeprecationWarning, - message="The host_callback APIs are deprecated") - def test_deprecated_imports(self): - if hasattr(hcb, "id_print"): - id_print = hcb.id_print - self.assertIs(id_print, hcb_id_print) - -class HostCallbackTapTest(jtu.JaxTestCase): - - def setUp(self): - # skipping here skips teardown, so do this before super().setUp(). - if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: - raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") - if xla_bridge.using_pjrt_c_api(): - raise SkipTest("host_callback not implemented in PJRT C API") - super().setUp() - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="The host_callback APIs are deprecated")) - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="backend and device argument")) - testing_stream.reset() - testing_stream._test_method_name = self._testMethodName - self.old_flags = os.getenv("XLA_FLAGS", "") - - def tearDown(self) -> None: - if os.getenv("XLA_FLAGS") != self.old_flags: - os.environ["XLA_FLAGS"] = self.old_flags - xla_bridge.get_backend.cache_clear() - hcb.barrier_wait("HostCallbackTapTest.tearDown") - super().tearDown() - - def supported_only_in_legacy_mode(self): - if not hcb._HOST_CALLBACK_LEGACY.value: - self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") - - def test_tap_eval(self): - self.assertAllClose((5. * 2.) ** 2, fun1(5.)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - what: a * 2 - 10.00 - what: y * 3 - 30.00""", testing_stream.output) - - def test_tap_with_tuple_results(self): - def func2(x): - x1, y1 = hcb_id_print((x * 2., x * 3.), output_stream=testing_stream) - return x1 + y1 - - self.assertEqual(3. * (2. + 3.), func2(3.)) - hcb.barrier_wait() - - assertMultiLineStrippedEqual(self, """ - ( 6.00 9.00 )""", testing_stream.output) - - def test_tap_with_dict_results(self): - def func2(x): - res = hcb_id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream) - return res["a"] + res["b"] - - self.assertEqual(3. * (2. + 3.), func2(3.)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - { a=6.00 b=9.00 }""", testing_stream.output) - - def test_tap_with_result(self): - def func2(x): - x1 = hcb_id_print((x * 2., x * 3.), result=x * 4., - output_stream=testing_stream) - return x1 - - self.assertEqual(3. * 4., func2(3.)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - ( 6.00 9.00 )""", testing_stream.output) - - def test_tap_with_result_no_arg(self): - def tap_func(arg, transforms): - testing_stream.write(f"called tap_func with {arg}") - - def func2(x): - x1 = hcb.id_tap(tap_func, None, result=x) - return x1 - - self.assertEqual(3., func2(3.)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, "called tap_func with None", - testing_stream.output) - - def test_tap_result_unused(self): - def tap_func(arg, transforms): - testing_stream.write(f"called tap_func with {arg}") - def func2(x): - hcb.id_tap(tap_func, None) - return x - - self.assertEqual(3., func2(3.)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, "called tap_func with None", - testing_stream.output) - - def test_tap_with_device(self): - self.supported_only_in_legacy_mode() - def func2(x): - x1 = hcb_id_print((x * 2., x * 3.), result=x * 4., - output_stream=testing_stream, - tap_with_device=True) - return x1 - - self.assertEqual(3. * 4., func2(3.)) - hcb.barrier_wait() - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 - ( 6.00 9.00 )""") - - def test_tap_eval_exception(self): - if not hcb._HOST_CALLBACK_OUTFEED.value: - raise SkipTest("TODO: implement error handling for customcall") - - # Simulate a tap error - def tap_err(*args, **kwargs): - raise ValueError("Some user message") - - def func(x): - x1 = hcb_id_print(x + 1, what="x1", output_stream=testing_stream) - x2 = hcb.id_tap(tap_err, x1 + 1) - x3 = hcb_id_print(x2 + 1, what="x3", output_stream=testing_stream) - return x3 - - if hcb._HOST_CALLBACK_LEGACY.value: - ctx = self.assertRaisesRegex( - hcb.CallbackException, - re.compile("There were exceptions during callback processing. Last one was:.*" - "ValueError: Some user message", re.DOTALL)) - else: - ctx = self.assertRaisesRegex(Exception, "Some user message") - - with ctx: - func(0) - hcb.barrier_wait() - - if hcb._HOST_CALLBACK_LEGACY.value: - # We should have received everything before the error - assertMultiLineStrippedEqual(self, """ - what: x1 - 1 - what: x3 - 3""", testing_stream.output) - else: - # We should have received everything before the error - assertMultiLineStrippedEqual(self, """ - what: x1 - 1""", testing_stream.output) - - def test_tap_empty(self): - """Tap empty arrays.""" - hcb_id_print((), output_stream=testing_stream) - hcb_id_print((1., np.ones((2, 0))), what="second", output_stream=testing_stream) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - ( ) - what: second - ( 1.00 [] )""", testing_stream.output) - - def test_tap_jit_simple(self): - jit_fun1 = jax.jit(lambda x: 3. * hcb_id_print( - 2. * x, what="here", output_stream=testing_stream)) - self.assertAllClose(6. * 5., jit_fun1(5.)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - what: here - 10.00""", testing_stream.output) - - def test_tap_jit_no_invars(self): - def func(): # jitted function does not take arguments - return hcb_id_print(42, output_stream=testing_stream) - - self.assertAllClose(42, jax.jit(func)()) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - 42""", testing_stream.output) - - def test_tap_jit_multiple_invars(self): - def func(x1, x2): - return hcb_id_print(x1 + x2, output_stream=testing_stream) - - self.assertAllClose(42, jax.jit(func)(40, 2)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - 42""", testing_stream.output) - - def test_tap_jit_constant(self): - def func(x): - return hcb_id_print(42, result=x, output_stream=testing_stream) - - self.assertAllClose(5, jax.jit(func)(5)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - 42""", testing_stream.output) - - def test_tap_jit_sequence1(self): - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - return hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) - - logging.info("%s: %s", self._testMethodName, - jax.make_jaxpr(func)(1)) - logging.info( - "%s: %s", - self._testMethodName, - jax.jit(func) - .trace(1) - .lower(lowering_platforms=(jtu.device_under_test(),)).as_text("hlo")) - self.assertEqual(2, jax.jit(func)(1)) - hcb.barrier_wait() - - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2""", testing_stream.output) - - def test_tap_jit2(self): - """A sequence of JIT.""" - - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) - return x2 - - self.assertEqual(2, jax.jit(func)(1)) - self.assertEqual(11, jax.jit(func)(10)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2 - where: 1 - 10 - where: 2 - 11""", testing_stream.output) - - def test_tap_jit_result_unused(self): - """We can id_print even if we don't use the result.""" - - def func(x): - hcb_id_print(x, where="1", output_stream=testing_stream) - hcb_id_print(x + 1, where="2", output_stream=testing_stream) - return x + 1 - - self.assertEqual(2, jax.jit(func)(1)) - self.assertEqual(11, jax.jit(func)(10)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2 - where: 1 - 10 - where: 2 - 11""", testing_stream.output) - - def test_tap_jit_nested(self): - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - - def func_nested(x): - x2 = hcb_id_print(x + 1, where="nested", output_stream=testing_stream) - return x2 - - x3 = jax.jit(func_nested)(x1) - return hcb_id_print(x3 + 1, where="3", output_stream=testing_stream) - - self.assertEqual(3, jax.jit(func)(1)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: nested - 2 - where: 3 - 3""", testing_stream.output) - - def test_tap_jit_devices(self): - """Running on multiple devices.""" - self.supported_only_in_legacy_mode() - logging.info("%s: has devices %s", self._testMethodName, local_devices()) - - def func(x, device_id): - x1 = hcb_id_print(x, dev=str(device_id), output_stream=testing_stream) - x2 = hcb_id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream) - return x2 - - for d in local_devices(): - self.assertEqual(112, jax.jit(func, device=d, static_argnums=1)(111, d.id)) - hcb.barrier_wait() - logging.info("%s: found output %s", self._testMethodName, - testing_stream.output) - self.assertEqual( - len(local_devices()), len(re.findall(r"111", testing_stream.output))) - self.assertEqual( - len(local_devices()), len(re.findall(r"112", testing_stream.output))) - - @jtu.sample_product(with_jit=[True, False]) - def test_tap_pytree(self, with_jit=False): - def func(x, what=""): - """Returns some pytrees depending on x""" - if what == "pair_1_x": - return (1, x) - elif what == "pair_x_2x": - return (x, 2 * x) - elif what == "dict": - return dict(a=2 * x, b=3 * x) - else: - assert False - - tap_count = 0 - - def tap_func(a, _, *, what=""): - nonlocal tap_count - tap_count += 1 - self.assertEqual(func(5, what), a) - - transform = jax.jit if with_jit else lambda f: f - for what in ("pair_1_x", "pair_x_2x", "dict"): - transformed = transform( - lambda x: hcb.id_tap( - partial(tap_func, what=what), - func(x, what), - result=func(x * 2, what)) - )(5) - self.assertEqual(func(10, what), transformed) - hcb.barrier_wait() # Wait for receivers to be done - self.assertEqual(3, tap_count) - - @jtu.sample_product(concurrent=[True, False]) - def test_tap_multiple(self, concurrent=False): - """Call id_tap multiple times, concurrently or in sequence. """ - if concurrent and jtu.test_device_matches(["cpu", "gpu"]): - # TODO(necula): if there is device side concurrency, outfeeds from - # different computations can be interleaved. For example, it seems that - # on GPU if multiple host threads run a jit computation, the multiple - # computations are interleaved on the GPU. This can result in the outfeed - # trains being interleaved, which will trigger an error. - # The solution is to fix on GPU the receiving logic so that we can outfeed - # the train as one tuple, and receive it one piece as a time. Then the - # trains should be atomic. - # See also b/160692602. - raise SkipTest("concurrent id_tap not supported on CPU, GPU") - - received = set() - count = 5 - - def pause_tap(idx, _): - received.add(int(idx)) - logging.info("Starting do_tap %s. Sleeping 1sec ...", idx) - time.sleep(0.3) - logging.info("Finish do_tap %s", idx) - - def do_tap(idx): - jax.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx) - - if concurrent: - threads = [ - threading.Thread( - name=f"enqueue_tap_{idx}", target=do_tap, args=(idx,)) - for idx in range(count) - ] - [t.start() for t in threads] - [t.join() for t in threads] - else: - for idx in range(count): - do_tap(idx) - - hcb.barrier_wait() - self.assertEqual(received, set(range(count))) - - # TODO(necula): see comment for test_multiple_tap. Here we disable also - # on TPU, because the barrier_wait runs on all devices, including on the CPU - # where it would run into concurrency problems. - @skip("Concurrency not supported") - def test_tap_multiple_barriers(self): - """Call barrier_wait concurrently.""" - - def pause_tap(*args, **kwargs): - logging.info("pause_tap waiting") - time.sleep(0.3) - logging.info("pause_tap done") - - def long_run(x): - return hcb.id_tap(pause_tap, x) - - jax.jit(long_run)(5.) - - def try_barrier(idx): - logging.info("Starting test barrier %s", idx) - hcb.barrier_wait() - logging.info("Finished test barrier %s", idx) - - threads = [ - threading.Thread( - name=f"barrier_{idx}", target=try_barrier, args=(idx,)) - for idx in range(3) - ] - [t.start() for t in threads] - [t.join() for t in threads] - - @jtu.sample_product(with_jit=[True, False]) - def test_tap_cond(self, with_jit=False): - """A conditional""" - - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) - - x4 = lax.cond(x % 2 == 0, - lambda x: hcb_id_print(x, where="cond_t", - output_stream=testing_stream), - lambda x: hcb_id_print(-1, where="cond_f", result=x, - output_stream=testing_stream), - x2 + 1) - x5 = hcb_id_print(x4 + 1, where="end", output_stream=testing_stream) - return x5 - - transform = jax.jit if with_jit else lambda f: f - self.assertEqual(4, transform(func)(1)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2 - where: cond_f - -1 - where: end - 4""", testing_stream.output) - - @jtu.sample_product(with_jit=[True, False]) - def test_tap_while_cond(self, with_jit=False): - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) - - def body(x): - x3 = hcb_id_print(x, where="w_b_1", output_stream=testing_stream) - x4 = lax.cond(x % 2 == 0, - lambda x: hcb_id_print(x, where="w_b_t", - output_stream=testing_stream), - lambda x: hcb_id_print(-1, where="w_b_f", - result=x, output_stream=testing_stream), - x3 + 1) - return hcb_id_print(x4, where="w_b_2", output_stream=testing_stream) - - x10 = lax.while_loop(lambda x: x <= 3, body, x2) - res = hcb_id_print(x10, where="end", output_stream=testing_stream) - return res - - transform = jax.jit if with_jit else lambda f: f - self.assertEqual(4, transform(func)(1)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2 - where: w_b_1 - 2 - where: w_b_t - 3 - where: w_b_2 - 3 - where: w_b_1 - 3 - where: w_b_f - -1 - where: w_b_2 - 4 - where: end - 4""", testing_stream.output) - - def test_tap_jit_while_pred_tap(self): - """While with printing in the conditional.""" - - def func(x): - x1 = hcb_id_print(x, where="1") - x10 = lax.while_loop(lambda x: hcb_id_print(x < 3, - where="w_p", - output_stream=testing_stream), - lambda x: hcb_id_print(x + 1, where="w_b", - output_stream=testing_stream), - x1) - res = hcb_id_print(x10, where="3", output_stream=testing_stream) - return res - - self.assertEqual(3, jax.jit(func)(1)) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, - """ - where: w_p - True - where: w_b - 2 - where: w_p - True - where: w_b - 3 - where: w_p - False - where: 3 - 3""", testing_stream.output) - - @jtu.sample_product(with_jit=[True, False]) - def test_tap_scan_cond(self, with_jit=True): - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) - - def body(c, x): - x3 = hcb_id_print(x, where="s_1", output_stream=testing_stream) - x4 = lax.cond(x % 2 == 0, - lambda x: hcb_id_print(x, where="s_t", output_stream=testing_stream), - lambda x: hcb_id_print(-1, where="s_f", result=x, output_stream=testing_stream), - x3 + 1) - return (c, hcb_id_print(x4, where="s_2", output_stream=testing_stream)) - - _, x10 = lax.scan(body, x2, jnp.arange(3)) - res = hcb_id_print(x10, where="10", output_stream=testing_stream) - return res - - if with_jit: - func = jax.jit(func) - res = func(1) - self.assertAllClose(jnp.arange(1, 4), res) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2 - where: s_1 - 0 - where: s_t - 1 - where: s_2 - 1 - where: s_1 - 1 - where: s_f - -1 - where: s_2 - 2 - where: s_1 - 2 - where: s_t - 3 - where: s_2 - 3 - where: 10 - [1 2 3]""", testing_stream.output) - testing_stream.reset() - - @jtu.sample_product( - nr_args=[1, 2], - shape=[(), (2,), (2, 3), (2, 3, 4)], - dtype=jtu.dtypes.all, - ) - def test_tap_jit_dtypes(self, nr_args=2, dtype=jnp.int16, shape=(2,)): - if dtype in (jnp.complex64, jnp.complex128, jnp.bool_): - raise SkipTest(f"host_callback not implemented for {dtype}.") - if dtype == np.bool_: - args = [self.rng().choice(a=[True, False], size=shape)] - else: - args = [jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)] - if nr_args > 1: - args = args * nr_args - jit_fun1 = jax.jit(lambda xs: hcb_id_print( - xs, - a_new_test="************", - testcase_name=f"{shape=}_{dtype=}_{nr_args=}")) - - res = jit_fun1(args) - self.assertAllClose(args, res, check_dtypes=True) - - def test_tap_jit_large(self): - arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1)) - jax.jit(hcb_id_print)(arg) - - def test_tap_jit_several_together(self): - arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5)) - jax.jit(lambda x, y: hcb_id_print((x, y, x * 2)))(arg, jnp.ones(100, dtype=jnp.int32)) - - def test_tap_jit_interleaving(self): - # Several jit's without data dependencies; they may interfere - count = 0 # Count tap invocations - nr_arrays = 5 - - def tap_func(arg, _): - nonlocal count - assert len(arg) == nr_arrays - count += 1 - - # This is the function that we'll run multiple times - def func(x, count): - for i in range(count): - x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)])[-1] - return x - - x = jnp.array(1, dtype=np.int32) - res = 0 - for _ in range(10): - # No dependencies between the jit invocations - res += jax.jit(lambda x: func(x, 10))(x) - hcb.barrier_wait() - self.assertEqual(100, count) - - def test_tap_jit_tap_exception(self): - if not hcb._HOST_CALLBACK_OUTFEED.value: - raise SkipTest("TODO: implement error handling for customcall") - # Simulate a tap error - def tap_err(*args, **kwargs): - raise NotImplementedError - - def func(x): - x1 = hcb_id_print(x + 1, what="x1", output_stream=testing_stream) - x2 = hcb.id_tap(tap_err, x1 + 1) - x3 = hcb_id_print(x2 + 1, what="x3", output_stream=testing_stream) - return x3 - - if hcb._HOST_CALLBACK_LEGACY.value: - res = jax.jit(func)(0) # No error yet - with self.assertRaises(hcb.CallbackException): - hcb.barrier_wait() - - # Even though the receiver thread raised, the main thread should still - # return 3. - self.assertEqual(3, res) - # We should have received all others - assertMultiLineStrippedEqual(self, """ - what: x1 - 1 - what: x3 - 3""", testing_stream.output) - else: - with self.assertRaisesRegex(Exception, "NotImplementedError"): - res = jax.jit(func)(0) - hcb.barrier_wait() - - def test_tap_while(self): - """Executing while, even without JIT uses compiled code""" - y = jnp.ones(5) # captured const - - def func(x): - return lax.while_loop( - lambda c: c[1] < 5, - lambda c: (y, hcb_id_print(c[1], output_stream=testing_stream) + 1), - (x, 1)) - - func(y) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - 1 - 2 - 3 - 4""", testing_stream.output) - - def test_tap_jvp(self): - jvp_fun1 = lambda x, xt: jax.jvp(fun1, (x,), (xt,)) - res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1)) - self.assertAllClose(100., res_primals, check_dtypes=False) - self.assertAllClose(4., res_tangents, check_dtypes=False) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - what: a * 2 - 10.00 - what: y * 3 - 30.00""", testing_stream.output) - - def test_tap_grad_primal_unused(self): - # The output of id_print is not needed for backwards pass - def func(x): - return 2. * hcb_id_print(x * 3., what="x * 3", - output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - - grad_func = jax.grad(func) - arg = jnp.float32(5.) - jaxpr = str(jax.make_jaxpr(grad_func)(arg)) - # making the Jaxpr does not print anything - hcb.barrier_wait() - - if hcb._HOST_CALLBACK_LEGACY.value: - treedef = jax.tree.structure(arg) - assertMultiLineStrippedEqual( - self, f""" - {{ lambda ; a:f32[]. let - b:f32[] = mul a 3.00 - c:f32[] = outside_call[ - arg_treedef={treedef} - callback=... - device_index=0 - identity=True - ] b - _:f32[] = mul 2.00 c - d:f32[] = mul 2.00 1.00 - e:f32[] = mul d 3.00 - in (e,) }}""", jaxpr) - assertMultiLineStrippedEqual(self, "", testing_stream.output) - testing_stream.reset() - - res_grad = grad_func(arg) - hcb.barrier_wait() - - self.assertAllClose(6., res_grad, check_dtypes=False) - assertMultiLineStrippedEqual(self, """ - what: x * 3 - 15.00""", testing_stream.output) - - def test_tap_grad_simple(self): - def func(x): - y = hcb_id_print(x * 2., what="x * 2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x * hcb_id_print(y * 3., what="y * 3", - output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - - grad_func = jax.grad(func) - - res_grad = grad_func(jnp.float32(5.)) - self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - what: x * 2 - 10.00 - what: y * 3 - 30.00""", testing_stream.output) - - def test_tap_grad_grad(self): - def func(x): - y = hcb_id_print(x * 2., what="x * 2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x * (y * 3.) - - grad_func = jax.grad(jax.grad(func)) - # making the Jaxpr does not print anything - _ = jax.make_jaxpr(grad_func)(5.) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, "", testing_stream.output) - - res_grad = grad_func(jnp.float32(5.)) - - self.assertAllClose(12., res_grad, check_dtypes=False) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - what: x * 2 - 10.00""", testing_stream.output) - - def test_tap_grad_pytree(self): - def func(x): - x4, x5 = hcb_id_print((x * 2., x * 3.), what="pair", - result=(x * 4., x * 5.), - output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x4 + 2. * x5 - - x = jnp.float32(5.) - grad_func = jax.grad(func) - print(jax.make_jaxpr(grad_func)(x)) - res_grad = grad_func(x) - self.assertAllClose(14., res_grad, check_dtypes=False) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - what: pair - ( 10.00 15.00 )""", testing_stream.output) - - def test_tap_jvp_float0(self): - def f(x, yint): - x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint), - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x * yint - - res = jax.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0))) - self.assertAllClose((6., 0.6), res) - - def test_tap_grad_float0(self): - - def func(x, yint): - x, yint = hcb_id_print((x, yint), what="pair", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x * yint.astype(x.dtype) - - grad_func = jax.grad(func) - - res_grad = grad_func(jnp.float32(5.), jnp.int32(2)) - self.assertAllClose(2., res_grad, check_dtypes=False) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - what: pair - ( 5.00 2 )""", testing_stream.output) - - def test_tap_grad_float0_result(self): - # https://github.com/jax-ml/jax/issues/7340 - # x is a Tuple[f32[2], s32[3]] - x = (np.array([.7, .8], dtype=np.float32), - np.array([11, 12, 13], dtype=np.int32)) - def f_jax(x): - x = hcb_id_print(x, result=x, output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important - return (3. * x[0], x[1]) - - def f_jax_vjp(x): - res, pullback = jax.vjp(f_jax, x) - g, = pullback((np.ones(x[0].shape, dtype=x[0].dtype), - np.zeros(x[1].shape, dtype=dtypes.float0))) - return g - - g = f_jax_vjp(x) - self.assertAllClose(np.array([3., 3.], dtype=np.float32), g[0]) - self.assertEqual(dtypes.float0, g[1].dtype) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - ( [0.70 0.80] [11 12 13] )""", testing_stream.output) - - def test_tap_higher_order_grad_float0_result(self): - # https://github.com/jax-ml/jax/issues/7340 - # x is a Tuple[f32[2], s32[3]] - x = (np.array([.7, .8], dtype=np.float32), - np.array([11, 12, 13], dtype=np.int32)) - def f_jax(x): - x = hcb_id_print(x, result=x, output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important - return (jnp.sin(x[0]), x[1]) - - def wrap_vjp(f, args, res_f_of_args): - # Given a function "f" and "args" return the f_vjp and args_vjp - def make_ct(res): - res_dtype = np.result_type(res) - if res_dtype == dtypes.float0: - return res - ct_dtype = core.primal_dtype_to_tangent_dtype(res_dtype) - return np.ones(np.shape(res), dtype=ct_dtype) - cts = jax.tree.map(make_ct, res_f_of_args) - def f_vjp(args, cts): - res, pullback = jax.vjp(f, *args) - return pullback(cts) - return (f_vjp, (args, cts)) - - res = f_jax(x) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - ( [0.70 0.80] [11 12 13] )""", testing_stream.output) - testing_stream.reset() - - # 1st order - f_jax_vjp1, args_vjp1 = wrap_vjp(f_jax, (x,), res) - res_vjp1 = f_jax_vjp1(*args_vjp1) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - ( [0.70 0.80] [11 12 13] )""", testing_stream.output) - testing_stream.reset() - - # 2nd order - f_jax_vjp2, args_vjp2 = wrap_vjp(f_jax_vjp1, args_vjp1, res_vjp1) - res_vjp2 = f_jax_vjp2(*args_vjp2) - - # 3rd order - f_jax_vjp3, args_vjp3 = wrap_vjp(f_jax_vjp2, args_vjp2, res_vjp2) - _ = f_jax_vjp3(*args_vjp3) - - def test_tap_vmap(self): - vmap_fun1 = jax.vmap(fun1) - vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) - vmap_fun1(vargs) - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)})] what: a * 2 - [ 8.00 10.00] - transforms: [('batch', {'batch_dims': (0,)})] what: y * 3 - [24.00 30.00]""", testing_stream.output) - else: - assertMultiLineStrippedEqual(self, """ - what: a * 2 - 8.00 - what: a * 2 - 10.00 - what: y * 3 - 24.00 - what: y * 3 - 30.00 - """, testing_stream.output) - - def test_tap_vmap_not_batched(self): - x = 3. - - def func(y): - # x is not mapped, y is mapped - _, y = hcb_id_print((x, y), output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x + y - - vmap_func = jax.vmap(func) - vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) - _ = vmap_func(vargs) - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (None, 0)})] - ( 3.00 [4.00 5.00] )""", testing_stream.output) - else: - assertMultiLineStrippedEqual(self, """ - ( 3.00 4.00 ) - ( 3.00 5.00 ) - """, testing_stream.output) - - def test_tap_vmap_vmap(self): - # A 2D tensor with x[i, j] = i + j using 2 vmap - def sum(x, y): - return hcb_id_print(x + y, output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - - def sum_rows(xv, y): - return jax.vmap(sum, in_axes=(0, None))(xv, y) - - def sum_all(xv, yv): - return jax.vmap(sum_rows, in_axes=(None, 0))(xv, yv) - - xv = jnp.arange(5, dtype=np.int32) - yv = jnp.arange(3, dtype=np.int32) - # assertMultiLineStrippedEqual(self, "", str(jax.make_jaxpr(sum_all)(xv, yv))) - _ = sum_all(xv, yv) - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})] - [[0 1 2 3 4] - [1 2 3 4 5] - [2 3 4 5 6]]""", testing_stream.output) - else: - assertMultiLineStrippedEqual(self, """ - 0 - 1 - 2 - 1 - 2 - 3 - 2 - 3 - 4 - 3 - 4 - 5 - 4 - 5 - 6 - """, testing_stream.output) - - def test_tap_vmap_while(self): - """Vmap of while.""" - - def func(x): - # like max(x, 2) - x1 = hcb_id_print(x, where="before:x", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - x2 = lax.while_loop( - lambda x: x < 2, lambda x: hcb_id_print( - x + 1, where="body:x+1", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG), x1) - res = hcb_id_print(x2, where="after:x", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return res - - inputs = np.arange(5, dtype=np.int32) - self.assertAllClose( - np.array([2, 2, 2, 3, 4]), - jax.jit(jax.vmap(func))(inputs), - check_dtypes=False) - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - assertMultiLineStrippedEqual( - self, """ - transforms: [('batch', {'batch_dims': (0,)})] where: before:x - [0 1 2 3 4] - transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 - [1 2 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 - [2 3 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: after:x - [2 2 2 3 4]""", testing_stream.output) - else: - pass # order of vmaps is not guaranteed - - def test_tap_vmap_while_tap_cond(self): - """Vmap of while, with a tap in the conditional.""" - - def func(x): - # like max(x, 2) - x1 = hcb_id_print(x, where="1", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - x2 = lax.while_loop(lambda x: hcb_id_print(x < 2, where="w_c", - output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG), - lambda x: hcb_id_print(x + 1, where="w_b", - output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG), - x1) - res = hcb_id_print(x2, where="3", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return res - - inputs = np.arange(5, dtype=np.int32) - res = jax.jit(jax.vmap(func))(inputs) - hcb.barrier_wait() - self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False) - if hcb._HOST_CALLBACK_LEGACY.value: - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)})] where: 1 - [0 1 2 3 4] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [ True True False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: w_b - [1 2 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [ True False False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: w_b - [2 3 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [False False False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: 3 - [2 2 2 3 4]""", testing_stream.output) - else: - pass # order of vmap is not guaranteed - - def test_tap_transforms_doc(self): - # Examples from the documentation - def power3(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb_id_print((x, y), what="x,x^2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return y * x - - print(f"impl = {power3(3.)}") - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. )""" - else: - expected = """ - what: x,x^2 - ( 3.0 9.0 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}") - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. )""" - else: - expected = """ - what: x,x^2 - ( 3.0 9.0 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - @jax.custom_jvp - def print_tangents(arg): - return None - - @print_tangents.defjvp - def print_tangents_jvp(primals, tangents): - arg_dot, = tangents - hcb_id_print(arg_dot, what="tangents", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return primals, tangents - - def power3_with_tangents(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb_id_print((x, y), what="x,x^2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - print_tangents((x, y)) - return y * x - - print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}") - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. ) - what: tangents - ( 0.1 0.6 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - testing_stream.reset() - - print(f"grad = {jax.grad(power3)(3.)}") - hcb.barrier_wait() - # Only the primals by default - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. )""" - else: - expected = """ - what: x,x^2 - ( 3.0 9.0 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - @jax.custom_vjp - def print_cotangents(arg): - # Must return the argument for which we want the cotangent. - return arg - - # f_fwd: a -> (b, residual) - def print_cotangents_fwd(arg): - return print_cotangents(arg), None - # f_bwd: (residual, CT b) -> [CT a] - def print_cotangents_bwd(residual, ct_b): - hcb_id_print(ct_b, what="cotangents", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return ct_b, - - print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd) - - def power3_with_cotangents(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb_id_print((x, y), what="x,x^2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - # Must use the output of print_cotangents - (x1, y1) = print_cotangents((x, y)) - return y1 * x1 - - print(f"grad = {jax.grad(power3_with_cotangents)(3.)}") - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. ) - what: cotangents - ( 9. 3. )""" - else: - expected = """ - what: x,x^2 - ( 3.0 9.0 ) - what: cotangents - ( 9.0 3.0 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - # TODO: grad of grad - - print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}") - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] )""" - else: - expected = """ - what: x,x^2 - ( 2.0 4.0 ) - what: x,x^2 - ( 3.0 9.0 ) - """ - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}") - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] )""" - else: - expected = """ - what: x,x^2 - ( 2.0 4.0 ) - what: x,x^2 - ( 3.0 9.0 ) - """ - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}") - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] ) - transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents - ( [4. 9.] [2. 3.] )""" - else: - expected = """ - what: x,x^2 - ( 2.0 4.0 ) - what: x,x^2 - ( 3.0 9.0 ) - what: cotangents - ( 4.0 2.0 ) - what: cotangents - ( 9.0 3.0 ) - """ - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}") - hcb.barrier_wait() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. ) - what: x,x^2 - ( 27. 729. ) - what: x,x^2 - ( 3. 9. )""" - else: - expected = """ - what: x,x^2 - ( 3.0 9.0 ) - what: x,x^2 - ( 27.0 729.0 ) - what: x,x^2 - ( 3.0 9.0 ) - """ - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - def test_tap_pmap(self): - self.supported_only_in_legacy_mode() - if len(local_devices()) < 2: - raise SkipTest("test requires at least 2 devices") - - def power3(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - _, y = hcb_id_print((x, y), - what="x,x^2", - output_stream=testing_stream, - tap_with_device=True) - return y * x - - pmap_power3 = jax.pmap(power3, devices=local_devices()) - xv = np.array([3, 4], dtype=np.int32) - res = pmap_power3(xv) - hcb.barrier_wait() - self.assertAllClose(xv * xv * xv, res, check_dtypes=False) - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual( - self, """ - device: cpu:0 what: x,x^2 - ( 3 9 ) - device: cpu:1 what: x,x^2 - ( 4 16 )""") - - def test_tap_pmap_vmap(self): - self.supported_only_in_legacy_mode() - # A matrix M[ij] = i * 10 + j - nr_devices = len(local_devices()) - shape = (nr_devices, 3) - matrix = np.fromfunction(lambda i, j: 10. * i + j, shape, - dtype=np.int32) - - def fun1(x, do_print=False): # x: i32 - return maybe_print(do_print, x * 2, "x * 2", tap_with_device=True) - - pmap_vmap_fun1 = jax.pmap( - jax.vmap(partial(fun1, do_print=True)), devices=local_devices()) - - res = pmap_vmap_fun1(matrix) - hcb.barrier_wait() - expected_res = jax.pmap( - jax.vmap(partial(fun1, do_print=False)), devices=local_devices())( - matrix) - self.assertAllClose(expected_res, res, check_dtypes=False) - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [0.00 2.00 4.00] - device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [20.00 22.00 24.00]""") - - def test_tap_pmap_pmap_vmap(self): - # A matrix M[ijk] = i * 100 + j * 10 + k - self.supported_only_in_legacy_mode() - nr_devices = len(local_devices()) - if nr_devices % 2 != 0: - raise SkipTest("test works only on even number of devices") - - shape = (2, nr_devices // 2, 3) - matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape, - dtype=np.float32) - - def fun1(x, do_print=False): # x: f32 - y = maybe_print(do_print, x * 2., "x * 2", tap_with_device=True) - return y ** 2 - - pmap_fun1 = jax.pmap( - jax.pmap(jax.vmap(partial(fun1, do_print=True))), - devices=local_devices()) - res = pmap_fun1(matrix) - hcb.barrier_wait() - expected_res = jax.pmap( - jax.pmap(jax.vmap(partial(fun1, do_print=False))), - devices=local_devices())( - matrix) - self.assertAllClose(expected_res, res, check_dtypes=False) - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [0.00 2.00 4.00] - device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [200.00 202.00 204.00]""") - - @ignore_jit_of_pmap_warning() - def test_tap_pmap_pmap_extra(self): - """pmap of a pmap surrounded by extra code.""" - # A matrix M[ij] = i * 10 + j - self.supported_only_in_legacy_mode() - nr_devices = len(local_devices()) - if nr_devices != 2: - raise SkipTest("test works only on 2 devices") - shape = (2, 1, 3) - matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape, - dtype=np.float32) - - def fun(xv, do_print=False): - # This will be printed on all devices, with shape [1, 3] - xv = maybe_print(do_print, xv + 1., "before", tap_with_device=True) - res = jax.pmap(lambda x: maybe_print(do_print, x * 2., "inside", tap_with_device=True))(xv) - # This will be printed on all devices, with shape [1, 3] - return maybe_print(do_print, res + 1., "after", tap_with_device=True) - - res = jax.pmap(partial(fun, do_print=True))(matrix) - self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False) - hcb.barrier_wait() - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 what: before - [[1.00 2.00 3.00]] - device: cpu:0 what: inside - [2.00 4.00 6.00] - device: cpu:0 what: after - [[3.00 5.00 7.00]] - device: cpu:1 what: before - [[101.00 102.00 103.00]] - device: cpu:1 what: inside - [202.00 204.00 206.00] - device: cpu:1 what: after - [[203.00 205.00 207.00]]""") - - def test_tap_jvp_pmap_vmap(self): - self.supported_only_in_legacy_mode() - # A matrix M[ijk] = i * 100 + j * 10 * k - nr_devices = len(local_devices()) - shape = (nr_devices, 2, 3) - matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape, - dtype=np.float32) - - def fun(xv, do_print=False): - # x: f32[3] - return jax.jvp(jax.pmap(jax.vmap(lambda x: maybe_print(do_print, x * 2., "x * 2", tap_with_device=True))), - (xv,), (.1 * jnp.ones_like(xv),)) - - res = fun(matrix, do_print=True) - hcb.barrier_wait() - expected_res = fun(matrix, do_print=False) - self.assertAllClose(expected_res, res, check_dtypes=False) - # Assertion text is for 2 devices (also works for 1 device) - # Device 0 will get to execute jax.jvp(jax.vmap(...)) for matrix[0, :, :] - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [[ 0.00 2.00 4.00] - [20.00 22.00 24.00]] - device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [[200.00 202.00 204.00] - [220.00 222.00 224.00]]""") - - def test_tap_vmap_pmap(self): - self.supported_only_in_legacy_mode() - # A matrix M[ijk] = i * 100 + j * 10 * k - nr_devices = len(local_devices()) - shape = (2, nr_devices, 3) - matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape, - dtype=np.float32) - - def fun(xv, do_print=False): - # x: f32[3] - return jax.vmap(jax.pmap(lambda x: maybe_print(do_print, x * 2., "x * 2", tap_with_device=True)))(xv) - - res = fun(matrix, do_print=True) - hcb.barrier_wait() - expected_res = fun(matrix, do_print=False) - self.assertAllClose(expected_res, res, check_dtypes=False) - # Assertion text is for 2 devices (also works for 1 device) - # Device 0 will get to execute jax.jvp(jax.vmap(...)) for matrix[:, 0, :] - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [[ 0.00 2.00 4.00] - [200.00 202.00 204.00]] - device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [[ 20.00 22.00 24.00] - [220.00 222.00 224.00]]""") - - @ignore_jit_of_pmap_warning() - def test_tap_jit_pmap_extra(self): - """jit of a pmap surrounded by extra code.""" - self.supported_only_in_legacy_mode() - # A matrix M[ij] = i * 10 + j - nr_devices = len(local_devices()) - assert nr_devices in (1, 2) - shape = (nr_devices, 3) - matrix = np.fromfunction(lambda i, j: 10. * i + j, shape, - dtype=np.float32) - - def fun(xv, do_print=False): - # This will be printed on all devices with shape (nr_devices, 3) - xv = maybe_print(do_print, xv + 1., "before", tap_with_device=True) - res = jax.pmap(lambda x: maybe_print(do_print, x * 2., "inside", tap_with_device=True))(xv) - # This will be printed on all devices with shape (nr_devices, 3) - return maybe_print(do_print, res + 1., "after", tap_with_device=True) - - res = jax.jit(partial(fun, do_print=True))(matrix) - self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False) - hcb.barrier_wait() - if len(local_devices()) == 2: - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 what: before - [[ 1.00 2.00 3.00] - [11.00 12.00 13.00]] - device: cpu:0 what: inside - [2.00 4.00 6.00] - device: cpu:0 what: after - [[ 3.00 5.00 7.00] - [23.00 25.00 27.00]] - device: cpu:1 what: before - [[ 1.00 2.00 3.00] - [11.00 12.00 13.00]] - device: cpu:1 what: inside - [22.00 24.00 26.00] - device: cpu:1 what: after - [[ 3.00 5.00 7.00] - [23.00 25.00 27.00]]""") - else: - assert len(local_devices()) == 1 - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 what: before - [[1.00 2.00 3.00]] - device: cpu:0 what: inside - [2.00 4.00 6.00] - device: cpu:0 what: after - [[3.00 5.00 7.00]]""") - - @unittest.skip("cond of pmap does not work in JAX. Issue #5178.") - def test_tap_cond_pmap(self): - # A matrix M[ij] = i * 10 + j - nr_devices = len(local_devices()) - shape = (nr_devices, 3) - matrix = np.fromfunction(lambda i, j: 10. * i + j, shape, - dtype=np.float32) - - def fun1(x, do_print=False): - return maybe_print(do_print, x * 2., "x * 2") - - def fun2(cond, xv, do_print=False): - return lax.cond(cond, jax.pmap(partial(fun1, do_print=do_print)), - lambda xv: xv, xv) - - res = fun2(True, matrix) - self.assertAllClose(fun2(True, matrix, do_print=False), res, check_dtypes=False) - hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - TBD""", testing_stream.output) - - @jtu.sample_product(device_index=[0, 1]) - def test_tap_pjit(self, device_index=0): - self.supported_only_in_legacy_mode() - if (device_index != 0 and - not hcb._HOST_CALLBACK_OUTFEED.value and - jtu.test_device_matches(["cpu"])): - # See comment in host_callback.py. - raise SkipTest("device_index works only with outfeed on CPU") - - devices = np.array(local_devices()) - nr_devices = len(devices) - if nr_devices < 2: - raise SkipTest("test requires at least 2 devices") - - logging.info(f"test_tap_pjit is running on devices {devices}.") - # x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...] - # y: i32[3, 4] - x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3] - y = jnp.ones((3, 4), np.int32) - - @partial(jax.named_call, name="fun1") # for xprof debugging - def fun1(x): - z = jnp.dot(x, y) - return hcb_id_print(z, what="z", - output_stream=testing_stream, - tap_with_device=True, device_index=device_index) - - pjit_fun1 = pjit.pjit(fun1, in_shardings=(P("d"),), out_shardings=P("d")) - - with jax.sharding.Mesh(devices, ["d"]): - # Print the internal IR - helper_log_ir( - f"{self._testMethodName}.pjit", - pjit_fun1, - x, - num_partitions=nr_devices) - res = pjit_fun1(x) - - self.assertAllClose(jnp.dot(x, y), res) - hcb.barrier_wait("before check") - - # Assertion text is for 2 devices (also works for 1 device) - # Note that a single call is made. - assertMultiDeviceOutputEqual( - self, f""" - device: cpu:{device_index} what: z - [[ 3 3 3 3] - [33 33 33 33]]""") - - def test_tap_scan_custom_jvp(self): - """custom JVP, inside scan. - This exercises the custom_jvp_call_jaxpr primitives.""" - self.supported_only_in_legacy_mode() - @jax.custom_jvp - def f(x): - return x * hcb_id_print(x, output_stream=testing_stream, what="x") - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - x_dot, = tangents - primal_out = f(x) - tangent_out = 3. * x * hcb_id_print(x_dot, output_stream=testing_stream, what="x_dot") - return primal_out, tangent_out - - def g(x): - # Sum f(x_i) - return lax.scan(lambda carry, inp: (carry + f(inp), 0.), - np.full(x.shape[1:], 0.), # Like x w/o leading dim - x)[0] - - arg = np.full((2,), 0.7) - self.assertAllClose(0.7 * 0.7 * 2, g(arg)) - hcb.barrier_wait() - self.assertMultiLineStrippedEqual(""" - what: x - 0.7 - what: x - 0.7""", testing_stream.output) - testing_stream.reset() - - self.assertAllClose(np.array([2.1, 2.1]), jax.grad(g)(arg), check_dtypes=False) - hcb.barrier_wait() - self.assertMultiLineStrippedEqual(""" - what: x - 0.7 - what: x - 0.7 - transforms: ['transpose'] what: x_dot - 2.1 - transforms: ['transpose'] what: x_dot - 2.1""", testing_stream.output) - - def test_tap_scan_custom_vjp(self): - """custom VJP, inside scan. - This exercises the custom_vjp_call_jaxpr primitives.""" - self.supported_only_in_legacy_mode() - @jax.custom_vjp - def f(x): - return x * hcb_id_print(x, output_stream=testing_stream, what="x") - - # f_fwd: a -> (b, residual) - def f_fwd(x): - return f(x), 3. * x - - # f_bwd: (residual, CT b) -> [CT a] - def f_bwd(residual, ct_b): - return residual * hcb_id_print(ct_b, output_stream=testing_stream, what="ct_b"), - - f.defvjp(f_fwd, f_bwd) - - def g(x): - # Sum f(x_i) - return lax.scan(lambda carry, inp: (carry + f(inp), 0.), - np.full(x.shape[1:], 0.), # Like x w/o leading dim - x)[0] - - arg = np.full((2,), 0.7) - - self.assertAllClose(0.7 * 0.7 * 2, g(arg)) - hcb.barrier_wait() - self.assertMultiLineStrippedEqual(""" - what: x - 0.7 - what: x - 0.7""", testing_stream.output) - testing_stream.reset() - - self.assertAllClose(np.array([2.1, 2.1]), jax.grad(g)(arg), check_dtypes=False) - hcb.barrier_wait() - self.assertMultiLineStrippedEqual(""" - what: x - 0.7 - what: x - 0.7 - what: ct_b - 1. - what: ct_b - 1.""", testing_stream.output) - - def test_tap_callback_delay(self): - hcb.callback_extra = lambda dev: time.sleep(1) - - def func(x): - for i in range(5): - x = hcb_id_print(x * i, what="x times i") - return x - - jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) - - def test_tap_callback_delay_barrier(self): - hcb.callback_extra = lambda dev: time.sleep(2) - - def func(x): - for i in range(1, 4): - x = hcb_id_print(x * i, what=f"x times {i}", output_stream=testing_stream) - return x - - jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) - # Wait for the results - hcb.barrier_wait("first") - expected = """ - what: x times 1 - [[0. 1. 2.] - [3. 4. 5.]] - what: x times 2 - [[ 0. 2. 4.] - [ 6. 8. 10.]] - what: x times 3 - [[ 0. 6. 12.] - [18. 24. 30.]]""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - # Call again - jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) - hcb.barrier_wait("second") - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - def test_tap_error_bad_consumer_id(self): - """Try to use reserved consumer ID 0. - - Check that we get the proper error from the runtime.""" - if not hcb._use_outfeed(jtu.device_under_test()): - raise SkipTest("test works only for outfeed") - comp = xla_client.XlaBuilder(self._testMethodName) - token = hcb.xops.CreateToken(comp) - hcb._initialize_outfeed_receiver() # Needed if this is the sole test - with self.assertRaisesRegex(RuntimeError, - "Consumer ID cannot be a reserved value: 0"): - hcb._callback_handler_data.receiver.add_outfeed( - comp, token, 0, - [xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))], 0) - - def test_tap_error_different_shapes(self): - """Try to register different shapes for the same consumer ID.""" - if not hcb._use_outfeed(jtu.device_under_test()): - raise SkipTest("test works only for outfeed") - comp = xla_client.XlaBuilder(self._testMethodName) - token = hcb.xops.CreateToken(comp) - hcb._initialize_outfeed_receiver() # Needed if this is the sole test - hcb._callback_handler_data.receiver.add_outfeed( - comp, token, 123, - [xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))], 0) - with self.assertRaisesRegex( - RuntimeError, ".*does not match previous shape .*\n?element_type.*"): - hcb._callback_handler_data.receiver.add_outfeed( - comp, token, 123, - [xops.Constant(comp, np.zeros((2, 3), dtype=np.int32))], 0) - with self.assertRaisesRegex( - RuntimeError, ".*does not match previous shape .*\n?element_type.*"): - hcb._callback_handler_data.receiver.add_outfeed( - comp, token, 123, - [xops.Constant(comp, np.zeros((2,), dtype=np.float32))], 0) - - def test_tap_id_tap_removed_kwargs(self): - def func(x, transforms, y): - pass - - with self.assertRaisesRegex(TypeError, r"Support for \*\*kwargs in ``id_tap``"): - hcb.id_tap(func, 1, y=2) - - def test_tap_id_tap_random_key(self): - # See https://github.com/jax-ml/jax/issues/13949 - with jax.enable_custom_prng(): - @jax.jit - def f(x): - def tap(tap_x, _): pass - return hcb.id_tap(tap, x, result=x) - f(jax.random.PRNGKey(123)) - - def test_tap_odeint(self): - # TODO: find a smaller repro for bug #4015 - # Seems to be xla_call(scan(xla_call)), all under grad. - from jax.experimental.ode import odeint - - def f(x, t, k): - x = hcb_id_print(x, callback_flavor=hcb.CallbackFlavor.DEBUG) - return -k * x - - def loss(k=1.0): - t = jnp.linspace(0, 0.001, num=2) - xs = odeint(f, 1.0, t, k) - return xs[-1] - - jax.grad(loss)(1.0) # should not fail - - def test_tap_remat_0(self): - def f(i, k): - x = hcb_id_print(k + i, output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return k * x - - def loss(k): - return lax.fori_loop(0, 2, jax.remat(f), k) - - print(loss(3)) - hcb.barrier_wait() - expected = """ - 3 - 10""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - @jtu.sample_product( - use_result=[True, False], - grad_func=["grad", "value_and_grad"], - use_remat=["old", "new", "none"], - ) - def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"): - self.supported_only_in_legacy_mode() - if use_remat == "old": raise SkipTest() - - def f(x): - id_print_result = hcb_id_print(x, output_stream=testing_stream) - if use_result: - x = id_print_result - return 3. * x - grad_f = jax.grad if grad_func == "grad" else jax.value_and_grad - if use_remat == "old": - trans_f = jax.remat(f) - elif use_remat == "new": - trans_f = ad_checkpoint.checkpoint(f) - else: - assert use_remat == "none" - trans_f = f - print(jax.make_jaxpr(grad_f(trans_f))(2.)) - grad_f(trans_f)(2.) - - hcb.barrier_wait() - - if use_remat == "none": - # GOOD: whether or not we use_result, we get the same callback. - expected = "2." - else: # use_remat - if use_result: - expected = """ - 2. - 2.""" - else: - if use_remat == "old": - # TODO: we should see two callbacks - expected = "" - else: - # Good: we see two callbacks, whether or not we use the result. - expected = """ - 2. - 2.""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - def test_tap_named_call(self): - def tap_scalar(init, do_print=False): - @partial(jax.named_call, name="step") - def step(acc, step_nr): - acc = acc + step_nr - maybe_print(do_print, step_nr, what="step_nr") - return acc, None - - return lax.scan(step, init, np.arange(2)) - - self.assertAllClose(tap_scalar(3, do_print=False), tap_scalar(3, do_print=True)) - hcb.barrier_wait() - expected = """ - what: step_nr - 0 - what: step_nr - 1""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - -class HostCallbackCallTest(jtu.JaxTestCase): - """Tests for hcb.call""" - - def setUp(self): - # skipping here skips teardown, so do this before super().setUp(). - if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: - raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") - if xla_bridge.using_pjrt_c_api(): - raise SkipTest("host_callback not implemented in PJRT C API") - super().setUp() - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="The host_callback APIs are deprecated")) - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="backend and device argument")) - - testing_stream.reset() - testing_stream._test_method_name = self._testMethodName - - def tearDown(self) -> None: - hcb.barrier_wait("HostCallbackCallTest.tearDown") - super().tearDown() - - def supported_only_in_legacy_mode(self): - if not hcb._HOST_CALLBACK_LEGACY.value: - self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") - - def call_log_testing_stream(self, func, arg, *, result_shape, name=""): - """Call `func` and log inputs and outputs to the testing stream""" - - def call_log(arg): - def val2str(v): - return np.array2string(np.array(arg)) - testing_stream.write(f"Call {name}({val2str(arg)})\n") - res = func(arg) - testing_stream.write(f" = {val2str(res)}\n") - return res - return hcb.call(call_log, arg, result_shape=result_shape) - - def test_call_simple(self): - - def f_outside(x): - return 2 * x - - def fun(x): - y = hcb.call(f_outside, x + 1, result_shape=x) - return 3 * (1 + y) - - arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4)) - self.assertAllClose(3 * (1 + 2 * (arg + 1)), fun(arg)) - - def test_primitive_compilation(self): - - def f_outside(x): - return 2 * x - - def fun(x): - return hcb.call(f_outside, x, result_shape=x) - - arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4)) - with jtu.count_primitive_compiles() as count: - for _ in range(3): - self.assertAllClose(2 * arg, fun(arg)) - r = jax.make_jaxpr(fun)(arg) - self.assertEqual(count[0], 1) - - @jtu.sample_product( - dtype=[dtype for dtype in jtu.dtypes.all if dtype != np.bool_], - ) - def test_call_types(self, dtype=np.float64): - - def f_outside(x): - # Use x + x to ensure that the result type is the same - return x + x - - def fun(x): - return hcb.call(f_outside, x + x, result_shape=x) - - arg = np.arange(24, dtype=dtype).reshape((2, 3, 4)) - self.assertAllClose(arg + arg + arg + arg, fun(arg), check_dtypes=True) - - def test_call_types_bool(self, dtype=np.float64): - - def f_outside(x): - return np.invert(x) - - def fun(x): - return hcb.call(f_outside, x, result_shape=x) - - arg = self.rng().choice(a=[True, False], size=(2, 3, 4)) - self.assertAllClose(np.invert(arg), fun(arg)) - - def test_call_tuples(self): - - def f_outside(args): - x, y = args - return y, x # Swap the tuple - - def fun(x): - xy = hcb.call(f_outside, (x, x + 1), result_shape=(x, x)) - return 2 * xy[0] + 3 * xy[1] - - arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4)) - self.assertAllClose(2 * (arg + 1) + 3 * arg, fun(arg)) - - def test_call_no_arg(self): - """Call with no arguments.""" - result = np.ones((2,), dtype=np.float32) - def f_outside(in_tuple): - assert len(in_tuple) == 0 - return result - def fun(x): - return x + hcb.call(f_outside, (), - result_shape=jax.ShapeDtypeStruct(result.shape, result.dtype)) - self.assertAllClose(2. + result, fun(2.)) - - def test_call_empty_arg(self): - """Call with empty array.""" - result = np.full((2,), 3., dtype=np.float32) - def f_outside(x0): # x0: f32[2, 0] - return result - x0 = np.ones((2, 0), dtype=np.float32) - def fun(x): - return x + hcb.call(f_outside, x0, - result_shape=jax.ShapeDtypeStruct(result.shape, result.dtype)) - self.assertAllClose(2. + result, fun(2.)) - - def test_call_empty_arg_inside_pytree(self): - """Call taking tuple with an empty array and a non-empty one.""" - x0 = np.ones((2, 0), dtype=np.float32) - x1 = np.full((2,), 3., dtype=np.float32) - result = x1 - def f_outside(in_tuple): # x0: f32[2, 0] x1: f32[2] - return in_tuple[1] - - def fun(x): - res = hcb.call(f_outside, (x0, x1), - result_shape=jax.ShapeDtypeStruct(result.shape, result.dtype)) - return x + res - self.assertAllClose(2. + result, fun(2.)) - - def test_call_empty_result(self): - """Call returning empty array.""" - result_shape = (2, 0) - def f_outside(_): - return np.ones(result_shape, dtype=np.float32) - def fun(x): - return x + hcb.call(f_outside, 1., - result_shape=jax.ShapeDtypeStruct(result_shape, np.float32)) - self.assertAllClose(f_outside(0.), fun(2.)) - - def test_call_empty_result_inside_pytree(self): - """Call returning a tuple with an empty array and a non-empty one.""" - result_shape_0 = (2, 0) - result_shape_2 = (0,) - def f_outside(_): - return (np.ones(result_shape_0, dtype=np.float32), - np.ones((1,), dtype=np.float32), - np.ones(result_shape_2, dtype=np.float32)) - def fun(x): - res = hcb.call(f_outside, 1., - result_shape=(jax.ShapeDtypeStruct(result_shape_0, np.float32), - jax.ShapeDtypeStruct((1,), np.float32), - jax.ShapeDtypeStruct(result_shape_2, np.float32))) - self.assertEqual(result_shape_0, res[0].shape) - self.assertEqual(result_shape_2, res[2].shape) - return x + res[1] - self.assertAllClose(2 + np.ones((1,), dtype=np.float32), fun(2.)) - - def test_call_empty_result_all_pytree(self): - """Call returning a tuple of empty arrays.""" - result_shape = (2, 0) - def f_outside(_): - return (np.ones(result_shape, dtype=np.float32), - np.ones(result_shape, dtype=np.float32)) - def fun(x): - res = hcb.call(f_outside, 1., - result_shape=(jax.ShapeDtypeStruct(result_shape, np.float32), - jax.ShapeDtypeStruct(result_shape, np.float32))) - return x + res[0] + res[1] - self.assertAllClose(np.ones(result_shape, dtype=np.float32), - fun(2.)) - - def test_call_no_result(self): - def f_outside(arg): - self.call_log_testing_stream(lambda x: None, arg, - result_shape=None, - name="outside") - return arg - - self.assertAllClose((3., 4.), f_outside((3., 4.))) - hcb.barrier_wait() - expected = """ - Call outside([3. 4.]) - = [3. 4.]""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - def test_call_cond(self): - def f_outside(args): - x, y = args - return x * y.astype(np.float32) - - def loop(x, use_outside=True): - def body(i, acc): - return lax.cond(i % 2 == 1, - lambda _: (hcb.call(f_outside, (acc, i), - result_shape=acc) - if use_outside else f_outside((acc, i))), - lambda _: acc, - None) - - return lax.fori_loop(0, 18, body, x) - - res_inside = loop(np.float32(1.2), use_outside=False) - self.assertAllClose(res_inside, jax.jit(loop)(np.float32(1.2))) - - def test_call_jit_scan_call(self): - def f_outside(x): - return x - - def loop(x, use_outside=True): - def body(carry, i): - if use_outside: - return carry + hcb.call(f_outside, i, - result_shape=i), None - else: - return carry + i, None - - return lax.scan(body, 0, x) - - x = np.arange(5, dtype=np.int32) - - res_outside = jax.jit(partial(loop, use_outside=True))(x) - self.assertAllClose(res_outside, loop(x, use_outside=False)) - - def test_call_doc_example1(self): - """Examples from the documentation: simplest, call a function""" - - def host_eig(x): - return np.linalg.eigvals(x) - - shape = (2, 5, 4, 4) - - m = np.ones(shape, dtype=np.float32) - - def fun(m): - eig_m = hcb.call(host_eig, m, - result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype)) - return eig_m - - expected_res = np.linalg.eigvals(m) - self.assertAllClose(expected_res, fun(m)) - @jtu.skip_on_devices("gpu") - def test_call_doc_example_hlo(self): - """Examples from the documentation: simplest, call a function.""" - - def fun1(m): - return jnp.sin(hcb.call(lambda x: np.cos, - jnp.cos(m), - result_shape=m)) - - m = np.ones((2,), np.float32) - helper_print_optimized_hlo(fun1, m) - - def fun2(m): - x = hcb.call(lambda x: None, 2, result_shape=()) - return x - - m = np.ones((2,), np.float32) - helper_print_optimized_hlo(fun2, m) - - def test_call_with_device(self): - self.supported_only_in_legacy_mode() - def callback_func(x, device=None): - testing_stream.write(f"device: {device}\n Called with {x}") - return x - - def func(x): - return hcb.call(callback_func, x, - result_shape=x, - call_with_device=True) - - self.assertEqual(3., func(3.)) - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 - Called with 3.00""") - - def test_call_pmap(self): - self.supported_only_in_legacy_mode() - # Works for 1 or 2 devices - def callback_func(x, device=None): - testing_stream.write(f"device: {device}\n Called with {x}") - return x * np.array(3, np.int32) - - def fun(x): # x: i32 - return hcb.call(callback_func, x * 2, - result_shape=x, - call_with_device=True) - - xv = jnp.arange(len(local_devices()), dtype=jnp.int32) - res = jax.pmap(fun)(xv) - self.assertAllClose(jax.pmap(lambda x: x * 6)(xv), res) - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 - Called with 0 - device: cpu:1 - Called with 2""") - - def test_call_vmap(self): - def f_outside(x): return x - - def fun(x): - return hcb.call(f_outside, x, result_shape=x, - callback_flavor=hcb.CallbackFlavor.PURE) - - if hcb._HOST_CALLBACK_LEGACY.value: - with self.assertRaisesRegex(NotImplementedError, - "batching rules are implemented only for id_tap, not for call"): - jax.vmap(fun)(np.ones((2, 3))) - else: - jax.vmap(fun)(np.ones((2, 3))) - - @jtu.sample_product(device_index=[0, 1]) - @jtu.skip_on_devices("cpu") # TODO: RET_CHECK failure - def test_call_pjit(self, device_index=0): - devices = np.array(local_devices()) - nr_devices = len(devices) - if nr_devices < 2: - raise SkipTest("test requires at least 2 devices") - - logging.info(f"test_call_pjit is running on devices {devices}.") - # x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...] - # y: i32[3, 4] - x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3] - y = jnp.ones((3, 4), np.int32) - - def callback_x5_func(x, device=None): - testing_stream.write(f"device: {device}\n Called with {x}") - return x * np.array(5, np.int32) - - def fun(x): - xy = jnp.dot(x, y) - return hcb.call( - callback_x5_func, xy, result_shape=xy, call_with_device=True, - device_index=device_index) - - pjit_fun = pjit.pjit(fun, in_shardings=(P("d"),), out_shardings=P("d")) - with jax.sharding.Mesh(devices, ["d"]): - # Print the internal IR - helper_log_ir( - f"{self._testMethodName}.pjit", - pjit_fun, - x, - num_partitions=nr_devices) - - res = pjit_fun(x) - - expected_res = jnp.dot(x, y) * np.array(5, np.int32) - self.assertAllClose(expected_res, res, check_dtypes=False) - - hcb.barrier_wait("before assertion") - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual( - self, f""" - device: cpu:{device_index} - Called with [[ 3 3 3 3] - [33 33 33 33]]""") - - def test_call_error_bad_result_shape(self): - with self.assertRaisesRegex( - ValueError, - "The values must be either numeric scalars, or must have 'shape' and 'dtype' attributes"): - hcb.call(lambda x: x, 3., result_shape="string") - - with self.assertRaisesRegex( - ValueError, - "The values must be either numeric scalars, or must have 'shape' and 'dtype' attributes"): - hcb.call(lambda x: x, 3., result_shape=lambda x: x) - hcb.barrier_wait("wait for error") - - def helper_check_callback_errors(self, thunk: Callable, - expected_exc_txt: str): - """Calls thunk() and checks for expected exceptions. - """ - if jtu.test_device_matches(["cpu"]): - # On CPU the runtime crashes, and the tests are all aborted - raise SkipTest("TODO: CPU runtime crashes on unexpected infeed") - elif jtu.test_device_matches(["gpu"]): - # On GPU we get a nice error back to Python - with self.assertRaisesRegex( - RuntimeError, - "(.* Mismatch between infeed source buffer shape s8.12345." - "|.*The destination shape does not match the source shape.)"): - thunk() - elif jtu.test_device_matches(["tpu"]): - # On TPU we get no error!!! - raise SkipTest("TODO: TPU runtime does not check infeed, and just computes with garbage") - - # Both on GPU and TPU we also get an error during the barrier_wait at the - # end of the test. Run a barrier_wait now, to consume that error. - with self.assertRaisesRegex( - hcb.CallbackException, - re.compile( - "There were exceptions during callback processing.*Last one was:.*" + - expected_exc_txt, - re.DOTALL)): - hcb.barrier_wait("Waiting for error") - - def test_call_error_callback_throws_exception(self): - self.supported_only_in_legacy_mode() - def f_outside(x): - raise ValueError("user exception") - def fun(x): - return hcb.call(f_outside, x, result_shape=x) - - self.helper_check_callback_errors(lambda: fun(3.), - "ValueError: user exception") - - def test_call_error_callback_returns_unexpected_shape(self): - self.supported_only_in_legacy_mode() - def fun(x): - return hcb.call(lambda x: (x, x), x, result_shape=x) - - self.helper_check_callback_errors(lambda: fun(3.), - "Callback func .* should have returned a result with pytree") - - def test_call_error_then_compute(self): - self.supported_only_in_legacy_mode() - # Continue computation on device after error - def f_outside(x): - raise ValueError("user exception") - def fun(x): - x1 = hcb.call(f_outside, x, result_shape=x) - return x1 - arg = np.arange(3, dtype=np.int32) - self.helper_check_callback_errors(lambda: self.assertAllClose(arg, fun(arg)), - "ValueError: user exception") - - -def call_jax_other_device( - jax_outside_fun, arg, *, device, - callback_flavor: hcb.CallbackFlavor = hcb.CallbackFlavor.IO_CALLBACK): - """Calls a JAX function on a specific device with simple support for reverse AD. - - Functions whose name starts with "jax_outside" are called on another device, - by way of hcb.call. - """ - - def run_jax_outside_fun(arg): - return jax.jit(jax_outside_fun)(jax.device_put(arg, device)) - - @jax.custom_vjp - def make_call(arg): - return hcb.call(run_jax_outside_fun, arg, - result_shape=jax.eval_shape(jax_outside_fun, arg), - callback_flavor=callback_flavor) - - # Define the fwd and bwd custom_vjp functions - def make_call_vjp_fwd(arg): - # Return the primal argument as the residual. Use `make_call` for the - # primal computation to enable higher-order AD. - return make_call(arg), arg # Return the primal argument as the residual - - def make_call_vjp_bwd(res, ct_res): - arg = res # residual is the primal argument - - def jax_outside_vjp_fun(arg_and_ct): - arg, ct = arg_and_ct - _, f_vjp = jax.vjp(jax_outside_fun, arg) - ct_in, = f_vjp(ct) - return ct_in - - return (call_jax_other_device(jax_outside_vjp_fun, (arg, ct_res), device=device),) - - make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd) - return make_call(arg) - - -class CallJaxTest(jtu.JaxTestCase): - """Tests using `call_jax_other_device`.""" - - def setUp(self): - if not hcb._HOST_CALLBACK_LEGACY.value: - self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") - if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: - raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") - if xla_bridge.using_pjrt_c_api(): - raise SkipTest("host_callback not implemented in PJRT C API") - - if not jtu.test_device_matches(["cpu"]): - assert jax.devices("cpu") - self.outside_device = jax.devices("cpu")[0] - else: - if len(jax.devices("cpu")) == 1: - raise SkipTest("Test needs at least two devices. On CPU use XLA_FLAGS=--xla_force_host_platform_device_count=2") - self.outside_device = jax.devices("cpu")[1] - super().setUp() - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="The host_callback APIs are deprecated")) - - - def test_jax_impl(self): - def f_jax(x): - return jnp.sin(x) - - def f_outside(x): - return call_jax_other_device(f_jax, x, device=self.outside_device) - - self.assertAllClose(f_jax(3.), f_outside(3.)) - self.assertAllClose(f_jax(3.), jax.jit(f_outside)(3.)) - - def test_jax_impl_pytree(self): - def f_jax(x): - # x : dict(a=..., b=...) and output is a list of two elements - return [jnp.sin(x["a"]), jnp.sin(x["b"])] - - def f_outside(x): - return call_jax_other_device(f_jax, x, device=self.outside_device) - - x = dict(a=3., b=4.) - res_jax = f_jax(x) - # print(f"outside_jaxpr = {jax.make_jaxpr(f_outside)(x)}") - res_outside = f_outside(x) - self.assertAllClose(res_jax, res_outside) - - def test_jax_grad(self): - def f_jax(x): - return 2. * jnp.sin(x) - - def f_outside(x): - return 2. * call_jax_other_device(jnp.sin, x, device=self.outside_device) - - res_jax = jax.grad(f_jax)(3.) - self.assertAllClose(res_jax, jax.grad(f_outside)(3.)) - - def test_jax_grad_pytree(self): - def f_jax(x): - # x : dict(a=..., b=...) and output is a float - return 3. * jnp.sin(x["a"]) + jnp.sin(x["b"]) - - def f_outside(x): - return call_jax_other_device(f_jax, x, device=self.outside_device) - - x = dict(a=3., b=4.) - res_jax = jax.grad(f_jax)(x) - self.assertAllClose(res_jax, jax.grad(f_outside)(x)) - - def test_jax_grad_of_grad(self): - def f_jax(x): - return 2. * x * x * x - - def f_outside(x): - return 2. * call_jax_other_device(lambda x: x * x * x, x, device=self.outside_device) - - res_jax = jax.grad(jax.grad(f_jax))(5.) - res_outside = jax.grad(jax.grad(f_outside))(5.) - self.assertAllClose(res_jax, res_outside) - - -class OutfeedRewriterTest(jtu.JaxTestCase): - - def setUp(self): - if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: - raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") - if xla_bridge.using_pjrt_c_api(): - raise SkipTest("host_callback not implemented in PJRT C API") - super().setUp() - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="The host_callback APIs are deprecated")) - - def supported_only_in_legacy_mode(self): - if not hcb._HOST_CALLBACK_LEGACY.value: - self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") - - def assertRewrite(self, expected: str, func: Callable, args: Sequence, - has_input_token=True, has_output_token=True): - """Check that the rewrite of func(*args) matches expected.""" - jaxpr = jax.make_jaxpr(func)(*args) - rewritten = hcb._rewrite_closed_jaxpr(jaxpr, # noqa: F841 - has_input_token, has_output_token) - # Since it is somewhat annoying to update the Jaxpr assertions when we change - # the Jaxpr printing, we do not check these by default. It is recommended that - # before making changes to the code generation and Jaxpr rewriting, turn on - # the checking, update the expected Jaxpr, and then make the changes. - # assertMultiLineStrippedEqual(self, expected, str(rewritten)) - del rewritten - - def test_no_outfeed(self): - self.assertRewrite(""" - { lambda ; a. - let b = mul a a - c = add a b - in (c,) }""", lambda x: x + x * x, [0], has_input_token=False, - has_output_token=False) - self.assertRewrite(""" - { lambda ; a d e. - let b = mul a a - c = add a b - in (c,) }""", lambda x: x + x * x, [0], has_output_token=False) - self.assertRewrite(""" - { lambda ; a d e. - let b = mul a a - c = add a b - in (c, d, e) }""", lambda x: x + x * x, [0]) - - def test_simple_outfeed(self): - self.assertRewrite(""" - { lambda ; a d e. - let b = add a a - c f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] b d e - in (c, f, g) }""", lambda x: hcb_id_print(x + x), [0]) - - def test_simple_outfeed_without_input_token(self): - self.assertRewrite(""" - { lambda ; a b. - let e = create_token a b - f = create_token a b - c = add a b - d g h = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] c e f - in (d,) }""", lambda x1, x2: hcb_id_print(x1 + x2), [1, 2], - has_input_token=False, has_output_token=False) - - def test_simple_outfeed_without_input_token_nor_invars(self): - self.assertRewrite(""" - { lambda ; . - let b = create_token - c = create_token - a d e = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] 42 b c - in (a,) }""", lambda: hcb_id_print(42), [], - has_input_token=False, has_output_token=False) - - def test_multiple_tap_without_dependencies(self): - def f(x): - hcb_id_print(x, what="x") - hcb_id_print(x + 1, what="x + 1") - return 2 - - self.assertRewrite(""" - { lambda ; a c d. - let _ e f = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a c d - b = add a 1 - _ g h = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] b e f - in (2, g, h) }""", f, [1]) - - def test_cond(self): - y = jnp.ones(5) # captured const - - def func(x, z): - return lax.cond(z > 0, (1, 2), lambda a: (a[0], jnp.zeros(5)), - z, lambda a: (hcb_id_print(a), y)) - - self.assertRewrite(""" - { lambda a ; b c h i. - let d = gt c 0 - e = convert_element_type[ new_dtype=int32 ] d - f g j k = - cond[ branches=( { lambda ; a b c d f g. - let e h i = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] d f g - in (e, a, h, i) } - { lambda ; f_ a b c g h. - let d = broadcast_in_dim[ broadcast_dimensions=( ) - shape=(5,) ] 0.00 - in (a, d, g, h) } ) ] e a 1 2 c h i - in (f, g, j, k) }""", func, [y, 5]) - - def test_while(self): - ct_body = jnp.ones(5, np.float32) # captured const for the body - ct_cond = jnp.ones(5, np.float32) # captured const for the conditional - - def func(x): - # x: f32[5] - # c: (f32[5], f32) - return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond), - lambda c: (ct_body, hcb_id_print(c[1]) + 1.), - (x, np.float32(1.))) - - self.assertRewrite(""" - { lambda a b ; c f g. - let d e h i = - while[ body_jaxpr={ lambda ; a b c f g. - let d h i = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] c f g - e = add d 1.00 - in (a, e, h, i) } - body_nconsts=1 - cond_jaxpr={ lambda ; a b c g h. - let d = add b a - e = reduce_sum[ axes=(0,) ] d - f = lt c e - in (f,) } - cond_nconsts=1 ] a b c 1.00 f g - in (d, e, h, i) }""", func, [ct_body]) - - def test_while_pred_outfeed(self): - """A while with outfeed in the pred.""" - ct_body = jnp.ones(5) # captured const for the body - ct_cond = jnp.ones(2) # captured const for the conditional - - def func(x): - return lax.while_loop(lambda c: hcb_id_print(ct_cond, result=c[1]) < 5, - lambda c: (ct_body, hcb_id_print(c[1]) + 1), - (x, 1)) - - self.assertRewrite(""" - { lambda a b ; c f g. - let j k l = xla_call[ call_jaxpr={ lambda ; a b c g h. - let d i j = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a g h - e = id_tap_dep c d - f = lt e 5 - in (f, i, j) } - donated_invars=(False, False, False, False, False) - name=cond_before ] a c 1 f g - bf d e h i = - while[ body_jaxpr={ lambda ; r s t u v w x. - let y z ba bb = - xla_call[ call_jaxpr={ lambda ; a b c f g. - let d h i = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] c f g - e = add d 1 - in (a, e, h, i) } - donated_invars=(False, False, False, False, False) - name=body ] s u v w x - bc bd be = - xla_call[ call_jaxpr={ lambda ; a b c g h. - let d i j = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a g h - e = id_tap_dep c d - f = lt e 5 - in (f, i, j) } - donated_invars=(False, False, False, False, False) - name=cond_body ] r y z ba bb - in (bc, y, z, bd, be) } - body_nconsts=2 - cond_jaxpr={ lambda ; m n o p q. - let - in (m,) } - cond_nconsts=0 ] a b j c 1 k l - in (d, e, h, i) }""", func, [ct_body]) - - def test_scan(self): - y = jnp.ones(5) # captured const - - def func(x): - return lax.scan(lambda c, a: (hcb_id_print(c), y), (1, 2), x) - - self.assertRewrite(""" - { lambda a ; b f g. - let c d h i e = - scan[ jaxpr={ lambda ; a b c g h d. - let e f i j = - outside_call[ arg_treedef=PyTreeDef(tuple, [*,*]) - callback=... - has_token=True - identity=True ] b c g h - in (e, f, i, j, a) } - length=5 - linear=(False, False, False, False, False, False) - num_carry=4 - num_consts=1 - reverse=False - unroll=1 ] a 1 2 f g b - in (c, d, e, h, i) }""", func, [y]) - - def test_scan_custom_jvp(self): - """custom JVP, inside scan. - This exercises the custom_jvp_call_jaxpr primitives.""" - self.supported_only_in_legacy_mode() - @jax.custom_jvp - def f(x): - return x * hcb_id_print(x) - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - x_dot, = tangents - primal_out = f(x) - tangent_out = 3. * x * hcb_id_print(x_dot) - return primal_out, tangent_out - - def g(x): - # Sum f(x_i) - return lax.scan(lambda carry, inp: (carry + f(inp), 0.), - np.full(x.shape[1:], 0.), # Like x w/o leading dim - x)[0] - - arg = np.full((5,), 0.7) - self.assertRewrite(""" - { lambda ; a c d. - let b e f _ = - scan[ jaxpr={ lambda ; a e f b. - let c g h = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda ; a d e. - let b f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a d e - c = mul a b - in (c, f, g) } - num_consts=0 ] b e f - d = add a c - in (d, g, h, 0.00) } - length=5 - linear=(False, False, False, False) - num_carry=3 - num_consts=0 - reverse=False - unroll=1 ] 0.00 c d a - in (b, e, f) }""", g, [arg]) - self.assertRewrite(""" - { lambda ; a d e. - let _ _ f g _ b = - scan[ jaxpr={ lambda ; a b h i c d. - let e j k = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda ; a d e. - let b f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a d e - c = mul a b - in (c, f, g) } - num_consts=0 ] c h i - f = add a e - g = mul c 3.00 - in (f, *, j, k, 0.00, g) } - length=5 - linear=(False, True, False, False, False, True) - num_carry=4 - num_consts=0 - reverse=False - unroll=1 ] 0.00 * d e a * - _ _ h i _ c = - scan[ jaxpr={ lambda ; a b g h c d. - let e = mul b d - f i j = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True - transforms=(('transpose',),) ] e g h - in (*, b, i, j, *, f) } - length=5 - linear=(True, True, False, False, True, False) - num_carry=4 - num_consts=0 - reverse=True - unroll=1 ] * 1.00 f g * b - in (c, h, i) }""", jax.grad(g), [arg]) - - def test_scan_custom_vjp(self): - """custom VJP, inside scan. - This exercises the custom_vjp_call_jaxpr primitives.""" - self.supported_only_in_legacy_mode() - @jax.custom_vjp - def f(x): - return x * hcb_id_print(x) - - # f_fwd: a -> (b, residual) - def f_fwd(x): - return f(x), 3. * x - - # f_bwd: (residual, CT b) -> [CT a] - def f_bwd(residual, ct_b): - return residual * hcb_id_print(ct_b), - - f.defvjp(f_fwd, f_bwd) - - def g(x): - # Sum f(x_i) - return lax.scan(lambda carry, inp: (carry + f(inp), 0.), - np.full(x.shape[1:], 0.), # Like x w/o leading dim - x)[0] - - arg = np.full((2,), 0.7) - self.assertRewrite(""" - { lambda ; a c d. - let b e f _ = - scan[ jaxpr={ lambda ; a e f b. - let c g h = custom_vjp_call_jaxpr[ - fun_jaxpr={ lambda ; a d e. - let b f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a d e - c = mul a b - in (c, f, g) } - num_consts=0 - ] b e f - d = add a c - in (d, g, h, 0.00) } - length=2 - linear=(False, False, False, False) - num_carry=3 - num_consts=0 - reverse=False - unroll=1 ] 0.00 c d a - in (b, e, f) }""", g, [arg]) - self.assertRewrite(""" - { lambda ; a d e. - let _ _ f g _ b = - scan[ jaxpr={ lambda ; a b h i c d. - let e j k = custom_vjp_call_jaxpr[ - fun_jaxpr={ lambda ; a d e. - let b f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a d e - c = mul a b - in (c, f, g) } - num_consts=0 - ] c h i - f = add a e - g = mul c 3.00 - in (f, *, j, k, 0.00, g) } - length=2 - linear=(False, True, False, False, False, True) - num_carry=4 - num_consts=0 - reverse=False - unroll=1 ] 0.00 * d e a * - _ _ h i _ c = - scan[ jaxpr={ lambda ; a b g h c d. - let e i j = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] b g h - f = mul d e - in (*, b, i, j, *, f) } - length=2 - linear=(True, True, False, False, True, False) - num_carry=4 - num_consts=0 - reverse=True - unroll=1 ] * 1.00 f g * b - in (c, h, i) }""", jax.grad(g), [arg]) - - def test_remat_loop(self): - def f(k, x): - x = hcb_id_print(k + x) - return -k * x - - def loss(k): - return lax.fori_loop(0, 1, jax.remat(f), k) - - self.assertRewrite(""" - { lambda ; a c d. - let _ _ b e f = - while[ body_jaxpr={ lambda ; a b c f g. - let d = add a 1 - e h i = remat_call[ call_jaxpr={ lambda ; a b g h. - let c = add a b - d i j = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] c g h - e = neg a - f = mul e d - in (f, i, j) } - concrete=False - name=f ] a c f g - in (d, b, e, h, i) } - body_nconsts=0 - cond_jaxpr={ lambda ; a b c e f. - let d = lt a b - in (d,) } - cond_nconsts=0 ] 0 1 a c d - in (b, e, f) }""", loss, [2]) - - def test_named_call(self): - def tap_scalar(init, do_print=False): - @partial(jax.named_call, name="step") - def step(acc, step_nr): - acc = acc + step_nr - maybe_print(do_print, step_nr, what="step_nr") - return acc, None - - return lax.scan(step, init, np.arange(2, dtype=np.int32)) - - self.assertRewrite(""" - { lambda a ; b d e. - let c = scan[ jaxpr={ lambda ; a b. - let c = named_call[ call_jaxpr={ lambda ; a b. - let c = add a b - in (c,) } - name=step ] a b - in (c,) } - length=2 - linear=(False, False) - num_carry=1 - num_consts=0 - reverse=False - unroll=1 ] b a - in (c, d, e) }""", tap_scalar, [np.int32(3)]) - - def test_pmap(self): - self.supported_only_in_legacy_mode() - def f(xv): - jax.pmap(lambda x: jnp.sin(hcb_id_print(x, tap_with_device=True)), - axis_name="i")(xv) - - self.assertRewrite(""" - { lambda ; a b c. - let _ d e = xla_pmap[ axis_name=i - axis_size=1 - backend=None - call_jaxpr={ lambda ; a d e. - let b f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a d e - c = sin b - in (c, f, g) } - devices=None - donated_invars=(False, False, False) - global_axis_size=None - in_axes=(0, 0, 0) - name= - out_axes=(0, 0, 0) ] a b c - in (d, e) }""", f, [np.array([2.], dtype=np.float32)]) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/host_callback_to_tf_test.py b/tests/host_callback_to_tf_test.py deleted file mode 100644 index 3a36ce1296a6..000000000000 --- a/tests/host_callback_to_tf_test.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""An example of using host_callback.call to invoke on the host functions -written in Tensorflow. The interesting aspect here is how we can differentiate -through the outside computation, using tf.GradientTape on the host. - -This is separate from host_callback_test because it needs a TF dependency. -""" -from collections.abc import Callable -import unittest - -from absl.testing import absltest -from absl.testing import parameterized - -import jax -from jax import numpy as jnp -from jax._src import config -from jax._src import test_util as jtu -from jax._src import xla_bridge -from jax.experimental import host_callback as hcb - -import numpy as np - -try: - import tensorflow as tf -except ImportError: - tf = None - -config.parse_flags_with_absl() - - -def call_tf_no_ad(tf_fun: Callable, arg, *, result_shape): - """The simplest implementation of calling to TF, without AD support. - - We must use hcb.call because the TF invocation must happen outside the - JAX staged computation.""" - - def tf_to_numpy(t): - # Turn the Tensor to NumPy array without copying. - return np.asarray(memoryview(t)) if isinstance(t, tf.Tensor) else t - - return hcb.call(lambda arg: tf.nest.map_structure(tf_to_numpy, - tf_fun(arg)), - arg, result_shape=result_shape, - callback_flavor=hcb.CallbackFlavor.DEBUG) - - -def call_tf_simple_ad(tf_fun: Callable, arg, *, result_shape): - """Calls a TensorFlow function with simple support for reverse AD. - - Works only for 1st order AD and only for arguments and results being a single - ndarray (no pytrees). Functions whose name starts with "tf_" are TensorFlow - functions and must be called outside the JAX computation. - """ - - @jax.custom_vjp - def make_call(arg): - """We wrap it all in `make_call` so that we can attach custom VJP.""" - return call_tf_no_ad(tf_fun, arg, result_shape=result_shape) - - # Define the fwd and bwd custom_vjp functions - def make_call_vjp_fwd(arg): - # Return the primal argument as the residual. Use `make_call` for the - # primal computation to enable higher-order AD. - return make_call(arg), arg - - def make_call_vjp_bwd(res, ct_res): - arg = res # residual is the primal argument - - def tf_vjp_fun(arg_and_ct_res): - """Invoke TF gradient; used with hcb.call.""" - arg, ct_res = arg_and_ct_res - arg_var = tf.Variable(arg) - with tf.GradientTape(persistent=True) as tape: - res = tf_fun(arg_var) - - dres_darg = tape.gradient(res, sources=arg_var, - output_gradients=ct_res, - unconnected_gradients=tf.UnconnectedGradients.ZERO) - return dres_darg - - return (call_tf_simple_ad(tf_vjp_fun, (arg, ct_res), - result_shape=arg),) - - make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd) - return make_call(arg) - - -def call_tf_full_ad(tf_fun: Callable, arg, *, result_shape): - """Calls a TensorFlow function with support for reverse AD. - - Supports higher-order AD and pytree arguments. - """ - - @jax.custom_vjp - def make_call(arg): - """We wrap it all in `make_call` so that we can attach custom VJP.""" - return call_tf_no_ad(tf_fun, arg, result_shape=result_shape) - - # Define the fwd and bwd custom_vjp functions - def make_call_vjp_fwd(arg): - return make_call(arg), arg # Return the primal argument as the residual - - def make_call_vjp_bwd(res, ct_res): - arg = res # residual is the primal argument - - def tf_vjp_fun(arg_and_ct_res): - """Invoke TF gradient; used with hcb.call.""" - arg, ct_res = arg_and_ct_res - - def make_var(a): - return a if isinstance(a, tf.Variable) else tf.Variable(a) - - arg_var = tf.nest.map_structure(make_var, arg) - - with tf.GradientTape(persistent=True) as tape: - res = tf_fun(arg_var) - - tf.nest.assert_same_structure(res, ct_res) - accumulator = None # Accumulate argument cotangent. Same structure as "arg" - - def acc_ct(res_, ct_res_): - dres_darg = tape.gradient(res_, sources=arg_var, - unconnected_gradients=tf.UnconnectedGradients.ZERO) - tf.nest.assert_same_structure(dres_darg, arg) - scaled_dres_darg = tf.nest.map_structure(lambda d: d * ct_res_, dres_darg) - nonlocal accumulator - accumulator = (scaled_dres_darg if accumulator is None - else tf.nest.map_structure(lambda x, y: x + y, - accumulator, scaled_dres_darg)) - - tf.nest.map_structure(acc_ct, res, ct_res) - return accumulator - - return (call_tf_full_ad(tf_vjp_fun, (arg, ct_res), - result_shape=arg),) - - make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd) - return make_call(arg) - - -CALL_TF_IMPLEMENTATIONS = { - "none": call_tf_no_ad, - "simple": call_tf_simple_ad, - "full": call_tf_full_ad, -} - - -class CallToTFTest(jtu.JaxTestCase): - - def setUp(self): - if tf is None: - raise unittest.SkipTest("Test requires tensorflow") - if xla_bridge.using_pjrt_c_api(): - raise unittest.SkipTest("host_callback not implemented in PJRT C API") - super().setUp() - - def supported_only_in_legacy_mode(self): - if not hcb._HOST_CALLBACK_LEGACY.value: - self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") - - @parameterized.named_parameters( - dict( - testcase_name=f"_{ad=}", - ad=ad) - for ad in CALL_TF_IMPLEMENTATIONS.keys()) - @jtu.ignore_warning(message="The host_callback APIs are deprecated", - category=DeprecationWarning) - def test_impl(self, ad="simple"): - self.supported_only_in_legacy_mode() - call_tf = CALL_TF_IMPLEMENTATIONS[ad] - - def f_jax(x): - return jnp.sin(x) - - def f_outside(x): - return call_tf(tf.math.sin, x, - result_shape=x) - - res = f_outside(3.) - self.assertAllClose(f_jax(3.), res) - self.assertAllClose(f_jax(3.), jax.jit(f_outside)(3.)) - - @parameterized.named_parameters( - dict( - testcase_name=f"_{ad=}", - ad=ad) - for ad in CALL_TF_IMPLEMENTATIONS.keys() - if ad != "none") - @jtu.ignore_warning(message="The host_callback APIs are deprecated", - category=DeprecationWarning) - def test_grad(self, ad="simple"): - self.supported_only_in_legacy_mode() - call_tf = CALL_TF_IMPLEMENTATIONS[ad] - - def f_jax(x): - return 3. * jnp.sin(2. * x) - - def f_outside(x): - return 3. * call_tf( - lambda x: tf.cast(tf.math.sin(x), tf.float32), 2. * x, - result_shape=jax.ShapeDtypeStruct((), np.float32)) - - x = np.float32(4.) - self.assertAllClose(f_jax(x), f_outside(x), - check_dtypes=False) - - grad_f = jax.grad(f_outside)(x) - self.assertAllClose(jax.grad(f_jax)(x), grad_f, - check_dtypes=False) - - @jtu.ignore_warning(message="The host_callback APIs are deprecated", - category=DeprecationWarning) - def test_grad_pytree(self): - self.supported_only_in_legacy_mode() - call_tf = call_tf_full_ad - - def f_jax(xy): - dict_ab = dict(a=2. * xy[0], b=xy[0] * xy[1]) - return 3. * dict_ab["a"] + 4. * dict_ab["b"] - - def f_outside(xy): - dict_ab = call_tf( - lambda xy: dict(a=tf.cast(2. * xy[0], np.float32), - b=tf.cast(xy[0] * xy[1], np.float32)), - xy, - result_shape=dict(a=jax.ShapeDtypeStruct((), np.float32), - b=jax.ShapeDtypeStruct((), np.float32))) - return 3. * dict_ab["a"] + 4. * dict_ab["b"] - - xy = (5., 6.) - self.assertAllClose(f_jax(xy), f_outside(xy), - check_dtypes=False) - res_jax = jax.grad(f_jax)(xy) - self.assertAllClose(res_jax, jax.grad(f_outside)(xy), - check_dtypes=False) - - @parameterized.named_parameters( - dict( - testcase_name=f"_degree=_{degree}", - degree=degree) - for degree in [1, 2, 3, 4]) - @jtu.ignore_warning(message="The host_callback APIs are deprecated", - category=DeprecationWarning) - def test_higher_order_grad(self, degree=4): - self.supported_only_in_legacy_mode() - call_tf = call_tf_full_ad - - def f_jax(x): - return 2. * x * x * x - - def f_outside(x): - return 2. * call_tf(lambda y: y * y * y, x, - result_shape=x) - - grad_jax = f_jax - grad_outside = f_outside - for i in range(degree): - grad_jax = jax.grad(grad_jax) - grad_outside = jax.grad(grad_outside) - - res_jax = grad_jax(5.) - self.assertAllClose(res_jax, grad_outside(5.)) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/infeed_test.py b/tests/infeed_test.py index ba47d2417f94..e378fe37a2f5 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -19,7 +19,6 @@ from absl.testing import absltest import jax from jax import lax, numpy as jnp -from jax.experimental import host_callback as hcb from jax._src import core from jax._src import xla_bridge from jax._src.lib import xla_client @@ -77,7 +76,6 @@ def f(x): @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. def testInfeedThenOutfeed(self): - hcb._deprecated_stop_outfeed_receiver() @jax.jit def f(x): @@ -99,7 +97,6 @@ def f(x): self.assertAllClose(out, y + np.float32(1)) def testInfeedThenOutfeedInALoop(self): - hcb._deprecated_stop_outfeed_receiver() def doubler(_, token): y, token = lax.infeed( diff --git a/tests/jax_to_ir_test.py b/tests/jax_to_ir_test.py index 91780e17800a..f600a08f5dc4 100644 --- a/tests/jax_to_ir_test.py +++ b/tests/jax_to_ir_test.py @@ -81,10 +81,6 @@ def test_parse_shape_str_invalid(self): jax_to_ir.parse_shape_str('foo[]') @unittest.skipIf(tf is None, 'TensorFlow not installed.') - @jtu.ignore_warning( - category=UserWarning, - message='jax2tf.convert with native_serialization=False is deprecated.' - ) def test_jax_to_tf_axpy(self): tf_proto, tf_text = jax_to_ir.jax_to_tf(axpy, [ ('y', jax_to_ir.parse_shape_str('f32[128]')), @@ -92,11 +88,6 @@ def test_jax_to_tf_axpy(self): ('x', jax_to_ir.parse_shape_str('f32[128,2]')), ]) - # Check that tf debug txt contains a broadcast, add, and multiply. - self.assertIn('BroadcastTo', tf_text) - self.assertIn('AddV2', tf_text) - self.assertIn('Mul', tf_text) - # Check that we can re-import our graphdef. gdef = tf.compat.v1.GraphDef() gdef.ParseFromString(tf_proto) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index d3dada0d750a..5f1781c3be06 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -5590,7 +5590,6 @@ def test_isdtype(self, dtype, kind): self.assertEqual(jax_result, numpy_result) -from jaxlib import xla_client @unittest.skipIf(metal_plugin == None, "Tests require jax-metal plugin.") class ReportedIssuesTests(jtu.JaxTestCase): def dispatchOn(self, args, func, device=jax.devices('cpu')[0]): @@ -5601,11 +5600,16 @@ def dispatchOn(self, args, func, device=jax.devices('cpu')[0]): @staticmethod def compile_and_exec(module, args, run_on_cpu=False): - backend = jax.lib.xla_bridge.get_backend('METAL') - if (run_on_cpu): - backend = jax.lib.xla_bridge.get_backend('cpu') - executables = backend.compile(module) - return xla_client.execute_with_python_values(executables, args, backend) + from jax.extend.backend import get_backend + backend = get_backend('METAL') + if run_on_cpu: + backend = get_backend('cpu') + executable = backend.compile(module) + def put(arg): + return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) + arguments = [put(arg) for arg in args] + outputs = executable.execute(arguments) + return [np.asarray(x) for x in outputs] @staticmethod def jax_metal_supported(target_ver): diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index d58a5c2c3866..392af2688c1d 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -1252,7 +1252,7 @@ def _can_cast(from_, to): def _compatible_dtypes(op, dtype, inexact=False): - if op == UpdateOps.ADD: + if op == UpdateOps.ADD or op == UpdateOps.SUB: return [dtype] elif inexact: return [dt for dt in float_dtypes if _can_cast(dt, dtype)] @@ -1263,17 +1263,19 @@ def _compatible_dtypes(op, dtype, inexact=False): class UpdateOps(enum.Enum): UPDATE = 0 ADD = 1 - MUL = 2 - DIV = 3 - POW = 4 - MIN = 5 - MAX = 6 + SUB = 2 + MUL = 3 + DIV = 4 + POW = 5 + MIN = 6 + MAX = 7 def np_fn(op, indexer, x, y): x = x.copy() x[indexer] = { UpdateOps.UPDATE: lambda: y, UpdateOps.ADD: lambda: x[indexer] + y, + UpdateOps.SUB: lambda: x[indexer] - y, UpdateOps.MUL: lambda: x[indexer] * y, UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)( lambda: x[indexer] / y.astype(x.dtype)), @@ -1290,6 +1292,7 @@ def jax_fn(op, indexer, x, y, indices_are_sorted=False, return { UpdateOps.UPDATE: x.at[indexer].set, UpdateOps.ADD: x.at[indexer].add, + UpdateOps.SUB: x.at[indexer].subtract, UpdateOps.MUL: x.at[indexer].multiply, UpdateOps.DIV: x.at[indexer].divide, UpdateOps.POW: x.at[indexer].power, @@ -1420,7 +1423,7 @@ def testMixedAdvancedIndexing(self, name, shape, dtype, update_shape, for update_shape in _broadcastable_shapes(index_shape) ], [dict(op=op, dtype=dtype, update_dtype=update_dtype) - for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] + for op in [UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE] for dtype in float_dtypes for update_dtype in _compatible_dtypes(op, dtype, inexact=True) ], @@ -1447,8 +1450,9 @@ def testStaticIndexingGrads(self, name, shape, dtype, update_shape, ], [dict(op=op, dtype=dtype, update_dtype=update_dtype) for op in ( - [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] if unique_indices - else [UpdateOps.ADD]) + [UpdateOps.ADD, UpdateOps.SUB, UpdateOps.MUL, UpdateOps.UPDATE] + if unique_indices + else [UpdateOps.ADD, UpdateOps.SUB]) for dtype in float_dtypes for update_dtype in _compatible_dtypes(op, dtype, inexact=True) ], diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index 45a780c9f721..744a99fb70e7 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -126,6 +126,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16], all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0), + op_record("spacing", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"], + inexact=True, tolerance=0), op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), @@ -701,6 +703,34 @@ def testI0Grad(self): dx = jax.grad(jax.numpy.i0)(0.0) self.assertArraysEqual(dx, 0.0) + @jtu.sample_product( + shape=all_shapes, + dtype=default_dtypes, + ) + def testSpacingIntegerInputs(self, shape, dtype): + rng = jtu.rand_int(self.rng(), low=-64, high=64) + args_maker = lambda: [rng(shape, dtype)] + computation_dtype = jnp.spacing(rng(shape, dtype)).dtype + np_func = lambda x: np.spacing(np.array(x).astype(computation_dtype)) + self._CheckAgainstNumpy(np_func, jnp.spacing, args_maker, check_dtypes=True, tol=0) + self._CompileAndCheck(jnp.spacing, args_maker, tol=0) + + @jtu.sample_product(dtype = float_dtypes) + @jtu.skip_on_devices("tpu") + def testSpacingSubnormals(self, dtype): + zero = np.array(0, dtype=dtype) + inf = np.array(np.inf, dtype=dtype) + x = [zero] + for i in range(5): + x.append(np.nextafter(x[-1], -inf)) # negative denormals + x = x[::-1] + for i in range(5): + x.append(np.nextafter(x[-1], inf)) # positive denormals + x = np.array(x, dtype=dtype) + args_maker = lambda: [x] + self._CheckAgainstNumpy(np.spacing, jnp.spacing, args_maker, check_dtypes=True, tol=0) + self._CompileAndCheck(jnp.spacing, args_maker, tol=0) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 6f8167df9c29..780419d0d81a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2424,6 +2424,17 @@ def testTrilIndicesFrom(self, shape, dtype, k): args_maker = lambda: [rng(shape, dtype), k] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + @jtu.sample_product( + n = [2, 3, 4], + k = [None, -1, 0, 1], + funcname = ['triu', 'tril'] + ) + def testMaskIndices(self, n, k, funcname): + kwds = {} if k is None else {'k': k} + jnp_result = jnp.mask_indices(n, getattr(jnp, funcname), **kwds) + np_result = np.mask_indices(n, getattr(np, funcname), **kwds) + self.assertArraysEqual(jnp_result, np_result, check_dtypes=False) + @jtu.sample_product( dtype=default_dtypes, a_shape=[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (1, 2), (0, 2), (2, 3), (2, 2, 2), (2, 2, 2, 2)], @@ -3764,6 +3775,14 @@ def testMemoryView(self): np.array(bytearray(b'\x2a\xf3'), ndmin=2) ) + @jtu.sample_product(value=[False, 1, 1.0, np.int32(5), np.array(16)]) + def testIsScalar(self, value): + self.assertTrue(jnp.isscalar(value)) + + @jtu.sample_product(value=[None, [1], slice(4), (), np.array([0])]) + def testIsNotScalar(self, value): + self.assertFalse(jnp.isscalar(value)) + @jtu.sample_product(val=[1+1j, [1+1j], jnp.pi, np.arange(2)]) def testIsComplexObj(self, val): args_maker = lambda: [val] @@ -4615,11 +4634,16 @@ def args_maker(): return x, i jnp_op = lambda x, i: jnp.take_along_axis(x, i, axis=axis) + jnp_one_hot_op = lambda x, i: jnp.take_along_axis( + x, i, axis=axis, mode='one_hot' + ) if hasattr(np, "take_along_axis"): np_op = lambda x, i: np.take_along_axis(x, i, axis=axis) self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CheckAgainstNumpy(np_op, jnp_one_hot_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + self._CompileAndCheck(jnp_one_hot_op, args_maker) def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): # https://github.com/jax-ml/jax/issues/5088 @@ -6142,6 +6166,15 @@ def testGradLogaddexp2Complex(self, shapes, dtype): tol = 3e-2 check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol) + @jtu.sample_product( + n=range(-4, 5), + dtype=[jnp.float32, jnp.float64], + ) + def testGradLdexp(self, n, dtype): + rng = jtu.rand_default(self.rng()) + x = rng((), dtype) + check_grads(lambda x: jnp.ldexp(x, n), (x,), 1) + class NumpySignaturesTest(jtu.JaxTestCase): @@ -6316,8 +6349,8 @@ def test_lax_numpy_docstrings(self): unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', - 'amax', 'amin', 'around', 'bitwise_right_shift', 'conj', 'degrees', - 'divide', 'mod', 'pow', 'radians', 'round_'] + 'amax', 'amin', 'around', 'bitwise_left_shift', 'bitwise_right_shift', + 'conj', 'degrees', 'divide', 'mod', 'pow', 'radians', 'round_'] skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split'] for name in dir(jnp): @@ -6381,14 +6414,13 @@ def wrapped(x, out=None): if jit: wrapped = jax.jit(wrapped) - wrapped = implements(orig, skip_params=['out'])(wrapped) + wrapped = implements(orig)(wrapped) doc = wrapped.__doc__ self.assertStartsWith(doc, "Example Docstring") self.assertIn("Original docstring below", doc) self.assertIn("Parameters", doc) self.assertIn("Returns", doc) - self.assertNotIn('out', doc) self.assertNotIn('other_arg', doc) self.assertNotIn('versionadded', doc) diff --git a/tests/lax_test.py b/tests/lax_test.py index c8f3ca797903..66739d9b520b 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1061,19 +1061,19 @@ def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision): accumulation_type=np.float32, ), [np.float16]), ("F16_F16_F32", [np.float16]), - (lax.DotAlgorithm.Preset.DEFAULT, lax_test_util.float_dtypes), - (lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes), - (lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes), - (lax.DotAlgorithm.Preset.F16_F16_F16, [np.float16]), - (lax.DotAlgorithm.Preset.F16_F16_F32, [np.float16]), - (lax.DotAlgorithm.Preset.BF16_BF16_BF16, [dtypes.bfloat16]), - (lax.DotAlgorithm.Preset.BF16_BF16_F32, [dtypes.bfloat16]), - (lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, [np.float32]), - (lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, [np.float32]), - (lax.DotAlgorithm.Preset.TF32_TF32_F32, [np.float32]), - (lax.DotAlgorithm.Preset.TF32_TF32_F32_X3, [np.float32]), - (lax.DotAlgorithm.Preset.F32_F32_F32, [np.float32]), - (lax.DotAlgorithm.Preset.F64_F64_F64, [np.float64]), + (lax.DotAlgorithmPreset.DEFAULT, lax_test_util.float_dtypes), + (lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes), + (lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes), + (lax.DotAlgorithmPreset.F16_F16_F16, [np.float16]), + (lax.DotAlgorithmPreset.F16_F16_F32, [np.float16]), + (lax.DotAlgorithmPreset.BF16_BF16_BF16, [dtypes.bfloat16]), + (lax.DotAlgorithmPreset.BF16_BF16_F32, [dtypes.bfloat16]), + (lax.DotAlgorithmPreset.BF16_BF16_F32_X3, [np.float32]), + (lax.DotAlgorithmPreset.BF16_BF16_F32_X6, [np.float32]), + (lax.DotAlgorithmPreset.TF32_TF32_F32, [np.float32]), + (lax.DotAlgorithmPreset.TF32_TF32_F32_X3, [np.float32]), + (lax.DotAlgorithmPreset.F32_F32_F32, [np.float32]), + (lax.DotAlgorithmPreset.F64_F64_F64, [np.float64]), ] for dtype in test_dtypes if jtu.dtypes.supported([dtype]) ]) @@ -1084,26 +1084,35 @@ def testDotAlgorithm(self, algorithm, dtype): if jaxlib_version <= (0, 4, 33): raise SkipTest( "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + if jtu.test_device_matches(["cpu"]): + if algorithm not in { + lax.DotAlgorithmPreset.DEFAULT, + lax.DotAlgorithmPreset.F16_F16_F16, + lax.DotAlgorithmPreset.F32_F32_F32, + lax.DotAlgorithmPreset.F64_F64_F64, + }: + raise SkipTest( + f"The dot algorithm '{algorithm}' is not supported on CPU.") if jtu.test_device_matches(["gpu"]): # GPU algorithm support is a little spotty. It is checked in # xla/service/algorithm_util.cc and the logic is copied here. if algorithm in { - lax.DotAlgorithm.Preset.F16_F16_F32, - lax.DotAlgorithm.Preset.TF32_TF32_F32, - lax.DotAlgorithm.Preset.BF16_BF16_F32, - lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, # Must have f32 input - lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, # Must have f32 input + lax.DotAlgorithmPreset.F16_F16_F32, + lax.DotAlgorithmPreset.TF32_TF32_F32, + lax.DotAlgorithmPreset.BF16_BF16_F32, + lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, }: if not jtu.is_cuda_compute_capability_at_least("8.0"): raise SkipTest( f"The dot algorithm '{algorithm}' requires CUDA compute " "capability >= 8.0.") elif algorithm not in { - lax.DotAlgorithm.Preset.DEFAULT, - lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, - lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM, - lax.DotAlgorithm.Preset.F32_F32_F32, - lax.DotAlgorithm.Preset.F64_F64_F64, + lax.DotAlgorithmPreset.DEFAULT, + lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32, + lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM, + lax.DotAlgorithmPreset.F32_F32_F32, + lax.DotAlgorithmPreset.F64_F64_F64, }: raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on GPU.") @@ -1111,12 +1120,8 @@ def testDotAlgorithm(self, algorithm, dtype): rhs_shape = (4, 3) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - self._CompileAndCheck(partial(lax.dot, algorithm=algorithm), args_maker) - # Check that accumulation type sets the output type - output = lax.dot(*args_maker(), algorithm=algorithm) - algorithm = lax_internal.canonicalize_dot_algorithm(algorithm) - expected_dtype = dtype if algorithm is None else algorithm.accumulation_type - self.assertEqual(output.dtype, expected_dtype) + self._CompileAndCheck(partial(lax.dot, precision=algorithm), args_maker) + self.assertEqual(lax.dot(*args_maker(), precision=algorithm).dtype, dtype) def testDotAlgorithmInvalidFloat8Type(self): if xla_bridge.using_pjrt_c_api(): @@ -1125,95 +1130,29 @@ def testDotAlgorithmInvalidFloat8Type(self): if jaxlib_version <= (0, 4, 33): raise SkipTest( "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + if jtu.test_device_matches(["cpu"]): + raise SkipTest("Not supported on CPU.") lhs_shape = (3, 4) rhs_shape = (4, 3) rng = jtu.rand_default(self.rng()) lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, dtypes.float8_e4m3fn) with self.assertRaisesRegex(ValueError, "The dot algorithm"): - lax.dot(lhs, rhs, algorithm="ANY_F8_ANY_F8_F32") + lax.dot(lhs, rhs, precision="ANY_F8_ANY_F8_F32") - @parameterized.parameters([ - ({"precision": lax.Precision.HIGHEST}, "The dot_general precision must be None or DEFAULT"), - ({"preferred_element_type": np.float32}, "The preferred_element_type and algorithm arguments"), - ]) - def testDotAlgorithmInvalidParameters(self, kwargs, pattern): + def testDotAlgorithmCasting(self): if xla_bridge.using_pjrt_c_api(): raise SkipTest( "The dot algorithm attribute is not supported by PJRT C API.") if jaxlib_version <= (0, 4, 33): raise SkipTest( "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + def fun(lhs, rhs): + return lax.dot(lhs, rhs, precision="F32_F32_F32") lhs_shape = (3, 4) rhs_shape = (4, 3) rng = jtu.rand_default(self.rng()) - lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) - with self.assertRaisesRegex(ValueError, pattern): - lax.dot(lhs, rhs, algorithm="F32_F32_F32", **kwargs) - - def testDotAlgorithmTransposeRequired(self): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") - if jaxlib_version <= (0, 4, 33): - raise SkipTest( - "The dot algorithm attribute is only supported for jaxlib >0.4.33.") - lhs_shape = (3, 4) - rhs_shape = (4, 3) - rng = jtu.rand_default(self.rng()) - lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) - fun = partial(lax.dot, algorithm="F32_F32_F32") - out = fun(lhs, rhs) - _, vjp_fun = jax.vjp(fun, lhs, rhs) - with self.assertRaisesRegex( - ValueError, "When a dot_general algorithm is specified"): - vjp_fun(out) - - @parameterized.parameters([ - ("F32_F32_F32", "F16_F16_F32"), - ("F32_F32_F32", ("F16_F16_F32", "F64_F64_F64")), - ]) - def testDotAlgorithmTranspose(self, algorithm, transpose_algorithm): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") - if jaxlib_version <= (0, 4, 33): - raise SkipTest( - "The dot algorithm attribute is only supported for jaxlib >0.4.33.") - def fun(x, y): - return lax.dot(x, y, algorithm=algorithm, - transpose_algorithm=transpose_algorithm) - - algorithm_ = lax_internal.canonicalize_dot_algorithm(algorithm) - lhs_alg, rhs_alg = lax_internal.canonicalize_dot_transpose_algorithm( - transpose_algorithm) - - lhs_shape = (3, 4) - rhs_shape = (4, 3) - rng = jtu.rand_default(self.rng()) - lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) - out = fun(lhs, rhs) - - def check_transpose_algorithm(f, arg, alg, trans_alg, trans_trans_alg): - fun_trans = jax.linear_transpose(f, arg) - jaxpr = jax.make_jaxpr(fun_trans)(out) - eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr.eqns)) - self.assertEqual(eqn.params["algorithm"], alg) - self.assertEqual(eqn.params["transpose_algorithm"], trans_alg) - - fun_ = jax.linear_transpose(lambda x: fun_trans(x)[0], out) - jaxpr_ = jax.make_jaxpr(fun_)(arg) - eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr_.eqns)) - self.assertEqual(eqn.params["algorithm"], algorithm_) - - # Note that transposing the RHS of a dot_general introduce extra - # transposes on the input and output, so we don't actually end up with - # the same `transpose_algorithm` parameter after 2 transposes. - self.assertEqual(eqn.params["transpose_algorithm"], trans_trans_alg) - - check_transpose_algorithm(partial(fun, y=rhs), lhs, lhs_alg, - (algorithm_, rhs_alg), (lhs_alg, rhs_alg)) - check_transpose_algorithm(partial(fun, lhs), rhs, rhs_alg, - (algorithm_, lhs_alg), (rhs_alg, lhs_alg)) + lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16) + self.assertEqual(fun(lhs, rhs).dtype, np.float16) @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) @@ -2686,6 +2625,18 @@ def testIndexTake(self, shape, dtype, idxs, axes): ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), + lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(1,), + start_index_map=(1,), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (1, 1)), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + lax.GatherDimensionNumbers( + offset_dims=(2,), collapsed_slice_dims=(), + start_index_map=(2,), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (1, 1, 3)) ]], dtype=lax_test_util.all_dtypes, ) @@ -2703,63 +2654,196 @@ def testGather(self, shape, dtype, idxs, dnums, slice_sizes): @parameterized.named_parameters( {"testcase_name": f"_{testcase_name}", "operand_shape": operand_shape, "indices_shape": indices_shape, - "dimension_numbers": lax.GatherDimensionNumbers( - offset_dims=offset_dims, - collapsed_slice_dims=collapsed_slice_dims, - start_index_map=start_index_map), + "dimension_numbers": dimension_numbers, "slice_sizes": slice_sizes, "msg": msg} - for (testcase_name, operand_shape, indices_shape, offset_dims, - collapsed_slice_dims, start_index_map, slice_sizes, msg) in [ + for (testcase_name, operand_shape, indices_shape, dimension_numbers, + slice_sizes, msg) in [ ("NonAscendingWindowIndices", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 6, 8, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "offset_dims in gather op must be sorted"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 8, 7), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "offset_dims in gather op must be sorted"), ("RepeatedWindowIndices", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 6, 7, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "offset_dims in gather op must not repeat"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 7), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "offset_dims in gather op must not repeat"), ("WindowIndexOutOfBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 100, 101, 102), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "Offset dimension 2 in gather op is out of bounds"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 100, 101, 102), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "Offset dimension 2 in gather op is out of bounds"), ("WindowIndexBarelyOutOfBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 6, 7, 9), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "Offset dimension 4 in gather op is out of bounds"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 9), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "Offset dimension 4 in gather op is out of bounds"), ("MismatchingElidedWindowDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (4,), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(4,), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), + ("All components of the offset index in a gather op must either be a " + "offset dimension or explicitly collapsed/batching")), + ("MismatchingElidedWindowDimsV2", (10, 9, 8, 7, 6, 5), (10, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(4,), + start_index_map=(1, 2, 3, 4), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (10, 9, 8, 7, 6, 5), ("All components of the offset index in a gather op must either be a " - "offset dimension or explicitly collapsed")), + "offset dimension or explicitly collapsed/batching")), ("OutOfBoundsWindowToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (0, 1, 2, 3, 19), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(0, 1, 2, 3, 19), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "Invalid collapsed_slice_dims set in gather op; valid range is"), ("RepeatedWindowToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (0, 1, 2, 3, 3), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "collapsed_slice_dims in gather op must not repeat"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(0, 1, 2, 3, 3), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "collapsed_slice_dims in gather op must not repeat"), ("MismatchingGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (), (0, 1, 2, 3), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3)), + (10, 9, 8, 7, 6), ("Gather op has 4 elements in start_index_map and the bound of " "dimension index_vector_dim=4 of indices is 5. These two " "numbers must be equal.")), ("OutOfBoundsGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (), (0, 1, 2, 3, 7), (10, 9, 8, 7, 6), - "Invalid start_index_map"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 7)), + (10, 9, 8, 7, 6), "Invalid start_index_map"), ("RepeatedGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (), (0, 1, 2, 3, 3), (10, 9, 8, 7, 6), - "start_index_map in gather op must not repeat"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 3)), + (10, 9, 8, 7, 6), "start_index_map in gather op must not repeat"), ("NonAscendingElidedWindowDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (2, 1), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(2, 1), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "collapsed_slice_dims in gather op must be sorted"), ("WindowBoundsTooLarge", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7), (2,), (0, 1, 2, 3, 4), (10, 9, 8, 100, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(2,), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 100, 6), "Slice size at index 3 in gather op is out of range"), ("MismatchingNumberOfWindowBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7), "Gather op must have one slice size for every input dimension"), ("WindowBoundsNot1ForElidedDim", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7), (1,), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(1,), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), ("Gather op can only collapse slice dims with bound 1, but bound " - "is 9 for index 1 at position 0.")) + "is 9 for index 1 at position 0.")), + ("RepeatedOperandBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 4), operand_batching_dims=(2, 3, 3)), + (10, 9, 8, 7, 6), + "operand_batching_dims in gather op must not repeat"), + ("NonAscendingOperandBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 4), operand_batching_dims=(3, 2)), + (10, 9, 8, 7, 6), + "operand_batching_dims in gather op must be sorted"), + ("OutOfBoundsOperandBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + operand_batching_dims=(0, 10)), + (10, 9, 8, 7, 6), + "Invalid operand_batching_dims set in gather op; valid range is"), + ("NonDisjointCollapsedAndBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1, 2), + start_index_map=(0, 1, 4), operand_batching_dims=(2, 3)), + (10, 9, 8, 7, 6), + ("collapsed_slice_dims and operand_batching_dims in gather op must be " + "disjoint")), + ("NonDisjointStartIndexMapAndBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 4), operand_batching_dims=(2, 3)), + (10, 9, 8, 7, 6), + ("start_index_map and operand_batching_dims in gather op must be " + "disjoint")), + ("WindowBoundsNot1ForBatchingDim", (10, 9, 8, 7, 6), (9, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(), + start_index_map=(0, 2, 3, 4), operand_batching_dims=(1,), + start_indices_batching_dims=(0,)), + (10, 9, 8, 7, 6), + ("Gather op can only have operand batching dims with bound 0/1, but " + "bound is 9 for index 1 at position 0.")), + ("RepeatedStartIndicesBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + start_indices_batching_dims=(0, 1, 0)), + (10, 9, 8, 7, 6), + "start_indices_batching_dims in gather op must not repeat"), + ("OutOfBoundsStartIndicesBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + start_indices_batching_dims=(0, 5)), + (10, 9, 8, 7, 6), + "Invalid start_indices_batching_dims set in gather op; valid range"), + ("IndexVectorDimInStartIndicesBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + start_indices_batching_dims=(0, 4)), + (10, 9, 8, 7, 6), + ("Gather op cannot have the index vector dimension as a batching " + "dimension")), + ("MismatchingNumberOfBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(1, 2), + start_index_map=(1, 2, 3, 4), operand_batching_dims=(0,), + start_indices_batching_dims=(0, 1)), + (10, 9, 8, 7, 6), + ("Gather op requires equal numbers of operand_batching_dims and " + "start_indices_batching_dims")), + ("MismatchingBatchingDimSizes", (10, 9, 8, 7, 6), (10, 9, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(2, 3, 4), + start_index_map=(2, 3, 4), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (10, 9, 8, 7, 6), + ("Gather op requires operand batching dimensions and indices batching " + "dimensions to have the same shape")) ] ) def testGatherShapeCheckingRule(self, operand_shape, indices_shape, dimension_numbers, slice_sizes, msg): + """ + + Args: + operand_shape: + indices_shape: + dimension_numbers: + slice_sizes: + msg: + """ operand = np.ones(operand_shape, dtype=np.int32) indices = np.ones(indices_shape, dtype=np.int32) @@ -2776,20 +2860,31 @@ def testGatherShapeCheckingRule(self, operand_shape, indices_shape, ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.inexact_dtypes, mode=["clip", "fill", None], + op=[lax.scatter_add, lax.scatter_sub], ) - def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, mode): + def testScatterAddSub(self, arg_shape, dtype, idxs, update_shape, dnums, mode, op): rng = jtu.rand_default(self.rng()) rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape)) rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype) args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(), rng(update_shape, dtype)] - fun = partial(lax.scatter_add, dimension_numbers=dnums, mode=mode) + fun = partial(op, dimension_numbers=dnums, mode=mode) self._CompileAndCheck(fun, args_maker) @jtu.sample_product( @@ -2802,9 +2897,19 @@ def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, mode): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2827,9 +2932,19 @@ def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2851,9 +2966,19 @@ def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2875,9 +3000,19 @@ def testScatterApply(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2895,84 +3030,207 @@ def testScatter(self, arg_shape, dtype, idxs, update_shape, dnums): # variations to account for the implicit setting of index_vector_dim in JAX. @parameterized.named_parameters( {"testcase_name": f"_{testcase_name}", "operand_shape": operand_shape, - "indices": indices, "update_shape": update_shape, - "dimension_numbers": lax.ScatterDimensionNumbers( - update_window_dims=update_window_dims, - inserted_window_dims=inserted_window_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims), + "indices_shape": indices_shape, "update_shape": update_shape, + "dimension_numbers": dimension_numbers, "msg": msg} - for (testcase_name, operand_shape, indices, update_shape, - update_window_dims, inserted_window_dims, - scatter_dims_to_operand_dims, msg) in [ - ("ScatterWithUpdatesBiggerThanInput", (64, 48), np.zeros((32, 1)), - (65, 32), (0,), (1,), (1,), "Bounds of the window dimensions"), - ("ScatterWithUpdatesBiggerThanInputV2", (64, 48), - np.zeros((32, 1)), (32, 49), (1,), (0,), (1,), + for (testcase_name, operand_shape, indices_shape, update_shape, + dimension_numbers, msg) in [ + ("ScatterWithUpdatesBiggerThanInput", (64, 48), (32, 1), (65, 32), + lax.ScatterDimensionNumbers( + update_window_dims=(0,), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,)), "Bounds of the window dimensions"), - ("ScatterWithUpdatesNotMatchingIndices", (64, 48), - np.zeros((32, 1)), (64, 31), (0,), (1,), (1,), + ("ScatterWithUpdatesBiggerThanInputV2", (64, 48), (32, 1), + (32, 49), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(1,)), + "Bounds of the window dimensions"), + ("ScatterWithUpdatesNotMatchingIndices", (64, 48), (32, 1), + (64, 31), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(1,)), "Bounds of the scatter dimensions"), - ("ScatterWithUpdatesNotMatchingIndicesV2", (64, 48), - np.zeros((32, 1)), (31, 48), (1,), (0,), (1,), + ("ScatterWithUpdatesNotMatchingIndicesV2", (64, 48), (32, 1), + (31, 48), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(1,)), "Bounds of the scatter dimensions"), ("ScatterNdWithUpdatesBiggerThanInput", (64, 48), - np.zeros((10, 9, 8, 7, 1)), (10, 9, 8, 7, 65), (4,), (1,), - (0,), "Bounds of the window dimensions"), + (10, 9, 8, 7, 1), (10, 9, 8, 7, 65), + lax.ScatterDimensionNumbers( + update_window_dims=(4,), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,)), + "Bounds of the window dimensions"), ("ScatterNdWithUpdatesNotMatchingIndices", (64, 48), - np.zeros((10, 9, 8, 7, 1)), (9, 9, 8, 7, 64), (4,), (1,), (0,), + (10, 9, 8, 7, 1), (9, 9, 8, 7, 64), + lax.ScatterDimensionNumbers( + update_window_dims=(4,), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(0,)), "Bounds of the scatter dimensions"), - ("InvalidUpdates", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4, 1), - (4, 5, 6), (1, 2), (0, 1, 2, 3, 4), + ("InvalidUpdates", (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), + (10, 9, 8, 7, 3, 2, 4, 1), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "Updates tensor must be of rank 7; got 8."), - ("NonAscendingUpdateWindowDims", (6, 5, 4, 3, 2), - np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), - (4, 5, 6, 8, 7), (), (0, 1, 2, 3, 4), + ("NonAscendingUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), + (10, 9, 8, 7, 6, 5, 4, 3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6, 8, 7), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "update_window_dims in scatter op must be sorted"), - ("RepeatedUpdateWindowDims", (6, 5, 4, 3, 2), - np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), - (4, 5, 6, 7, 7), (), (0, 1, 2, 3, 4), + ("RepeatedUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), + (10, 9, 8, 7, 6, 5, 4, 3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6, 7, 7), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "update_window_dims in scatter op must not repeat"), - ("OutOfBoundsUpdateWindowDims", (6, 5, 4, 3, 2), - np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), - (4, 5, 6, 7, 9), (), (0, 1, 2, 3, 4), + ("OutOfBoundsUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), + (10, 9, 8, 7, 6, 5, 4, 3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6, 7, 9), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "Invalid update_window_dims set in scatter op"), ("NonAscendingInsertedWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (2, 1), (0, 1, 2, 3, 4), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(2, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "inserted_window_dims in scatter op must be sorted"), ("RepeatedInsertedWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 1), (0, 1, 2, 3, 4), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "inserted_window_dims in scatter op must not repeat"), ("OutOfBoundsInsertedWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 5), (0, 1, 2, 3, 4), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 5), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "Invalid inserted_window_dims set in scatter op"), ("MismatchingScatterDimsToOperandDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 2), (0, 1, 2, 3), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 3)), ("Scatter op has 4 elements in scatter_dims_to_operand_dims and " "the bound of dimension index_vector_dim=4 of indices " "is 5. These two numbers must be equal")), ("OutOfBoundsScatterDimsToOperandDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 2), (0, 1, 2, 3, 10), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 10)), "Invalid scatter_dims_to_operand_dims mapping"), ("RepeatedValuesInScatterDimsToOperandDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 2), (0, 1, 2, 2, 3), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 2, 3)), "scatter_dims_to_operand_dims in scatter op must not repeat"), ("InsufficientWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1,), (0, 1, 2, 3), + (10, 9, 8, 7, 4), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(0, 1, 2, 3)), ("Scatter op has window of size 4; doesn't match operand of " - "rank 5.")) + "rank 5.")), + ("InsufficientWindowDimsV2", (10, 49, 48, 47, 46, 45), + (10, 9, 8, 7, 3), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1, 2, 3), + operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,)), + ("Scatter op has window of size 5; doesn't match operand of " + "rank 6.")), + ("RepeatedOperandBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 4), + operand_batching_dims=(2, 3, 3)), + "operand_batching_dims in scatter op must not repeat"), + ("NonAscendingOperandBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 4), + operand_batching_dims=(3, 2)), + "operand_batching_dims in scatter op must be sorted"), + ("OutOfBoundsOperandBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + operand_batching_dims=(0, 10)), + ("Invalid operand_batching_dims set in scatter op; valid range " + "is")), + ("NonDisjointCollapsedAndBatchingDims", (50, 49, 48, 47, 46, 45), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 4), + operand_batching_dims=(1, 2)), + ("inserted_window_dims and operand_batching_dims in scatter op " + "must be disjoint")), + ("NonDisjointScatterDimsToOperandDimsAndBatchingDims", + (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), + (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 4), + operand_batching_dims=(2, 3)), + ("scatter_dims_to_operand_dims and operand_batching_dims in " + "scatter op must be disjoint")), + ("RepeatedScatterIndicesBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + scatter_indices_batching_dims=(0, 1, 0)), + "scatter_indices_batching_dims in scatter op must not repeat"), + ("OutOfBoundsScatterIndicesBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + scatter_indices_batching_dims=(0, 5)), + ("Invalid scatter_indices_batching_dims set in scatter op; " + "valid range")), + ("IndexVectorDimInScatterIndicesBatchingDims", + (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), + (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + scatter_indices_batching_dims=(0, 4)), + ("Scatter op cannot have the index vector dimension as a " + "batching dimension")), + ("MismatchingNumberOfBatchingDims", (50, 49, 48, 47, 46, 45), + (10, 9, 8, 7, 4), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(1, 2, 3, 4), + operand_batching_dims=(0,), + scatter_indices_batching_dims=(0, 1)), + ("Scatter op requires equal numbers of operand_batching_dims " + "and scatter_indices_batching_dims")), + ("MismatchingBatchingDimSizes", (10, 9, 48, 47, 46, 45), + (10, 9, 8, 7, 2), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(2,), + scatter_dims_to_operand_dims=(2, 3), + operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0)), + ("Scatter op requires operand batching dimensions and indices " + "batching dimensions to have the same shape")) ] ) - def testScatterShapeCheckingRule(self, operand_shape, indices, + def testScatterShapeCheckingRule(self, operand_shape, indices_shape, update_shape, dimension_numbers, msg): - + indices = np.zeros(indices_shape, dtype=np.int32) def f(x, y): operand = lax.broadcast(x, operand_shape) updates = lax.broadcast(y, update_shape) @@ -3974,8 +4232,7 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind): size_im = 11 atol = None - if (name in {"arccos", "arcsin", "arcsinh", "arccosh"} - or name in {"arctan", "arctanh"} and jax._src.lib.version > (0, 4, 31)): + if name in {"arccos", "arcsin", "arcsinh", "arccosh", "arctan", "arctanh"}: # TODO(pearu): eliminate this if-block when a fix to mpmath#787 # becomes available extra_prec_multiplier = 20 @@ -4131,16 +4388,6 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'arccos': regions_with_inaccuracies_keep('q4.imag', 'ninf', 'pinf', 'ninfj', 'pinfj.real') - elif name == 'arctan' and jax._src.lib.version <= (0, 4, 31): - if dtype == np.complex64: - regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', - 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.real', 'mnegj.imag', 'mposj.imag') - if dtype == np.complex128: - regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mnegj.real') - - elif name == 'arctanh' and jax._src.lib.version <= (0, 4, 31): - regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag') - elif name in {'cos', 'sin'}: regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 37a0011e7bd0..83d4d657751b 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -32,7 +32,6 @@ from jax._src import test_util as jtu from jax._src.internal_test_util import lax_test_util from jax._src.lax import windowed_reductions as lax_windowed_reductions -from jax._src.lib import xla_client from jax._src.util import safe_map, safe_zip jax.config.parse_flags_with_absl() @@ -546,7 +545,7 @@ def testFft(self, fft_ndims, shape, bdims): ndims = len(shape) axes = range(ndims - fft_ndims, ndims) fft_lengths = tuple(shape[axis] for axis in axes) - op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths) + op = lambda x: lax.fft(x, lax.FftType.FFT, fft_lengths) self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng, rtol=1e-5) @@ -566,6 +565,18 @@ def testFft(self, fft_ndims, shape, bdims): ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), + lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(1,), + start_index_map=(1,), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (1, 1)), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + lax.GatherDimensionNumbers( + offset_dims=(2,), collapsed_slice_dims=(), + start_index_map=(2,), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (1, 1, 3)) ] for bdims in lax_test_util.all_bdims(shape, idxs.shape)], dtype=lax_test_util.all_dtypes @@ -590,6 +601,16 @@ def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ] for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape, update_shape)], dtype=lax_test_util.float_dtypes @@ -613,6 +634,16 @@ def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ] for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape)], dtype=lax_test_util.float_dtypes, diff --git a/tests/layout_test.py b/tests/layout_test.py index 1d18179ccfee..600be653da3b 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -25,7 +25,6 @@ from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -46,9 +45,6 @@ def setUp(self): super().setUp() def test_auto_layout(self): - # Remove this condition when xla_extension_version >= 285 - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: - self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape1 = (128, 128) shape2 = (128, 128) @@ -114,9 +110,6 @@ def init(x, y): self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T) def test_default_layout(self): - # Remove this condition when xla_extension_version >= 285 - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: - self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -156,9 +149,6 @@ def f(x): out_shardings=DLL.AUTO).lower(sds).compile() def test_in_layouts_out_layouts(self): - # Remove this condition when xla_extension_version >= 285 - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: - self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (8, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -183,9 +173,6 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): - # Remove this condition when xla_extension_version >= 285 - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: - self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -477,9 +464,6 @@ def test_incompatible_aval_error_device_put(self): jax.device_put(inp, l) def test_concrete_layout_in_shardings(self): - # Remove this condition when xla_extension_version >= 285 - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: - self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) shape = (16, 128) @@ -600,6 +584,23 @@ def f(x): f.lower(sds).compile()(arr) self.assertFalse(arr.is_deleted()) + def test_donation_error_on_auto(self): + @partial(jax.jit, donate_argnums=0, in_shardings=Layout(DLL.AUTO)) + def f(x): + return x * 2 + + with self.assertRaisesRegex( + ValueError, ".*Did you mean to set the.*output layout.*AUTO.*"): + f(jnp.arange(8)) + + @partial(jax.jit, donate_argnums=0, out_shardings=Layout(DLL.AUTO)) + def g(x): + return x * 2 + + with self.assertRaisesRegex( + ValueError, ".*Did you mean to set the.*input layout.*AUTO.*"): + g(jnp.arange(8)) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index e52582eb7526..5ace4b5ecf18 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1007,7 +1007,7 @@ def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16): if jtu.test_device_matches(['cpu']): err, msg = NotImplementedError, "Unsupported dtype float16" else: - err, msg = ValueError, r"Unsupported dtype dtype\('float16'\)" + err, msg = Exception, "Unsupported dtype" with self.assertRaisesRegex(err, msg): jnp.linalg.qr(arr) diff --git a/tests/logging_test.py b/tests/logging_test.py index 70f619de5ee6..a1d6695a1e37 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -92,10 +92,14 @@ def test_no_log_spam(self): python = sys.executable assert "python" in python env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"} + if os.getenv("ASAN_OPTIONS"): + env_variables["ASAN_OPTIONS"] = os.getenv("ASAN_OPTIONS") if os.getenv("PYTHONPATH"): env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH") if os.getenv("LD_LIBRARY_PATH"): env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH") + if os.getenv("LD_PRELOAD"): + env_variables["LD_PRELOAD"] = os.getenv("LD_PRELOAD") # Make sure C++ logging is at default level for the test process. proc = subprocess.run( [python, f.name], diff --git a/tests/memories_test.py b/tests/memories_test.py index 6959aa7535b8..781172f880c5 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -25,6 +25,7 @@ from jax import lax from jax._src import test_util as jtu from jax._src import xla_bridge as xb +from jax._src.lib import xla_extension_version from jax._src import config from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp @@ -35,7 +36,6 @@ TransferToMemoryKind, PartitionSpec as P) from jax.experimental.compute_on import compute_on from jax.experimental.shard_map import shard_map -from jax._src.lib import xla_extension_version import numpy as np config.parse_flags_with_absl() @@ -416,8 +416,6 @@ def f(a, b): out, np_inp * np_inp, s_dev, "device") def test_parameter_streaming(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") _, s_host, np_inp, inp_host = _create_inputs( (8, 2), P("x", "y"), mem_kind="pinned_host") s_dev = s_host.with_memory_kind('device') @@ -444,7 +442,7 @@ def test_zero_size_parameter(self): if jtu.test_device_matches(["gpu"]): self.skipTest("This test does not work on GPU backend.") _, s_host, np_inp, inp_host = _create_inputs( - (0,), P("x"), mem_kind="pinned_host") + (0,), P(), mem_kind="pinned_host") s_dev = s_host.with_memory_kind('device') @functools.partial(jax.jit, out_shardings=s_host) @@ -461,8 +459,6 @@ def f(a): out, np_inp, s_host, 'pinned_host') def test_parameter_streaming_with_scalar_and_constant(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_mesh((2, 2), ("x", "y")) scalar_inp = 1 s_host = NamedSharding(mesh, P(), memory_kind="pinned_host") @@ -512,8 +508,6 @@ def f(x): ) def test_parameter_and_output_streaming_with_scalar(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test is flaky on GPU backend.") if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") @@ -581,8 +575,6 @@ def body(carry, x): self.assertEqual(out_hbm.sharding, out_s) def test_output_streaming(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test is flaky on GPU backend.") mesh = jtu.create_mesh((1, 1), ("x", "y")) np_inp = np.arange(16.0).reshape(8, 2) s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device") @@ -599,8 +591,6 @@ def f(xs): self.assertEqual(out_host.sharding, s_host) def test_weight_offload_with_dp_on_output(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test is flaky on GPU backend.") _, s_dev, np_inp, inp_dev = _create_inputs( (8, 2), P("x", "y"), mem_kind="device") s_host = s_dev.with_memory_kind('pinned_host') @@ -616,8 +606,6 @@ def f(x): out_host, np_inp * 2, s_host, 'pinned_host') def test_output_streaming_inside_scan(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) @@ -650,8 +638,6 @@ def test_deepcopy(self): self.assertEqual(t.shape, t_copy.shape) def test_close_over_host_constant_and_stream(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") _, s_host, np_inp, inp_host = _create_inputs( (8, 2), P("x", "y"), mem_kind="pinned_host") @@ -666,6 +652,51 @@ def f(): out = f() self._check_device_put_addressable_shards(out, np_inp * 2, s_dev, 'device') + @jtu.run_on_devices('tpu') + def test_ragged_copy_on_host(self): + if xla_extension_version < 290: + self.skipTest('Requires xla_extension_version >= 290') + mesh = jtu.create_mesh((2,), ('x')) + sharding = jax.sharding.NamedSharding(mesh, P(('x'))) + cpu_sharding = sharding.with_memory_kind('pinned_host') + + num_pages = 512 * 1024 + page_size = 1024 + + x = jnp.full((num_pages, page_size), 1, dtype=jnp.bfloat16, device=sharding) + + def write(x): + return x.at[16 * 1024:].set(0) + x = shard_map(write, mesh, P(('x'),), P(('x')))(x) + + chunk_size = 8 + def inner(state): + idx, x, output = state + chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size) + chunk_host = jax.device_put(chunk, TransferToMemoryKind('pinned_host')) + output = jax.lax.dynamic_update_slice_in_dim( + output, chunk_host, idx * chunk_size, axis=0) + return (idx + 1, x, output) + + def cond(state): + idx, x, _ = state + chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size) + return (idx * chunk_size < x.shape[0]) & jnp.any(chunk > 0) + + def foo(x): + output = jnp.zeros_like(x, device=cpu_sharding) + _, _, cpu_x = jax.lax.while_loop(cond, inner, (0, x, output)) + return cpu_x + + fn = jax.jit(shard_map(foo, mesh, P(('x'),), P(('x')), + check_rep=False), + out_shardings=cpu_sharding) + y = fn(x) + jax.block_until_ready(y) + compiled_text = fn.lower(x).compile().as_text() + if compiled_text is not None: + self.assertIn('custom_call_target="AllocateBuffer"', compiled_text) + class ComputeOffload(jtu.BufferDonationTestCase): @@ -699,6 +730,9 @@ def init(): def test_compute_no_inputs_host_replicated(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3: self.skipTest("This test requires an xla_version >= 3.") + if config.use_shardy_partitioner.value: + self.skipTest("XLA failure due to b/370786664 and b/366411266. " + "Enable when fixed.") mesh = jtu.create_mesh((4,), ('data')) tpu_sharding = NamedSharding(mesh, P('data')) @@ -706,8 +740,8 @@ def test_compute_no_inputs_host_replicated(self): @functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding)) def init(): - tpu_array = jax.random.normal(jax.random.key(42), (16,16)) - cpu_array = jax.random.normal(jax.random.key(42), (16,16)) + tpu_array = jax.random.normal(jax.random.key(42), (16, 16)) + cpu_array = jax.random.normal(jax.random.key(42), (16, 16)) return tpu_array, cpu_array tpu_array, cpu_array = init() @@ -799,6 +833,20 @@ def h(x): self.assertArraysEqual(out2, np.sum(inp) * np.sum(inp)) self.assertEqual(out2.sharding.memory_kind, 'pinned_host') + def test_compute_host_loop(self): + @compute_on('device_host') + @jax.jit + def fn(): + k = jax.random.key(0) + return jax.nn.initializers.lecun_normal()(k, (2, 2), jnp.float32) + fn() # doesn't crash + + @compute_on('device_host') + def fn(): + k = jax.random.key(0) + return jax.nn.initializers.lecun_normal()(k, (2, 2), jnp.float32) + fn() # doesn't crash + def test_nested_compute_error(self): @compute_on('device') @jax.jit @@ -1200,6 +1248,8 @@ def test_jit_cpp_cache_hit(self): self.assertArraysEqual(out2, np_inp @ np_inp.T) def test_jit_compilation_cache_hit(self): + if config.use_shardy_partitioner.value: + self.skipTest("Shardy doesn't support GSPMDSharding") mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y")) inp2 = jax.device_put( np_inp, GSPMDSharding(tuple(mesh.devices.flat), @@ -1351,6 +1401,9 @@ def h(x): self.assertArraysAllClose(out, expected_out, rtol=1e-3) def test_mem_kind_donation_pinned_host(self): + if config.use_shardy_partitioner.value: + self.skipTest("XLA failure due to b/370786664 and b/366411266. " + "Enable when fixed.") mesh = jtu.create_mesh((2,), "x") s = NamedSharding(mesh, P(), memory_kind='pinned_host') s_dev = s.with_memory_kind('device') @@ -1389,29 +1442,6 @@ def f(inp): self.assertIn("input_output_alias", lowered_text) self.assertDeleted(x) - @jtu.run_on_devices('tpu') - def test_aot_device_implicit_transfer(self): - mesh = jtu.create_mesh((1,), 'x') - np_inp = np.arange(8) - arr = jax.device_put(np_inp, NamedSharding(mesh, P())) - - @jax.jit - def f(x): - return x * 2 - - compiled = f.lower(arr).compile() - - cpu_dev = jax.devices('cpu')[0] - with jax.default_device(cpu_dev): - cpu_arr = jnp.arange(8) - self.assertEqual(cpu_arr.sharding, SingleDeviceSharding(cpu_dev)) - self.assertFalse(cpu_arr._committed) - - out = compiled(cpu_arr) - self.assertArraysEqual(out, np_inp * 2) - self.assertEqual(out.sharding, NamedSharding(mesh, P())) - self.assertEqual(out.sharding.memory_kind, 'device') - def test_compute_offload_with_donation(self): sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) p_sharding = jax.sharding.SingleDeviceSharding( @@ -1460,6 +1490,23 @@ def f(x): # 2 for `f` and `2` for `mul` (compute type changes for `mul`) self.assertEqual(count[0], 4) + def test_offload_take_host(self): + @compute_on('device_host') + @jax.jit + def peer_forward(x, experts, indices, scores): + w = jnp.take(experts, indices.astype(int), axis=0) + w_gate, w_down, w_up = w[..., 0], w[..., 1], w[..., 2] + g = jnp.einsum('btd, bthkd->bthk', x, w_gate) + x = jnp.einsum('btd, bthkd->bthk', x, w_down) + x = x * jax.nn.gelu(g) * scores + return jnp.einsum('bthk, bthkd->btd', x, w_up) + + x = jnp.ones((16, 4, 32)) + experts = jnp.ones((128, 32, 3)) + indices = jnp.ones((16, 4, 4, 2), dtype=jnp.int32) + scores = jnp.ones((16, 4, 4, 2)) + jax.jit(peer_forward)(x, experts, indices, scores) # doesn't crash + class ActivationOffloadingTest(jtu.JaxTestCase): @@ -1568,8 +1615,6 @@ def g(ys, _): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_scan_layout_change_offloadable(self): - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289: - self.skipTest("Requires xla_extension_version >= 289") mesh = jtu.create_mesh((2,), ("x",)) shape = (256, 128) np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f949b63c7844..cd99416933f1 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -297,6 +297,23 @@ def kernel(ctx, inp, out, _): )(x) np.testing.assert_array_equal(y, x.reshape(out_ty.shape)) + @parameterized.named_parameters( + ("drop_1s", (1, 1, 5, 1, 1, 2, 1, 1), (5, 1, 2)), + ("add_1s", (5, 1, 2), (1, 1, 5, 1, 1, 2, 1, 1)), + ("fold", (1, 5, 2, 1,), (1, 10, 1)), + ("un", (1, 10, 1), (1, 5, 2, 1,)), + ) + def test_reshape(self, inp_shape, out_shape): + def kernel(ctx, inp, out, _): + copy(memref_reshape(inp, out_shape), out) + + x = np.arange(math.prod(inp_shape), dtype=jnp.float32).reshape(*inp_shape) + out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () + )(x) + np.testing.assert_array_equal(y, x.reshape(*out_shape)) + @parameterized.named_parameters([ ("packed", (4, 4, 4), (16, 4, 1), 1, 2, False), ("strided_end", (4, 4, 4, 4), (256, 64, 16, 4), 1, 2, False), diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index fe7ddd618ee1..d3e32873c597 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -223,5 +223,14 @@ def body_fun(_, index_x): _, xs = doit() self.assertAllClose(xs, (np.arange(5) * 2), check_dtypes=False) + def test_double_jit_mutable_array(self): + @jax.jit + @jax.jit + def f(): + x_ref = core.mutable_array(jnp.zeros(8)) + return x_ref[...] + x = f() + self.assertArraysEqual(x, jnp.zeros(8)) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/nn_test.py b/tests/nn_test.py index d6153d32c63e..df719256a921 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -50,6 +50,8 @@ def _check_cudnn_backend(fn, *args, **kwargs): hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) return '__cudnn$fmha' in hlo +_cudnn_dbias_error = 'cuDNN only supports bias gradient' + @jtu.with_config(jax_legacy_prng_key="allow", jax_numpy_dtype_promotion="standard") class NNFunctionsTest(jtu.JaxTestCase): @@ -167,6 +169,63 @@ def testDotProductAttentionMask(self, mask_mode): self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03) + @parameterized.product( + batch_size=[1, 16], + use_vmap=[False, True], + ) + def testDotProductAttentionBiasGradient(self, batch_size, use_vmap): + if not _is_required_cudnn_version_satisfied(8904): + raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") + + dtype = jnp.bfloat16 + B, S, N, H = batch_size, 128, 4, 32 + keys = random.split(random.PRNGKey(0), 2) + x = random.normal(keys[0], (B, S, N, H), dtype) + bias = random.normal(keys[1], (B, N, S, S), dtype=dtype) + mask = jnp.ones((1, 1, S), dtype=jnp.bool_) + + def attention(x, bias, mask, impl): + return jax.nn.dot_product_attention( + query=x, + key=x, + value=x, + bias=bias, + mask=mask, + is_causal=False, + implementation=impl, + ) + attn_ref = partial(attention, impl=None) + attn_ans = partial(attention, impl='cudnn') + if use_vmap: + attn_batched_ref = jax.vmap(attn_ref, in_axes=(0, 0, None)) + attn_batched_ans = jax.vmap(attn_ans, in_axes=(0, 0, None)) + else: + attn_batched_ref = attn_ref + attn_batched_ans = attn_ans + + fwd_ref = jax.jit(attn_batched_ref) + fwd_ans = jax.jit(attn_batched_ans) + y_ref = fwd_ref(x, bias, mask) + y_ans = fwd_ans(x, bias, mask) + self.assertAllClose(y_ref, y_ans) + + @jax.jit + def bwd_ref(x, bias, mask): + _, f_vjp = jax.vjp(attn_ref, x, bias, mask) + return f_vjp(x) + @jax.jit + def bwd_ans(x, bias, mask): + _, f_vjp = jax.vjp(attn_ans, x, bias, mask) + return f_vjp(x) + + if batch_size != 1: + with self.assertRaisesRegex(ValueError, _cudnn_dbias_error): + _, dbias_ans, _ = bwd_ans(x, bias, mask) + else: + _, dbias_ref, _ = bwd_ref(x, bias, mask) + _, dbias_ans, _ = bwd_ans(x, bias, mask) + self.assertAllClose(dbias_ans, dbias_ref, rtol=.03, atol=.03) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSoftplusGrad(self): check_grads(nn.softplus, (1e-8,), order=4, diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 044f82067510..b5af90272510 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -16,6 +16,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", "jax_multiplatform_test", + "jax_py_test", "py_deps", ) @@ -33,11 +34,6 @@ jax_multiplatform_test( srcs = [ "pallas_test.py", ], - config_tags_overrides = { - "gpu_a100_x32": { - "ondemand": False, # Include in presubmit. - }, - }, enable_backends = [ "cpu", "tpu", @@ -85,11 +81,6 @@ jax_multiplatform_test( srcs = [ "ops_test.py", ], - config_tags_overrides = { - "gpu_a100_x32": { - "ondemand": False, # Include in presubmit. - }, - }, disable_configs = [ "gpu_v100", "gpu_x32", @@ -145,11 +136,6 @@ jax_multiplatform_test( srcs = [ "pallas_vmap_test.py", ], - config_tags_overrides = { - "gpu_a100_x32": { - "ondemand": False, # Include in presubmit. - }, - }, enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", @@ -170,11 +156,6 @@ jax_multiplatform_test( srcs = [ "mosaic_gpu_test.py", ], - config_tags_overrides = { - "gpu_h100_x32": { - "ondemand": False, # Include in presubmit. - }, - }, enable_backends = [], enable_configs = [ "gpu_h100_x32", @@ -184,7 +165,7 @@ jax_multiplatform_test( }, deps = [ "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_mosaic_gpu", # build_cleaner: keep "//jax/_src/pallas/mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), ) @@ -192,11 +173,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "export_back_compat_pallas_test", srcs = ["export_back_compat_pallas_test.py"], - config_tags_overrides = { - "gpu_a100_x32": { - "ondemand": False, # Include in presubmit. - }, - }, enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", @@ -215,11 +191,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "export_pallas_test", srcs = ["export_pallas_test.py"], - config_tags_overrides = { - "gpu_a100_x32": { - "ondemand": False, # Include in presubmit. - }, - }, enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", @@ -235,11 +206,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "pallas_shape_poly_test", srcs = ["pallas_shape_poly_test.py"], - config_tags_overrides = { - "gpu_a100_x32": { - "ondemand": False, # Include in presubmit. - }, - }, disable_configs = [ "gpu_x32", "gpu_h100", @@ -269,9 +235,7 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_tpu", "//jax/_src/pallas/mosaic:random", - "//third_party/py/absl/testing:absltest", - "//third_party/py/absl/testing:parameterized", - ] + py_deps("numpy"), + ] + py_deps("absl/testing") + py_deps("numpy"), ) jax_multiplatform_test( @@ -280,6 +244,9 @@ jax_multiplatform_test( "tpu_all_gather_test.py", ], enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5e_4x2", + ], deps = [ "//jax:pallas_tpu_ops", ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), @@ -313,6 +280,10 @@ jax_multiplatform_test( # The flag is necessary for ``pl.debug_print`` tests to work on TPU. args = ["--logtostderr"], enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5e", + "tpu_v5p_1x1", + ], deps = [ "//jax:extend", "//jax:pallas_tpu", @@ -341,6 +312,12 @@ jax_multiplatform_test( name = "tpu_pallas_distributed_test", srcs = ["tpu_pallas_distributed_test.py"], enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5e_4x2", + "tpu_v5p_2x2", + "tpu_v4_2x2", + "tpu_v3_2x2", + ], deps = [ "//jax:extend", "//jax:pallas_tpu", @@ -352,6 +329,10 @@ jax_multiplatform_test( name = "tpu_pallas_pipeline_test", srcs = ["tpu_pallas_pipeline_test.py"], enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5e_4x2", + "tpu_v5p_1x1", + ], shard_count = 5, tags = [ "noasan", # Times out. @@ -369,7 +350,9 @@ jax_multiplatform_test( name = "tpu_pallas_async_test", srcs = ["tpu_pallas_async_test.py"], enable_backends = ["tpu"], - tags = [ + enable_configs = [ + "tpu_v5e_4x2", + "tpu_v5p_1x1", ], deps = [ "//jax:pallas_tpu", @@ -377,8 +360,8 @@ jax_multiplatform_test( ) jax_multiplatform_test( - name = "tpu_pallas_mesh_test", - srcs = ["tpu_pallas_mesh_test.py"], + name = "tpu_pallas_state_test", + srcs = ["tpu_pallas_state_test.py"], enable_backends = ["tpu"], tags = [ "noasan", @@ -401,9 +384,7 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_tpu", "//jax/_src/pallas/mosaic:random", - "//third_party/py/absl/testing:absltest", - "//third_party/py/absl/testing:parameterized", - ] + py_deps("numpy"), + ] + py_deps("absl/testing") + py_deps("numpy"), ) jax_multiplatform_test( @@ -438,17 +419,16 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_multiplatform_test( +# This test doesn't need a TPU; it only tests numpy-using helpers. +jax_py_test( name = "tpu_splash_attention_mask_test", srcs = [ "tpu_splash_attention_mask_test.py", ], - enable_backends = [ - "cpu", - "tpu", - ], deps = [ + "//jax", "//jax:pallas_tpu_ops", + "//jax:test_util", ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) @@ -457,11 +437,6 @@ jax_multiplatform_test( srcs = [ "gpu_attention_test.py", ], - config_tags_overrides = { - "gpu_a100_x32": { - "ondemand": False, # Include in presubmit. - }, - }, enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", @@ -480,11 +455,6 @@ jax_multiplatform_test( srcs = [ "gpu_ops_test.py", ], - config_tags_overrides = { - "gpu_a100_x32": { - "ondemand": False, # Include in presubmit. - }, - }, enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b35658ed4845..bbc2c9298a6a 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -14,14 +14,15 @@ import functools import math +import traceback from absl.testing import absltest from absl.testing import parameterized import jax from jax._src import config from jax._src import test_util as jtu -import jax._src.pallas.mosaic_gpu as plgpu from jax.experimental import pallas as pl +from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np @@ -42,16 +43,34 @@ def setUp(self): class PallasCallTest(PallasTest): - def test_add_one(self): + @parameterized.named_parameters( + ("add_one", lambda x: x + 1.), + ("logistic", jax.lax.logistic), + ("square", lambda x: x ** 2), + ("rsqrt", jax.lax.rsqrt), + ) + def test_unary_ops(self, unary): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, o_ref): - o_ref[...] = x_ref[...] + 1.0 + o_ref[...] = unary(x_ref[...]) x = jnp.arange(256).astype(jnp.float32) - np.testing.assert_array_equal(kernel(x), x + 1.0) + np.testing.assert_array_equal(kernel(x), unary(x)) + + def test_add_first(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[0] + + x = jnp.arange(256).astype(jnp.float32) + y = jnp.flip(x).reshape(1, 256) + np.testing.assert_array_equal(kernel(x, y), x + y[0]) def test_add_xy(self): @functools.partial( @@ -65,6 +84,19 @@ def kernel(x_ref, y_ref, o_ref): y = x + 1 np.testing.assert_array_equal(kernel(x, y), x + y) + def test_add_xy_indexed(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + idx = jnp.sum(y_ref[...]) + o_ref[...] = x_ref[idx] + + x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32) + y = jnp.zeros(128, dtype=jnp.int32) + np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)]) + def test_add_one_grid(self): @functools.partial( pl.pallas_call, @@ -96,7 +128,7 @@ def kernel(x_ref, o_ref, scratch_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) - @parameterized.product(max_concurrent_steps=[1, 2, 3, 4]) + @parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16]) def test_add_one_grid_pipelined(self, max_concurrent_steps): @functools.partial( @@ -116,42 +148,129 @@ def kernel(x_ref, o_ref): x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) - def test_add_one_with_async_copy_smem_to_gmem(self): + def test_add_one_grid_pipelined_program_id(self): + @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)), + out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32), + compiler_params=plgpu.GPUCompilerParams( + dimension_semantics=["parallel", "sequential"], + max_concurrent_steps=2, + ), + grid=(4, 4), + ) + def kernel(o_ref): + o_ref[...] = jnp.broadcast_to(pl.program_id(1), o_ref.shape) + + np.testing.assert_array_equal( + kernel(), + jnp.repeat(jnp.repeat(jnp.arange(4), 16)[None], 16, axis=0), + ) + + def test_add_one_grid_pipelined_sequential_invariant_output(self): + @functools.partial( + pl.pallas_call, + in_specs=[pl.BlockSpec((32, 16), lambda i, j: (i, j))], + out_specs=pl.BlockSpec((32, 16), lambda i, j: (i, 0)), + out_shape=jax.ShapeDtypeStruct([32 * 2, 64], jnp.float32), + compiler_params=plgpu.GPUCompilerParams( + dimension_semantics=["parallel", "sequential"], + max_concurrent_steps=2, + ), + grid=(2, 4), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1.0 + + x = jnp.arange(32 * 2 * 64).reshape((32 * 2, 64)).astype(jnp.float32) + y = jnp.empty_like(x) + for i in range(2): + i_slice = slice(32 * i, 32 * (i + 1)) + for j in range(4): + j_slice = slice(16 * j, 16 * (j + 1)) + y = y.at[i_slice, :16].set(x[i_slice, j_slice] + 1) + + # We only compare the elements in the first 16 columns, because the rest + # are never written to. + np.testing.assert_array_equal(kernel(x)[:, :16], y[:, :16]) + + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) + def test_copy_smem_to_gmem(self, indexer): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), - scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], + scratch_shapes=[plgpu.SMEM((256,), jnp.float32)], ) def kernel(x_ref, o_ref_gmem, scratch_ref): scratch_ref[...] = x_ref[...] + 1 - plgpu.async_copy_smem_to_gmem(scratch_ref, o_ref_gmem) + plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref_gmem.at[indexer]) plgpu.wait_smem_to_gmem(0) - x = jnp.arange(128).astype(jnp.float32) - np.testing.assert_array_equal(kernel(x), x + 1.0) + x = jnp.arange(256).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0) - def test_add_one_with_async_copy_gmem_to_smem(self): + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) + def test_copy_gmem_to_smem(self, indexer): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM((256,), jnp.float32), + plgpu.Barrier(num_arrivals=1), + ], + ) + def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): + plgpu.copy_gmem_to_smem( + x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier=barrier_ref + ) + plgpu.wait_barrier(barrier_ref) + o_ref[...] = scratch_ref[...] + 1 + x = jnp.arange(256).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0) + + @parameterized.product(indexer=[0, 1, 2, 3]) + def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ plgpu.SMEM((128,), jnp.float32), - plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(num_arrivals=1, num_barriers=4), ], ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): - plgpu.async_copy_gmem_to_smem( - x_ref_gmem, scratch_ref, barrier=barrier_ref + plgpu.copy_gmem_to_smem( + x_ref_gmem, scratch_ref, barrier=barrier_ref.at[indexer] ) - plgpu.wait_barrier(barrier_ref) + plgpu.wait_barrier(barrier_ref.at[indexer]) o_ref[...] = scratch_ref[...] + 1 x = jnp.arange(128).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) + def test_copy_gmem_to_smem_in_run_scoped(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + ) + def kernel(x_ref_gmem, o_ref): + def body(barrier_ref): + def inner_body(scratch_ref): + plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier=barrier_ref) + plgpu.wait_barrier(barrier_ref) + o_ref[...] = scratch_ref[...] + 1 + pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32)) + pl.run_scoped(body, plgpu.Barrier(num_arrivals=1)) + + x = jnp.arange(256).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + def test_add_doubled_sum(self): @functools.partial( pl.pallas_call, @@ -274,7 +393,7 @@ def kernel(x_ref, o_ref): self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output()) - def test_scoped_allocation(self): + def test_run_scoped(self): def kernel(x_ref, o_ref): def body(tmp_ref): self.assertEqual(tmp_ref.shape, (8, 128)) @@ -309,6 +428,25 @@ def kernel(o_ref): jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32), ) + def test_program_id_in_block_spec(self): + @functools.partial( + pl.pallas_call, + out_specs=pl.BlockSpec((128,), lambda *_: pl.program_id(0)), + out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), + grid=2, + ) + def kernel(o_ref): + del o_ref + + # ``assertRaises`` have no way of asserting against the cause, so we + # have to use ``traceback.format_exception`` manually. + with self.assertRaises(Exception) as exc_info: + kernel() + self.assertIn( + "not supported in this context", + "".join(traceback.format_exception(exc_info.exception)), + ) + def test_num_programs(self): @functools.partial( pl.pallas_call, @@ -327,27 +465,27 @@ def kernel(o_ref): def test_swizzled_blockspec_shapes(self): + spec = plgpu.GPUBlockSpec( + (128, 64), + lambda *i: i, + transforms=( + plgpu.TilingTransform((64, 64)), + plgpu.SwizzleTransform(128), + ), + ) @functools.partial( pl.pallas_call, - in_specs=[ - plgpu.GPUBlockSpec( - (128, 64), - lambda *i: i, - transforms=plgpu.TilingTransform((64, 64)), - swizzle=128, - ), - ], - out_specs=pl.BlockSpec((2, 1, 64, 64), lambda i, j: (i, j, 64, 64)), - out_shape=jax.ShapeDtypeStruct((4, 2, 64, 64), jnp.float16), + in_specs=[spec], + out_specs=spec, + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), grid=(2, 2), ) def kernel(x_ref, o_ref): - assert x_ref.shape == (2, 1, 64, 64), x_ref.shape + assert x_ref.shape == (128, 64), x_ref.shape o_ref[...] = x_ref[...] - x = jnp.zeros((256, 128), dtype=jnp.float16) - result = kernel(x) - self.assertEqual(result.shape, (4, 2, 64, 64)) + x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) + np.testing.assert_array_equal(kernel(x), x) def test_fori_loop_array(self): @functools.partial( @@ -376,6 +514,22 @@ def kernel(o_ref): kernel(), jnp.full([256], 5.0, dtype=jnp.float32) ) + def test_fori_loop_indexed_store(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + def body(idx, _): + o_ref[idx] = x_ref[idx] + y_ref[idx] + return () + + jax.lax.fori_loop(0, 4, body, ()) + + x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32) + y = x + 1 + np.testing.assert_array_equal(kernel(x, y), x + y) + def test_cond(self): @functools.partial( @@ -406,7 +560,7 @@ def test_wgmma(self, dtype): elems_128b = swizzle // jnp.dtype(dtype).itemsize def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): - plgpu.wgmma(acc_ref, a_ref, b_ref, rhs_transpose=rhs_transpose) + plgpu.wgmma(acc_ref, a_ref, b_ref) return acc_ref[...] o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32)) @@ -424,14 +578,15 @@ def scope(acc_ref): plgpu.GPUBlockSpec( (64, 128), lambda i, j: (i, j), - transforms=plgpu.TilingTransform((64, elems_128b)), - swizzle=128, + transforms=( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ), ), plgpu.GPUBlockSpec( (128, 128), lambda *i: i, - transforms=rhs_transforms, - swizzle=128, + transforms=(*rhs_transforms, plgpu.SwizzleTransform(128)), ), ], out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), @@ -442,6 +597,45 @@ def scope(acc_ref): res, a @ (b.T if rhs_transpose else b), rtol=1e-3 ) + def test_wgmma_sliced(self): + swizzle = 128 + elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize + def kernel(a_ref, b_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref, b_ref) + return acc_ref[:, :64], acc_ref[:, 64:] + + o_ref[:, :64], o_ref[:, 64:] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32)) + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) + b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16) + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec( + (64, 128), + lambda i, j: (i, j), + transforms=( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ), + plgpu.GPUBlockSpec( + (128, 128), + lambda *i: i, + transforms=( + plgpu.TilingTransform((elems_128b, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ), + ], + out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), + out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), + grid=(1, 1), + )(a, b) + np.testing.assert_allclose(res, a @ b, rtol=1e-3) + def test_input_output_aliases(self): # Note that we're writing to the input pointer, which should alias b_ptr. def kernel(a_ref, b_ref): @@ -467,16 +661,16 @@ def test_realistic_matmul(self): tile_k = elems_128b m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n def kernel(a_ref, b_ref, o_ref, acc_ref): + # Make sure tiling does not alter the shape of references + assert a_ref.shape == (tile_m, tile_k) + assert b_ref.shape == (tile_k, tile_n) + assert o_ref.shape == acc_ref.shape == (tile_m, tile_n) plgpu.wgmma(acc_ref, a_ref, b_ref) - plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races - # TODO(apaszke): Only store in the last step. It doesn't work because we - # don't have partial discharge for control flow. - # is_last_step = pl.program_id(2) == grid_k - 1 - # @pl.when(is_last_step) - # def _epilogue(): - # pl.debug_print("{}", acc_ref[...]) - # TODO(apaszke): This is an untiled store! It's slow!! - o_ref[...] = acc_ref[...] + is_last_step = pl.program_id(2) == grid_k - 1 + @pl.when(is_last_step) + def _epilogue(): + o_ref[...] = acc_ref[...].astype(dtype) + plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1 key1, key2 = jax.random.split(jax.random.key(42), 2) a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) @@ -488,27 +682,129 @@ def kernel(a_ref, b_ref, o_ref, acc_ref): plgpu.GPUBlockSpec( (tile_m, tile_k), lambda m, n, k: (m, k), - transforms=plgpu.TilingTransform((64, elems_128b)), - swizzle=128, + transforms=( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ), ), plgpu.GPUBlockSpec( (tile_k, tile_n), lambda m, n, k: (k, n), - transforms=plgpu.TilingTransform((elems_128b, elems_128b)), - swizzle=128, + transforms=( + plgpu.TilingTransform((elems_128b, elems_128b)), + plgpu.SwizzleTransform(128), + ), ), ], - out_specs=plgpu.GPUBlockSpec((tile_m, tile_n), lambda m, n, k: (m, n)), - out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + out_specs=plgpu.GPUBlockSpec( + (tile_m, tile_n), + lambda m, n, k: (m, n), + transforms=( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], grid=(grid_m, grid_n, grid_k), compiler_params=plgpu.GPUCompilerParams( dimension_semantics=["parallel", "parallel", "sequential"], max_concurrent_steps=2, + delay_release=1, ), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) + def test_slicing(self): + left = upper = slice(None, 64) + right = lower = slice(64, None) + # We rotate the four quadrants of the input clockwise. + def rotate(src, dst): + dst[upper, left] = src[lower, left] + dst[upper, right] = src[upper, left] + dst[lower, right] = src[upper, right] + dst[lower, left] = src[lower, right] + + x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) + spec = plgpu.GPUBlockSpec( + (128, 128), + lambda: (0, 0), + transforms=( + plgpu.TilingTransform((64, 64)), + plgpu.SwizzleTransform(128), + ), + ) + f = pl.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) + expected = np.empty_like(x) + rotate(x, expected) + np.testing.assert_array_equal(f(x), expected) + + +class PipelineTest(PallasTest): + + def test_manual(self, max_concurrent_steps=2, num_steps=4): + + def kernel(x_gmem, o_gmem): + return pl.run_scoped( + functools.partial(scoped_kernel, x_gmem, o_gmem), + plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32), + plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32), + plgpu.Barrier(1, num_barriers=max_concurrent_steps), + ) + + def scoped_kernel(x_gmem, o_gmem, x_smem, o_smem, barrier): + gmem_slice = pl.ds(pl.program_id(0) * 32, 32) + + def body(step, _): + slot = step % max_concurrent_steps + + # Wait for the current GMEM->SMEM copy to complete. + plgpu.wait_barrier(barrier.at[slot]) + # Wait for the previous output SMEM->GMEM copy to complete. + plgpu.wait_smem_to_gmem(max_concurrent_steps - 1) + + o_smem[...] = x_smem[...] + 1.0 + + plgpu.copy_smem_to_gmem( + o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)] + ) + + fetch_step = step + max_concurrent_steps + fetch_slot = slot # (x + y) % y == x % y + jax.lax.cond( + fetch_step < num_steps, + lambda: plgpu.copy_gmem_to_smem( + x_gmem.at[gmem_slice, pl.ds(fetch_step * 16, 16)], + x_smem.at[fetch_slot], + barrier=barrier.at[fetch_slot], + ), + lambda: None, + ) + return () + + # Initialize the pipeline. + for slot in range(min(max_concurrent_steps, num_steps)): + plgpu.copy_gmem_to_smem( + x_gmem.at[gmem_slice, pl.ds(slot * 16, 16)], + x_smem.at[slot], + barrier=barrier.at[slot], + ) + + jax.lax.fori_loop(0, num_steps, body, ()) + + # Finalize the pipeline. + plgpu.wait_smem_to_gmem(0) + + x = jnp.arange(32 * 4 * 64).reshape(32 * 4, 64).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(4, 1), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 7bba9f01bec9..5135d77119ac 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -15,7 +15,6 @@ """Tests for common JAX operations within pallas_call.""" from collections.abc import Sequence -import contextlib import functools import itertools import sys @@ -30,7 +29,6 @@ import jax.numpy as jnp from jax import lax from jax import random -from jax._src import config from jax._src import dtypes from jax._src import linear_util as lu from jax._src import state @@ -39,7 +37,7 @@ from jax.experimental import pallas as pl if sys.platform != "win32": - from jax.experimental.pallas import gpu as plgpu + from jax.experimental.pallas import triton as plgpu from jax.experimental.pallas import tpu as pltpu else: plgpu = None @@ -561,14 +559,6 @@ def test_cast(self, from_dtype, to_dtype, data): self.skipTest("Not supported: bad canonicalization") if from_dtype == "bool" and to_dtype in {"int16", "int8"}: self.skipTest("Not supported: cannot extend to sub-32 bit types") - if from_dtype in {"bfloat16", "float32"} and to_dtype == "bool": - self.skipTest("Not supported: unsupported relayout") - if from_dtype == "bool" and to_dtype in {"int32", "bfloat16", "float32"}: - self.skipTest("Not supported: unsupported relayout") - if from_dtype in {"int16", "int8"} and to_dtype == "bool": - self.skipTest("Not supported: cannot truncate from sub-32 bit types") - if from_dtype in {"int16", "int8"} and to_dtype == "bool": - self.skipTest("Not supported: cannot truncate from sub-32 bit types") if jtu.test_device_matches(["gpu"]): if (from_dtype in {"bfloat16", "float32"} and to_dtype in {"int8", "int16", "int32"}): @@ -707,16 +697,10 @@ def run(interpret=False): for value in values ) def test_sign(self, dtype, value): - if ( - not jax.config.x64_enabled - and dtype in (jnp.uint64, jnp.int64, jnp.float64) - ): + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") - if ( - jtu.test_device_matches(["tpu"]) - and dtype in (jnp.uint16, jnp.int16, jnp.bfloat16, jnp.float16) - ): + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") @functools.partial( @@ -753,37 +737,6 @@ def kernel(x_ref, o_ref): expected = lax.erf_inv(x) np.testing.assert_array_equal(out, expected) - -class OpsInterpretTest(OpsTest): - INTERPRET = True - - def test_debug_print(self): - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - grid=1, - ) - def kernel(x_ref, o_ref): - jax.debug.print("x = {}", x_ref) - - x = jnp.array([4.2, 2.4]).astype(jnp.float32) - with jtu.capture_stdout() as output: - jax.block_until_ready(kernel(x)) - jax.effects_barrier() - - self.assertIn("x = [4.2 2.4]", output()) - - -class OpsExtraTest(PallasBaseTest): - """These are additional ops tests that have not been ported to TPU yet.""" - # TODO: fix these for TPU and merge with OpsTest. - - def setUp(self): - super().setUp() - if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: - # TODO: most tests fail on TPU in non-interpret mode - self.skipTest("On TPU the test works only in interpret mode") - ELEMENTWISE_OPS = [ ( [jnp.abs, jnp.negative], @@ -811,17 +764,69 @@ def setUp(self): for fn, dtype in itertools.product(*args) ) def test_elementwise(self, fn, dtype): + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") + + # TODO(b/370578663): implement these lowerings on TPU + if jtu.test_device_matches(["tpu"]) and fn in ( + jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh, + jnp.cbrt, jnp.cosh, jnp.expm1, jnp.sinh, + ): + self.skipTest(f"{fn.__name__} not implemented on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), dtype), grid=1 ) def kernel(x_ref, o_ref): o_ref[:] = fn(x_ref[...]) - with contextlib.ExitStack() as stack: - if jnp.dtype(dtype).itemsize == 8: - stack.enter_context(config.enable_x64(True)) - x = jnp.array([0.42, 2.4]).astype(dtype) - np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) + x = jnp.array([0.42, 2.4]).astype(dtype) + np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) + + @parameterized.named_parameters( + (f"{fn.__name__}_{dtype}", fn, dtype) + for args in ELEMENTWISE_OPS + for fn, dtype in itertools.product(*args) + ) + def test_elementwise_scalar(self, fn, dtype): + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") + + if ( + jtu.test_device_matches(["tpu"]) + and fn == lax.population_count + and not self.INTERPRET + ): + self.skipTest( + "Scalar population count on TPU is only supported in interpret mode" + ) + + # TODO(b/370578663): implement these lowerings on TPU + if jtu.test_device_matches(["tpu"]) and fn in ( + jnp.abs, jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, + jnp.atanh, jnp.cbrt, jnp.cos, jnp.cosh, jnp.expm1, + jnp.sin, jnp.sinh, jnp.tan, lax.rsqrt, + ): + self.skipTest(f"{fn.__name__} not implemented on TPU") + + @functools.partial( + self.pallas_call, + in_specs=(pl.BlockSpec(memory_space=smem_on_tpu()),), + out_specs=pl.BlockSpec(memory_space=smem_on_tpu()), + out_shape=jax.ShapeDtypeStruct((2,), dtype), + ) + def kernel(x_ref, o_ref): + o_ref[0] = fn(x_ref[0]) + o_ref[1] = fn(x_ref[1]) + + x = jnp.array([0.42, 2.4]).astype(dtype) + np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) def test_abs_weak_type(self): # see https://github.com/jax-ml/jax/issues/23191 @@ -841,18 +846,24 @@ def kernel(x_ref, o_ref): ("float64", "float64"), ) def test_pow(self, x_dtype, y_dtype): + if jtu.test_device_matches(["tpu"]): + self.skipTest("TODO: Error on TPU") + + if not jax.config.x64_enabled and jnp.dtype(x_dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), x_dtype), grid=1 ) def kernel(x_ref, y_ref, o_ref): o_ref[:] = lax.pow(x_ref[...], y_ref[...]) - with contextlib.ExitStack() as stack: - if jnp.dtype(x_dtype).itemsize == 8: - stack.enter_context(config.enable_x64(True)) - x = jnp.array([1, 2, 3, 4]).astype(x_dtype) - y = jnp.array([1, 2, 3, 4]).astype(y_dtype) - np.testing.assert_allclose(kernel(x, y), lax.pow(x, y)) + if not jax.config.x64_enabled and jnp.dtype(x_dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + + x = jnp.array([1, 2, 3, 4]).astype(x_dtype) + y = jnp.array([1, 2, 3, 4]).astype(y_dtype) + np.testing.assert_allclose(kernel(x, y), lax.pow(x, y)) @parameterized.parameters(0, 1, 2, 3, 4, 5, -1, -2, -3) def test_integer_pow(self, y): @@ -867,20 +878,22 @@ def kernel(x_ref, o_ref): @parameterized.parameters("float32", "float64") def test_nextafter(self, dtype): - if jtu.test_device_matches(["tpu"]) and dtype == "float64": - self.skipTest("float64 disabled on TPU.") + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + + # TODO: implement this on TPU + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented: nextafter") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1 ) def kernel(x_ref, y_ref, o_ref): o_ref[:] = jnp.nextafter(x_ref[...], y_ref[...]) - with contextlib.ExitStack() as stack: - if jnp.dtype(dtype).itemsize == 8: - stack.enter_context(config.enable_x64(True)) - x = jnp.array([1, 2, 3, 4]).astype(dtype) - y = jnp.array([1, 2, 3, 4]).astype(dtype) - np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y)) + x = jnp.array([1, 2, 3, 4]).astype(dtype) + y = jnp.array([1, 2, 3, 4]).astype(dtype) + np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y)) COMPARISON_OPS = [ jnp.equal, @@ -901,6 +914,13 @@ def test_comparison(self, fn, dtype): if jtu.test_device_matches(["gpu"]) and dtype == "bool": self.skipTest("Not implemented on GPU.") + if jtu.test_device_matches(["tpu"]) and dtype == "float16": + self.skipTest("float16 is not supported on TPU") + + # TODO: skipped due to https://github.com/jax-ml/jax/issues/24030 + if jtu.test_device_matches(["tpu"]) and dtype == "bool": + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), grid=1) @@ -979,6 +999,21 @@ def kernel(x_ref, y_ref, o_ref): for fn, dtype in itertools.product(*args) ) def test_binary(self, f, dtype): + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") + + # TODO: skipped due to https://github.com/jax-ml/jax/issues/24027 + if ( + jtu.test_device_matches(["tpu"]) + and f == jnp.remainder + and not self.INTERPRET + ): + self.skipTest("jnp.remainder on TPU is only supported in interpret mode") + + # TODO(ayx): fix this on TPU + if jtu.test_device_matches(["tpu"]) and dtype == "uint32": + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype), grid=1 ) @@ -986,19 +1021,60 @@ def kernel(x_ref, y_ref, o_ref): o_ref[...] = f(x_ref[...], y_ref[...]) x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) - if (f == jnp.bitwise_left_shift): + if f == jnp.bitwise_left_shift: y = jnp.array([3, 1, 4, 5, 2, 2, 2, 4]).astype(dtype) else: y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) np.testing.assert_allclose(f(x, y), kernel(x, y)) + @parameterized.named_parameters( + (f"{fn.__name__}_{dtype}", fn, dtype) + for args in BINARY_OPS + for fn, dtype in itertools.product(*args) + ) + def test_binary_scalar(self, f, dtype): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Test only supported on TPU.") + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") + # TODO: skipped due to https://github.com/jax-ml/jax/issues/24027 + if ( + jtu.test_device_matches(["tpu"]) + and f == jnp.remainder + and not self.INTERPRET + ): + self.skipTest("jnp.remainder on TPU is only supported in interpret mode") + + # TODO: skipped due to https://github.com/jax-ml/jax/issues/23972 + if jtu.test_device_matches(["tpu"]) and dtype == "uint32": + self.skipTest("Not supported on TPU") + + @functools.partial( + self.pallas_call, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_shape=jax.ShapeDtypeStruct((1,), dtype), grid=1 + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[0] = f(x_ref[0], y_ref[0]) + + x = jnp.array([1,]).astype(dtype) + y = jnp.array([18,]).astype(dtype) + + np.testing.assert_allclose(f(x, y), kernel(x, y)) + @parameterized.parameters( ((8, 4), jnp.int32, 0), ((8, 16), jnp.float32, 1), ((8, 16, 2), jnp.int8, 1), ) def test_broadcasted_iota(self, shape, dtype, dimension): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Only 32-bit integer iota supported") + f = lambda: jax.lax.broadcasted_iota(dtype, shape, dimension) @functools.partial( @@ -1011,8 +1087,12 @@ def kernel(o_ref): @parameterized.parameters("float16", "bfloat16", "float32") def test_approx_tanh(self, dtype): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + if self.INTERPRET: self.skipTest("approx_tanh is not supported in interpret mode") + if (dtype == "bfloat16" and not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") @@ -1034,6 +1114,9 @@ def kernel(x_ref, o_ref): ) def test_elementwise_inline_asm(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented: elementwise_inline_asm_p") + if self.INTERPRET: self.skipTest( "elementwise_inline_asm is not supported in interpret mode" @@ -1057,6 +1140,9 @@ def kernel(x_ref, o_ref): np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3) def test_debug_barrier(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented: debug_barrier_p") + if self.INTERPRET: self.skipTest("debug_barrier is not supported in interpret mode") @@ -1077,9 +1163,13 @@ def kernel(x_ref, o_ref): "plgpu.TritonCompilerParams unavailable on Windows", ) def test_debug_print(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + # TODO: this test flakes on gpu if jtu.test_device_matches(["gpu"]): self.skipTest("This test flakes on gpu") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), @@ -1101,9 +1191,13 @@ def kernel(x_ref, o_ref): "plgpu.TritonCompilerParams unavailable on Windows", ) def test_debug_print_with_values(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + # TODO: this test flakes on gpu if jtu.test_device_matches(["gpu"]): self.skipTest("This test flakes on gpu") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), @@ -1127,6 +1221,9 @@ def kernel(x_ref, o_ref): ((64,), (32, 2)), ) def test_reshape(self, in_shape, out_shape): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), @@ -1156,6 +1253,10 @@ def f(x_ref, o_ref): # fmt: on ) def test_reshape_noop_or_singleton_dims(self, in_shape, out_shape): + # Unsupported implicit dim change: from "32,{0,0},(2,128),-1" to none + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), @@ -1171,6 +1272,7 @@ def f(x_ref, o_ref): def test_num_programs(self): @functools.partial( self.pallas_call, + out_specs=pl.BlockSpec(memory_space=smem_on_tpu()), out_shape=jax.ShapeDtypeStruct((4,), intx), grid=4, ) @@ -1182,6 +1284,10 @@ def kernel(o_ref): ) def test_where_broadcasting(self): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 2, 2), floatx), @@ -1207,6 +1313,10 @@ def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): ((), (2, 2), ()), ) def test_broadcast_in_dim(self, in_shape, out_shape, dims): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), @@ -1227,6 +1337,12 @@ def f(x_ref, o_ref): trans_y=[False, True], ) def test_dot(self, size, dtype, trans_x, trans_y): + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") + + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented: Transposed LHS") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((size, size), dtype), @@ -1249,6 +1365,9 @@ def dot(x_ref, y_ref, o_ref): block_size=[1, 2, 32, 64, 128], ) def test_masked_load_store(self, size, block_size): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented") + @functools.partial( self.pallas_call, out_shape=(jax.ShapeDtypeStruct((size,), floatx)), @@ -1266,6 +1385,10 @@ def kernel(x_ref, o_ref): np.testing.assert_allclose(kernel(x), x + 1.0, atol=1e-5, rtol=1e-5) def test_masked_oob_load_store_slice(self): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + n = 16 @functools.partial( @@ -1290,15 +1413,18 @@ def test_strided_load(self): # Reproducer from https://github.com/jax-ml/jax/issues/20895. @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32), ) def kernel(x_ref, o_ref): o_ref[...] = x_ref[::4] - x = jnp.arange(16, dtype=jnp.float32) + x = jnp.arange(64, dtype=jnp.float32).reshape((16, 4)) np.testing.assert_array_equal(kernel(x), x[::4]) def test_broadcasted_load_store(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Unimplemented primitive: broadcast_to") + m, n = 16, 32 @functools.partial( @@ -1320,6 +1446,10 @@ def load(x_ref, o_ref): ((16, 32), (16, 16)), ) def test_invalid_broadcasted_load(self, x_shape, mask_shape): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + if self.INTERPRET: self.skipTest("No broadcasting checks in pl.load in interpret mode") @@ -1342,6 +1472,10 @@ def kernel(x_ref, mask_ref, o_ref): self.fail("Expected exception due to invalid broadcasting") def test_swap(self): + # TODO: skipped due to https://github.com/jax-ml/jax/issues/24023 + if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: + self.skipTest("On TPU this is only supported in interpret mode") + m, n = 16, 32 @functools.partial( @@ -1362,6 +1496,9 @@ def swap(_, _2, x_ref, y_ref): np.testing.assert_array_equal(out[1], x) def test_masked_swap(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + m, n = 16, 32 @functools.partial( @@ -1383,6 +1520,10 @@ def masked_swap(_, _2, mask_ref, x_ref, y_ref): np.testing.assert_array_equal(out[1], jnp.where(mask, x, y)) def test_masked_oob_swap_slice(self): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + m, n = 32, 16 @functools.partial( @@ -1421,6 +1562,10 @@ def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref): ("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min), ) def test_scalar_atomic(self, op, value, numpy_op): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), value.dtype), @@ -1452,6 +1597,9 @@ def atomic_kernel(x_ref, _, o_ref): @parameterized.parameters((0,), (1,)) def test_array_atomic_add(self, axis): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Unimplemented primitive: broadcast_to") + m, n = 32, 8 if axis == 0: grid = m @@ -1489,6 +1637,10 @@ def reduce(x_ref, _, y_ref): (2, 1, 1), ) def test_atomic_cas(self, init_value, cmp, new_value): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + if jax.config.x64_enabled and jtu.test_device_matches(["gpu"]): self.skipTest("Not supported on GPU in 64-bit mode") @@ -1507,6 +1659,10 @@ def swap(_, lock_ref, out_ref): @parameterized.parameters(1, 2, 3, 4, 8) def test_atomic_counter(self, num_threads): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + if self.INTERPRET: self.skipTest("While loop not supported in interpret mode.") @@ -1532,6 +1688,10 @@ def _cond(_): @parameterized.parameters(False, True) def test_reduce_only_dim(self, use_store): + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + m = 32 x = random.normal(random.key(0), (m,), dtype=jnp.float32) out_shape = jax.ShapeDtypeStruct((), x.dtype) @@ -1573,9 +1733,10 @@ def reduce(x_ref, y_ref): if isinstance(axis, int) or "arg" not in op_name ]) def test_array_reduce(self, op, dtype, axis): - m, n = 32, 8 + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") - if not jax.config.x64_enabled and dtype in ("float64", "int64", "uint64"): + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") # Skip argmin/argmax on GPU in 64-bit mode because Pallas expects @@ -1587,6 +1748,12 @@ def test_array_reduce(self, op, dtype, axis): ): self.skipTest("Not supported on GPU in 64-bit mode") + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + + m, n = 32, 8 + def make_x(key): if jnp.issubdtype(dtype, jnp.integer): return random.permutation( @@ -1623,6 +1790,9 @@ def reduce(x_ref, y_ref): dtype=["float16", "float32", "int32", "uint32"], ) def test_cumsum(self, dtype, axis): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + m, n = 32, 8 out_dtype = dtype @@ -1649,9 +1819,25 @@ def reduce(x_ref, y_ref): np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) -class OpsExtraInterpretTest(OpsExtraTest): +class OpsInterpretTest(OpsTest): INTERPRET = True + def test_debug_print(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + grid=1, + ) + def kernel(x_ref, o_ref): + jax.debug.print("x = {}", x_ref) + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + jax.effects_barrier() + + self.assertIn("x = [4.2 2.4]", output()) + class PallasPrimitivesTest(PallasBaseTest): diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 6df31b55f8e7..1cdc4075d20a 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -2033,11 +2033,43 @@ def _(): np.testing.assert_allclose(out, expected, atol=atol) -class PallasCheckifyInterpretTest(PallasBaseTest): - # TODO(b/346651778): Support non-interpret mode checkify. - INTERPRET = True +class PallasCheckifyTest(PallasBaseTest): + INTERPRET = False + + def test_basic_runtime_assert(self): + # TODO(justinfu): Move to non-interpret checkify class. + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Runtime check only implemented on TPU.") + # Run this test manually, since we cannot recover from a halt. + self.skipTest("Cannot recover from halt.") + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + checkify.check(True, "first check passed") + checkify.check(False, "second check failed") + input_ = jnp.arange(4, dtype=jnp.int32) + out_shape = jax.ShapeDtypeStruct(input_.shape, input_.dtype) + with pltpu.enable_runtime_assert(True): + pallas_call = pl.pallas_call(kernel, out_shape=out_shape) + pallas_call(input_) # This should log "second check failed" + + def test_runtime_assert_is_noop_when_not_enabled(self): + # TODO(justinfu): Move to non-interpret checkify class. + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Runtime check only implemented on TPU.") + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + checkify.check(False, "failed check", + debug=True) # This check always fails. + input_ = jnp.arange(4, dtype=jnp.int32) + out_shape = jax.ShapeDtypeStruct(input_.shape, input_.dtype) + with pltpu.enable_runtime_assert(False): + pallas_call = pl.pallas_call(kernel, out_shape=out_shape) + result = pallas_call(input_) + np.testing.assert_allclose(result, input_) def test_no_checkify(self,): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Not supported on GPU.") def kernel(y_ref): y_ref[...] = jnp.zeros_like(y_ref[...]) out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32) @@ -2049,6 +2081,8 @@ def kernel(y_ref): np.testing.assert_allclose(result, jnp.zeros_like(result)) def test_does_not_clobber_previous_error(self,): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Not supported on GPU.") def kernel(y_ref): y_ref[...] = jnp.zeros_like(y_ref[...]) checkify.check(False, "error in kernel") @@ -2067,6 +2101,8 @@ def error_before_call(): @parameterized.parameters((False,), (True,)) def test_trivial_check(self, assert_cond): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Not supported on GPU.") def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] checkify.check(assert_cond, "pallas check failed") @@ -2083,6 +2119,8 @@ def kernel(x_ref, y_ref): np.testing.assert_allclose(result, input) def test_nan_error(self): + if not self.INTERPRET: + self.skipTest("Not supported in non-interpret mode.") def kernel(x_ref, y_ref): y_ref[...] = jnp.log(x_ref[...]) input = jnp.arange(4, dtype=jnp.float32) - 2 @@ -2090,7 +2128,7 @@ def kernel(x_ref, y_ref): pallas_call = self.pallas_call(kernel, out_shape=out_shape) checked_call = checkify.checkify(pallas_call, - errors=checkify.all_checks) + errors=checkify.nan_checks) err, result = checked_call(input) with self.assertRaisesRegex( checkify.JaxRuntimeError, "nan generated by primitive: log"): @@ -2119,6 +2157,8 @@ def kernel(x_ref, y_ref): @parameterized.parameters((5, 0), (8, 3), (4, 3)) def test_checkify_returns_first_error_in_grid( self, num_loops, fail_iteration): + if not self.INTERPRET: + self.skipTest("Not supported in non-interpret mode.") # Check that checkify returns the first error that occurs # TODO(justinfu): This test doesn't make sense on GPU, where threads run # in parallel. Update checkify to return a grid of errors. @@ -2137,12 +2177,42 @@ def kernel(x_ref, _): out_shape=out_shape) checked_call = checkify.checkify(pallas_call, - errors=checkify.all_checks) + errors=checkify.user_checks) err, _ = checked_call(input_arr) with self.assertRaisesRegex( checkify.JaxRuntimeError, f"failed on loop {fail_iteration}"): err.throw() + def test_checkify_on_oob_grid_access(self): + if not self.INTERPRET: + self.skipTest("Not supported in non-interpret mode.") + if config.enable_x64.value: + self.skipTest("Not supported in x64 mode.") + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + input_arr = jnp.arange(18, dtype=jnp.float32) + in_specs = [pl.BlockSpec((8,), lambda x: (x,))] + out_specs = pl.BlockSpec((8,), lambda x: (x,)) + out_shape = jax.ShapeDtypeStruct((18,), dtype=jnp.float32) + pallas_call = self.pallas_call(kernel, + grid=(3,), + in_specs=in_specs, + out_specs=out_specs, + out_shape=out_shape) + + checked_call = checkify.checkify(pallas_call, + errors=checkify.index_checks) + err, result = checked_call(input_arr) + with self.assertRaisesRegex(checkify.JaxRuntimeError, + (r"out-of-bounds indexing for array of shape \(18,\): index 16 " + r"is out of bounds for axis 0 with size 18")): + err.throw() + np.testing.assert_array_equal(result, input_arr) + + +class PallasCheckifyInterpretTest(PallasCheckifyTest): + INTERPRET = True + class PallasCallNamedGridTest(PallasBaseTest): def test_named_grid(self): diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 1d57dc164294..ca5361a70051 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -208,6 +208,31 @@ def body(x_ref, o_ref): ), ) + def test_select_with_scalar_condition(self): + def kernel(cond, lhs, rhs, out): + out[:] = jax.lax.select(cond[0] != 0, lhs[:], rhs[:]) + + def run(cond, lhs, rhs): + return pl.pallas_call( + kernel, + out_shape=lhs, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ], + ), + name="select_kernel", + )(cond, lhs, rhs) + + cond = jnp.array([1], dtype=jnp.int32) + lhs = jnp.zeros((8, 128), dtype=jnp.float32) + rhs = jnp.ones((8, 128), dtype=jnp.float32) + + assert (run(cond, lhs, rhs) == lhs).all() + class OpsInterpretTest(OpsTest): INTERPRET = True diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index 4f9d591dbea4..ef8d3ea8986d 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -20,6 +20,7 @@ from absl.testing import parameterized import jax from jax._src import test_util as jtu +from jax._src.state import discharge as state_discharge from jax.experimental import pallas as pl from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu @@ -755,5 +756,123 @@ def body(_, x): np.testing.assert_array_equal(y, expected) +def make_stateful_async_copy(): + @jax.named_call + def copy_start(x_ref, o_ref) -> Future: + + def copy_start_kernel(sem): + pltpu.make_async_copy(x_ref, o_ref, sem).start() + sem = pl.pallas_call( + copy_start_kernel, + out_shape=pltpu.SemaphoreType.DMA(()), + out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + )() + return sem + + @jax.named_call + def copy_done(x_ref, o_ref, future): + sem = future + + def copy_done_kernel(sem): + pltpu.make_async_copy(x_ref, o_ref, sem).wait() + + () = pl.pallas_call( + copy_done_kernel, + out_shape=(), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + )(sem) + + return copy_start, copy_done + + +def make_stateful_async_slice(i: int): + @jax.named_call + def copy_start(x_ref, o_ref) -> Future: + + def copy_start_kernel(sem): + pltpu.make_async_copy(x_ref.at[i], o_ref, sem).start() + sem = pl.pallas_call( + copy_start_kernel, + out_shape=pltpu.SemaphoreType.DMA(()), + out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + )() + return sem + + @jax.named_call + def copy_done(x_ref, o_ref, future): + sem = future + + def copy_done_kernel(sem): + pltpu.make_async_copy(x_ref.at[i], o_ref, sem).wait() + + () = pl.pallas_call( + copy_done_kernel, + out_shape=(), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + )(sem) + + return copy_start, copy_done + + +class PallasCallStatefulAsyncTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('DMAs only guaranteed to work ou TPU v4+') + + def test_basic_stateful_async_copy(self): + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def body(refs): + copy_start, copy_done = make_stateful_async_copy() + x_ref, y_ref = refs + fut = copy_start(x_ref, y_ref) + copy_done(x_ref, y_ref, fut) + _, y = state_discharge.run_state(body)((x, y)) + return y + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_multiple_stateful_async_copy(self): + @jax.jit + def f(x): + y = y2 = jnp.zeros_like(x) + def body(refs): + copy_start, copy_done = make_stateful_async_copy() + x_ref, y_ref, y2_ref = refs + fut = copy_start(x_ref, y_ref) + fut2 = copy_start(x_ref, y2_ref) + copy_done(x_ref, y_ref, fut) + copy_done(x_ref, y2_ref, fut2) + _, y, y2 = state_discharge.run_state(body)((x, y, y2)) + return y, y2 + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y, y2 = f(x) + np.testing.assert_array_equal(y, x) + np.testing.assert_array_equal(y2, x) + + def test_basic_stateful_async_slice(self): + @jax.jit + def f(x): + y = jnp.zeros(x.shape[1:], x.dtype) + def body(refs): + copy_start, copy_done = make_stateful_async_slice(2) + x_ref, y_ref = refs + fut = copy_start(x_ref, y_ref) + copy_done(x_ref, y_ref, fut) + _, y = state_discharge.run_state(body)((x, y)) + return y + x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x[2]) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_mesh_test.py b/tests/pallas/tpu_pallas_mesh_test.py deleted file mode 100644 index 0df759aec724..000000000000 --- a/tests/pallas/tpu_pallas_mesh_test.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Pallas mesh API.""" - -from absl.testing import absltest -import jax -from jax._src import test_util as jtu -from jax._src.state import discharge as state_discharge -from jax.experimental import pallas as pl -from jax.experimental import shard_map -from jax.experimental.pallas import tpu as pltpu -import jax.numpy as jnp -import numpy as np - - -jax.config.parse_flags_with_absl() - - -class ShmallasTest(jtu.JaxTestCase): - - def setUp(self): - super().setUp() - if not jtu.is_device_tpu_at_least(4): - self.skipTest("Only supported on TPU v4+") - - def test_can_create_tensorcore_mesh(self): - _ = pltpu.create_tensorcore_mesh("x") - - def test_can_trivially_shard_map_with_pallas_mesh(self): - mesh = pltpu.create_tensorcore_mesh("x") - _ = shard_map.shard_map(lambda: None, mesh, in_specs=(), out_specs=None)() - - def test_can_run_basic_pallas_kernel_with_shard_map(self): - mesh = pltpu.create_tensorcore_mesh("x") - - @jax.jit - def f(x): - y = jnp.zeros_like(x) - def inner(refs): - x_ref, y_ref = refs - def kernel(): - def alloc(sem): - pltpu.async_copy(x_ref, y_ref, sem).wait() - pl.run_scoped(alloc, pltpu.SemaphoreType.DMA) - shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, - check_rep=False)() - _, y = state_discharge.run_state(inner)((x, y)) - return y - x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) - y = f(x) - np.testing.assert_array_equal(y, x) - - def test_can_query_core_index_pallas_kernel_with_shard_map(self): - mesh = pltpu.create_tensorcore_mesh("x") - - @jax.jit - def f(x): - y = jnp.zeros_like(x) - def inner(refs): - x_ref, y_ref = refs - def kernel(): - num_cores = jax.lax.psum(1, "x") - slc_size = 16 // num_cores - def alloc(x_vmem_ref, y_vmem_ref, sem): - core_index = jax.lax.axis_index("x") - slc = pl.ds(core_index * slc_size, slc_size) - pltpu.async_copy( - x_ref.at[slc], - x_vmem_ref, - sem, - ).wait() - y = x_vmem_ref[...] + jax.lax.axis_index("x") - y_vmem_ref[...] = y - pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait() - pl.run_scoped( - alloc, - pltpu.VMEM((slc_size, 128), x_ref.dtype), - pltpu.VMEM((slc_size, 128), y_ref.dtype), - pltpu.SemaphoreType.DMA, - ) - shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, - check_rep=False)() - _, y = state_discharge.run_state(inner)((x, y)) - return y - num_cores = jax.devices()[0].num_cores - x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128)) - expected_out = ( - x.reshape((num_cores, -1, 128)) + jnp.arange(num_cores)[..., None, None] - ).reshape(x.shape) - y = f(x) - np.testing.assert_array_equal(y, expected_out) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index e3d43125c9ab..2b5c315263c9 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -220,5 +220,39 @@ def body(key_ref, o_ref): np.testing.assert_array_equal(result_16x128, result_32x256) +class ThreefryTest(parameterized.TestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Need TPU devices") + super().setUp() + + @parameterized.parameters( + ((8, 128),), + ((32, 256),), + ((4, 16, 128),), + ) + def test_uniform_matches_jax_threefry(self, shape): + def body(key_ref, o_ref): + key = jax.random.wrap_key_data(key_ref[0, ...], impl='threefry2x32') + o_ref[...] = jax_random.uniform( + key, shape=o_ref[...].shape, minval=0.0, maxval=1.0 + ) + + threefry_key = jax_random.key(0, impl="threefry2x32").reshape((1,)) + o_shape = jax.ShapeDtypeStruct(shape, jnp.float32) + with jax.threefry_partitionable(True): + # TODO(justinfu): support passing keys into VMEM. + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], + out_shape=o_shape, + )(jax.random.key_data(threefry_key)) + jax_result = jax_random.uniform( + threefry_key[0], shape=o_shape.shape, minval=0.0, maxval=1.0 + ) + np.testing.assert_array_equal(result, jax_result) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py new file mode 100644 index 000000000000..b017cac2fba0 --- /dev/null +++ b/tests/pallas/tpu_pallas_state_test.py @@ -0,0 +1,271 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Pallas mesh API.""" +import functools +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax._src.state import discharge as state_discharge +from jax.experimental import pallas as pl +from jax.experimental import shard_map +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() + + +class PallasCallStatefulTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Only supported on TPU v4+") + + def test_basic_stateful_kernel(self): + + def copy_kernel(x_ref, y_ref): + def body(sem): + pltpu.make_async_copy(x_ref, y_ref, sem).start() + pltpu.make_async_copy(x_ref, y_ref, sem).wait() + pl.run_scoped(body, pltpu.SemaphoreType.DMA) + + def f_stateful(refs): + x_ref, y_ref = refs + + pl.pallas_call(functools.partial(copy_kernel, x_ref, y_ref), + out_shape=[])() + + @jax.jit + def f(x): + _, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x))) + return y + + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_basic_stateful_kernel_with_scratch_sem(self): + + def copy_kernel(x_ref, y_ref, sem): + pltpu.make_async_copy(x_ref, y_ref, sem).start() + pltpu.make_async_copy(x_ref, y_ref, sem).wait() + + def f_stateful(refs): + x_ref, y_ref = refs + + pl.pallas_call(functools.partial(copy_kernel, x_ref, y_ref), + scratch_shapes=[pltpu.SemaphoreType.DMA], + out_shape=[])() + + @jax.jit + def f(x): + _, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x))) + return y + + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_basic_stateful_kernel_with_scalar_prefetch(self): + + def copy_kernel(x_ref, y_ref, index_ref, sem): + i = index_ref[0] + pltpu.make_async_copy(x_ref.at[i], y_ref, sem).start() + pltpu.make_async_copy(x_ref.at[i], y_ref, sem).wait() + + def f_stateful(refs): + x_ref, y_ref = refs + + pl.pallas_call( + functools.partial(copy_kernel, x_ref, y_ref), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + scratch_shapes=[pltpu.SemaphoreType.DMA], + ), + out_shape=[], + )(jnp.array([0])) + + @jax.jit + def f(x): + _, y = state_discharge.run_state(f_stateful)((x[None], jnp.zeros_like(x))) + return y + + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_basic_stateful_kernel_with_io_aliasing(self): + + def copy_kernel(x_ref, y_ref, x_old_ref, x_old_ref2, sem): + del x_old_ref, x_old_ref2 + pltpu.make_async_copy(x_ref, y_ref, sem).start() + pltpu.make_async_copy(x_ref, y_ref, sem).wait() + + def f_stateful(refs): + x_ref, y_ref, o_ref = refs + + x = pl.pallas_call( + functools.partial(copy_kernel, x_ref, y_ref), + scratch_shapes=[pltpu.SemaphoreType.DMA], + out_shape=jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype), + input_output_aliases={0: 0}, + )(x_ref[...]) + o_ref[...] = x + + @jax.jit + def f(x): + _, y, o = state_discharge.run_state(f_stateful)( + (x, jnp.zeros_like(x), jnp.zeros_like(x)) + ) + return y, o + + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + y, o = f(x) + np.testing.assert_array_equal(y, x) + np.testing.assert_array_equal(o, x) + + def test_stateful_matmul(self): + + m, k, n = 512, 512, 512 + bm, bk, bn = 128, 128, 128 + + def matmul_kernel(acc_ref, x_ref, y_ref, o_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref) + + acc_ref[...] += jnp.dot( + x_ref[...], y_ref[...], preferred_element_type=jnp.float32 + ) + + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + def _(): + o_ref[...] = acc_ref[...].astype(o_ref.dtype) + + def matmul(x, y): + + def run_matmul(refs): + x_ref, y_ref, o_ref = refs + + def matmul_pipeline_kernel(acc_ref): + pltpu.emit_pipeline( + functools.partial(matmul_kernel, acc_ref), + grid=(m // bm, n // bn, k // bk), + in_specs=[ + pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), + pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), + )(x_ref, y_ref, o_ref) + + pl.pallas_call( + matmul_pipeline_kernel, + out_shape=[], + scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)], + )() + + _, _, o = state_discharge.run_state(run_matmul)( + (x, y, jnp.ones((m, n), dtype=x.dtype)) + ) + return o + + x = jax.random.normal(jax.random.key(0), (m, k), jnp.float32) + y = jax.random.normal(jax.random.key(1), (k, n), jnp.float32) + o = matmul(x, y) + atol = 0 + if jtu.is_device_tpu(6): + atol = 2e-5 + np.testing.assert_allclose(o, x @ y, atol=atol) + + +class ShmallasTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Only supported on TPU v4+") + + def test_can_create_tensorcore_mesh(self): + _ = pltpu.create_tensorcore_mesh("x") + + def test_can_trivially_shard_map_with_pallas_mesh(self): + mesh = pltpu.create_tensorcore_mesh("x") + _ = shard_map.shard_map(lambda: None, mesh, in_specs=(), out_specs=None)() + + def test_can_run_basic_pallas_kernel_with_shard_map(self): + mesh = pltpu.create_tensorcore_mesh("x") + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def inner(refs): + x_ref, y_ref = refs + def kernel(): + def alloc(sem): + pltpu.async_copy(x_ref, y_ref, sem).wait() + pl.run_scoped(alloc, pltpu.SemaphoreType.DMA) + shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, + check_rep=False)() + _, y = state_discharge.run_state(inner)((x, y)) + return y + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_can_query_core_index_pallas_kernel_with_shard_map(self): + mesh = pltpu.create_tensorcore_mesh("x") + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def inner(refs): + x_ref, y_ref = refs + def kernel(): + num_cores = jax.lax.psum(1, "x") + slc_size = 16 // num_cores + def alloc(x_vmem_ref, y_vmem_ref, sem): + core_index = jax.lax.axis_index("x") + slc = pl.ds(core_index * slc_size, slc_size) + pltpu.async_copy( + x_ref.at[slc], + x_vmem_ref, + sem, + ).wait() + y = x_vmem_ref[...] + jax.lax.axis_index("x") + y_vmem_ref[...] = y + pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait() + pl.run_scoped( + alloc, + pltpu.VMEM((slc_size, 128), x_ref.dtype), + pltpu.VMEM((slc_size, 128), y_ref.dtype), + pltpu.SemaphoreType.DMA, + ) + shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, + check_rep=False)() + _, y = state_discharge.run_state(inner)((x, y)) + return y + num_cores = jax.devices()[0].num_cores + x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128)) + expected_out = ( + x.reshape((num_cores, -1, 128)) + jnp.arange(num_cores)[..., None, None] + ).reshape(x.shape) + y = f(x) + np.testing.assert_array_equal(y, expected_out) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index c20084c3c8e2..9d0389a4799f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -32,6 +32,7 @@ import jax.numpy as jnp from jax._src import core from jax._src import config +from jax._src import dispatch from jax._src import test_util as jtu from jax import dtypes from jax import stages @@ -41,6 +42,7 @@ from jax._src import prng from jax.sharding import PartitionSpec as P, Mesh from jax.experimental import multihost_utils +from jax.experimental.shard_map import shard_map from jax.experimental.custom_partitioning import custom_partitioning from jax._src import array from jax._src.sharding import Sharding, common_devices_indices_map @@ -57,7 +59,6 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -659,10 +660,7 @@ def testAutodiffCache(self): jax.grad(f)(x) # Warm up the cache. with jtu.count_pjit_cpp_cache_miss() as count: jax.grad(f)(x) - if xla_extension_version >= 286: - self.assertEqual(count[0], 0) # no cache miss i.e. cache hit - else: - self.assertEqual(count[0], 2) + self.assertEqual(count[0], 0) # no cache miss i.e. cache hit @jtu.with_mesh([('x', 2), ('y', 1)]) def testEvalJaxpr(self): @@ -3780,6 +3778,29 @@ def f(x): self.assertArraysEqual(out1[0], inp * 2) self.assertArraysEqual(out2[0], inp * 2) + @jtu.run_on_devices('tpu', 'gpu') + def test_aot_device_implicit_transfer(self): + mesh = jtu.create_mesh((1,), 'x') + np_inp = np.arange(8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P())) + + @jax.jit + def f(x): + return x * 2 + + compiled = f.lower(arr).compile() + + cpu_dev = jax.devices('cpu')[0] + with jax.default_device(cpu_dev): + cpu_arr = jnp.arange(8) + self.assertEqual(cpu_arr.sharding, SingleDeviceSharding(cpu_dev)) + self.assertFalse(cpu_arr._committed) + + out = compiled(cpu_arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + self.assertEqual(out.sharding.memory_kind, 'device') + def test_most_recent_executable_outer_inner_cache(self): x = np.zeros((20, 20), dtype=jnp.float64) @@ -4329,7 +4350,6 @@ def test_device_put_efficient_reshard_complex_mesh(self, shape): out = jax.device_put(x_s1, s2) self.assertArraysEqual(out, np_inp) self.assertEqual(out.sharding, s2) - del out s3 = NamedSharding(mesh2, P('model_q')) x_s3 = jax.device_put(np_inp, s3) @@ -4338,6 +4358,42 @@ def test_device_put_efficient_reshard_complex_mesh(self, shape): self.assertArraysEqual(out2, np_inp) self.assertEqual(out2.sharding, s1) + def test_device_put_donate_pytree(self): + shape1 = (8, 2) + shape2 = (8, 384) + if config.use_shardy_partitioner.value: + self.skipTest( + '_different_device_order_reshard is creating a GSPMDSharding') + if jax.device_count() < 8: + self.skipTest('Requires >= 8 devices') + + dev = jax.devices() + mesh1 = jax.sharding.Mesh( + np.asarray(dev).reshape([1, 2, 2, 2]), + ('replica', 'data', 'seq', 'model')) + mesh2 = jax.sharding.Mesh( + np.asarray(jax.devices()) + .reshape([1, 1, 2, 2, 2, 1]) + .swapaxes(2, 3) + .reshape([1, 1, 4, 2, 1]), + ('replica', 'data', 'seq', 'model_q', 'model_kv')) + + np_inp1 = jnp.arange(math.prod(shape1)).reshape(shape1) + np_inp2 = jnp.arange(math.prod(shape2)).reshape(shape2) + s1 = NamedSharding(mesh1, P('model')) + s2 = NamedSharding(mesh2, P('model_q')) + + x1 = jax.device_put(np_inp1, s1) + x2 = jax.device_put(np_inp2, s1) + # Reshard! + out1, out2 = jax.device_put((x1, x2), s2, donate=(True, False)) + self.assertArraysEqual(out1, np_inp1) + self.assertArraysEqual(out2, np_inp2) + self.assertEqual(out1.sharding, s2) + self.assertEqual(out2.sharding, s2) + self.assertTrue(x1.is_deleted()) + self.assertFalse(x2.is_deleted()) + def test_convert_element_type_sharding(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) @@ -4530,8 +4586,6 @@ def test_wsc_abstract_mesh_errors(self): ' match the mesh shape of the target sharding.*'): with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y'))) - @unittest.skipIf(xla_extension_version < 286, - "Requires xla_extension_version >= 286") def test_global_jit_cpp_cache_hit_out_shardings(self): mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) @@ -4577,6 +4631,47 @@ def f(x): else: self.assertEqual(lowered_text.count('@Sharding'), 2) + @config.sharding_in_types(True) + def test_fully_replicated_array_mul(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp1 = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr1 = jax.device_put(np_inp1, s) + + np_inp2 = np.arange(2).reshape(1, 2) + arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P(None, None))) + + @jax.jit + def f(x, y): + self.assertEqual(x.sharding.spec, s.spec) + out = x * y + self.assertEqual(out.sharding.spec, s.spec) + return out + + out = f(arr1, arr2) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, (np_inp1 * np_inp2)) + + out = f(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x',), ('y',))))) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, (np_inp1 * np_inp1)) + + out = f(arr1, jax.device_put(np_inp2, NamedSharding(mesh, P()))) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, (np_inp1 * np_inp2)) + + @jax.jit + def g(x, y): + return x * y + + with self.assertRaisesRegex( + TypeError, "mul got incompatible shardings for broadcasting"): + g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x')))) + + with self.assertRaisesRegex( + TypeError, "mul got incompatible shardings for broadcasting"): + g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y'))))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): @@ -5236,6 +5331,30 @@ def _init(): self.assertArraysEqual(w, w_gt) + def test_get_intermediate_shardings(self): + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + s = NamedSharding(mesh, P('x')) + arr = jax.device_put(np.arange(8), s) + + @jax.jit + def g(x): + x = with_sharding_constraint(x, s) + return with_sharding_constraint(x, s) + + @jax.jit + def f(x, y): + x, y = with_sharding_constraint((x, y), s) + x, y = shard_map(lambda x, y: (x, y), mesh=mesh, in_specs=P('x'), + out_specs=P('x'))(x, y) + x, y = jax.device_put((x, y), s) + x, y = jax.jit(lambda x, y: (x, y), in_shardings=s, out_shardings=s)(x, y) + return g(x), y + + jaxpr = f.trace(arr, arr).jaxpr + out = dispatch.get_intermediate_shardings(jaxpr) + self.assertLen(out, 16) + + @jtu.with_config(jax_use_shardy_partitioner=True) class SdyIntegrationTest(jtu.JaxTestCase): diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index d9887cf7b482..389a4181ebd9 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -635,13 +635,13 @@ def test_can_vmap_pure_callback(self): @jax.jit @jax.vmap def f(x): - return jax.pure_callback(np.sin, x, x) + return jax.pure_callback(np.sin, x, x, vmap_method="sequential") out = f(jnp.arange(4.)) np.testing.assert_allclose(out, np.sin(np.arange(4.))) @jax.jit def g(x): - return jax.pure_callback(np.sin, x, x) + return jax.pure_callback(np.sin, x, x, vmap_method="sequential") out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2))) np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T) @@ -649,7 +649,8 @@ def g(x): @functools.partial(jax.vmap, in_axes=(0, None)) def h(x, y): out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype)) - return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y) + return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y, + vmap_method="sequential") out = h(jnp.arange(4.), 4.) self.assertArraysAllClose(out, np.sin(np.arange(4.)) + 4., rtol=1E-7, check_dtypes=False) @@ -658,7 +659,8 @@ def h(x, y): @functools.partial(jax.vmap) def h(x, y): out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype)) - return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y) + return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y, + vmap_method="sequential") out = h(jnp.arange(4.), jnp.arange(10., 14.)) self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10., 14.), rtol=1E-7, check_dtypes=False) @@ -667,7 +669,8 @@ def h(x, y): @functools.partial(jax.vmap, in_axes=1, out_axes=1) def h(x, y): out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype)) - return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y) + return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y, + vmap_method="sequential") out = h(jnp.arange(4.)[None], jnp.arange(10., 14.)[None]) self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10., 14.)[None], @@ -682,7 +685,7 @@ def cb(x): @jax.jit @jax.vmap def f(x): - return jax.pure_callback(cb, x, x) + return jax.pure_callback(cb, x, x, vmap_method="sequential") np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.))) @@ -693,7 +696,7 @@ def cb2(x): @jax.jit @jax.vmap def g(x): - return jax.pure_callback(cb2, x, x, vectorized=True) + return jax.pure_callback(cb2, x, x, vmap_method="broadcast") np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.))) @@ -701,7 +704,7 @@ def g(x): @functools.partial(jax.vmap, in_axes=(0, None)) def h(x, y): return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y, - vectorized=True) + vmap_method="broadcast") out = h(jnp.arange(4.), 4.) np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.) @@ -709,7 +712,7 @@ def h(x, y): @functools.partial(jax.vmap, in_axes=(1, None), out_axes=1) def h(x, y): return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y, - vectorized=True) + vmap_method="legacy_vectorized") out = h(jnp.arange(4.)[None], 4.) np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.) @@ -722,7 +725,7 @@ def cb(x): @jax.jit @jax.vmap def f(x): - return jax.pure_callback(cb, x, x, vectorized=True) + return jax.pure_callback(cb, x, x, vmap_method="broadcast") with self.assertRaises(RuntimeError): f(jnp.arange(4.)) @@ -981,6 +984,52 @@ def f(x): out = jax.pure_callback(f, jax.ShapeDtypeStruct(x.shape, x.dtype), x) np.testing.assert_allclose(out, 2 * jnp.log(x + 1)) + def test_vmap_method_raise(self): + @jax.vmap + def f(x): + # Setting vectorized to None disables the current default behavior of + # falling back on sequential. + return jax.pure_callback(np.sin, x, x, vectorized=None) + + with self.assertRaisesRegex(NotImplementedError, "vmap is only supported"): + f(jnp.arange(4.)) + + def test_deprecated_vectorized(self): + def f(x, **kwargs): + return jax.pure_callback(np.sin, x, x, **kwargs) + + with self.assertWarnsRegex(DeprecationWarning, "The default behavior"): + jax.vmap(f)(jnp.arange(4.0)) + + with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): + f(jnp.arange(4.0), vectorized=True) + + with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): + f(jnp.arange(4.0), vectorized=False) + + def test_vmap_method_broadcast(self): + def callback(x, y): + self.assertTupleEqual(x.shape, (4,)) + self.assertTupleEqual(y.shape, (1,)) + return x + y + + def f(x, y): + return jax.pure_callback(callback, x, x, y, vmap_method="broadcast") + + jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error + + def test_vmap_method_broadcast_fullrank(self): + def callback(x, y): + self.assertTupleEqual(x.shape, (4,)) + self.assertTupleEqual(y.shape, (4,)) + return x + y + + def f(x, y): + return jax.pure_callback(callback, x, x, y, + vmap_method="broadcast_fullrank") + + jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error + class IOCallbackTest(jtu.JaxTestCase): diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 77e5273d172a..75a5cee44ea2 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -48,7 +48,6 @@ from jax._src.export import shape_poly_decision from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow -from jax._src.lib import xla_client from jax._src.lib import version as jaxlib_version import numpy as np @@ -2533,11 +2532,11 @@ def test_vmap_error(self): lambda x, fft_type, nr_fft_lengths: lax.fft_p.bind( x, fft_type=fft_type, fft_lengths=tuple( - x.shape[-nr_fft_lengths:] if fft_type != xla_client.FftType.IRFFT else + x.shape[-nr_fft_lengths:] if fft_type != lax.FftType.IRFFT else [(x.shape[-1] - 1) * 2])), arg_descriptors=[ RandArg((3, 4, 5, 6), - np.float32 if fft_type == xla_client.FftType.RFFT else np.complex64), + np.float32 if fft_type == lax.FftType.RFFT else np.complex64), StaticArg(fft_type), StaticArg(nr_fft_lengths)], # All axes but the last one are dynamic. This means that the test @@ -2545,8 +2544,8 @@ def test_vmap_error(self): polymorphic_shapes=["b0, b1, b2, ..."], tol=1e-4) - for fft_type in (xla_client.FftType.FFT, xla_client.FftType.IFFT, - xla_client.FftType.RFFT, xla_client.FftType.IRFFT) + for fft_type in (lax.FftType.FFT, lax.FftType.IFFT, + lax.FftType.RFFT, lax.FftType.IRFFT) for nr_fft_lengths in (1, 2) ], PolyHarness("full", "", @@ -2769,23 +2768,36 @@ def test_vmap_error(self): lambda x: jnp.nanquantile(x, .5, axis=0), arg_descriptors=[RandArg((3, 5), _f32)], polymorphic_shapes=["b, ..."]), + PolyHarness("inv", "", + lambda x: jnp.linalg.inv(jnp.eye(x.shape[0])), + arg_descriptors=[RandArg((3, 3), _f32)], + polymorphic_shapes=["b, b, ..."], + override_jax_config_flags={"jax_export_ignore_forward_compatibility": True}), [ - PolyHarness( + PolyHarness( # pylint: disable=g-complex-comprehension "qr", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}_{full_matrices=}", lambda x, full_matrices: lax.linalg.qr(x, full_matrices=full_matrices), arg_descriptors=[RandArg(shape, dtype), StaticArg(full_matrices)], - polymorphic_shapes=[poly]) + polymorphic_shapes=[poly], + symbolic_constraints=constraints) for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() - # m and n must be static for now - for shape, poly, full_matrices in [ - ((2, 0, 4), "b, ...", False), # m = 0 - ((2, 4, 0), "b, ...", False), # n = 0 - ((2, 3, 4, 4), "b1, b2, ...", False), # m == n - ((2, 3, 4, 4), "b1, b2, ...", True), - ((2, 3, 4, 5), "b1, b2, ...", False), # m < n - ((2, 3, 4, 5), "b1, b2, ...", True), - ((2, 3, 8, 4), "b1, b2, ...", False), # m > n - ((2, 3, 8, 4), "b1, b2, ...", True), + for shape, poly, full_matrices, constraints in [ + ((2, 0, 4), "b, ...", False, ()), # m = 0 + ((2, 4, 0), "b, ...", False, ()), # n = 0 + ((2, 3, 4, 4), "b1, b2, ...", False, ()), # m == n + ((2, 3, 4, 4), "b1, b2, ...", True, ()), + ((2, 3, 4, 5), "b1, b2, ...", False, ()), # m < n + ((2, 3, 4, 5), "b1, b2, ...", True, ()), + ((2, 3, 8, 4), "b1, b2, ...", False, ()), # m > n + ((2, 3, 8, 4), "b1, b2, ...", True, ()), + # Dynamic shapes are also supported for non-batch dimensions with + # some constraints. + ((2, 3, 4, 4), "b1, b2, m, m", False, ()), # m == n + ((2, 3, 4, 4), "b1, b2, m, m", True, ()), + ((2, 3, 4, 5), "b1, b2, m, n", False, ["m + 1 <= n"]), # m < n + ((2, 3, 4, 5), "b1, b2, m, n", True, ["m + 1 <= n"]), + ((2, 3, 8, 4), "b1, b2, m, n", False, ["n <= m"]), # m > n + ((2, 3, 8, 4), "b1, b2, m, n", True, ["n <= m"]), ] ], [ @@ -2828,11 +2840,7 @@ def test_vmap_error(self): "lu", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}", lax.linalg.lu, arg_descriptors=[RandArg(shape, dtype)], - polymorphic_shapes=[poly], - # TODO(b/360788062): Remove once the forward compatibility window is - # closed. - override_jax_config_flags={ - "jax_export_ignore_forward_compatibility": True}) + polymorphic_shapes=[poly]) for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() for shape, poly in [ ((5, 4), "m, n"), @@ -2844,6 +2852,34 @@ def test_vmap_error(self): ((2, 3, 4, 5), "b1, b2, m, n"), ] ], + [ + PolyHarness( # pylint: disable=g-complex-comprehension + "eigh", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}_{lower=}", + lambda x, lower: lax.linalg.eigh(x, lower=lower), + arg_descriptors=[RandArg(shape, dtype), StaticArg(lower)], + polymorphic_shapes=[poly]) + for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() + for lower in [True, False] + for shape, poly in [ + ((4, 4), "n, n"), + ((2, 3, 4, 4), "b1, b2, ..."), + ((2, 3, 4, 4), "b1, b2, n, n"), + ] + ], + [ + PolyHarness( # pylint: disable=g-complex-comprehension + "eigh_shape_error", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}", + lambda x: lax.linalg.eigh(x, symmetrize_input=False), + arg_descriptors=[RandArg(shape, dtype)], + polymorphic_shapes=[poly], + expect_error=(ValueError, "Argument to symmetric eigendecomposition")) + for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() + for shape, poly in [ + ((4, 5), "m, n"), + ((2, 3, 4, 5), "b1, b2, ..."), + ((2, 3, 4, 5), "b1, b2, m, n"), + ] + ], [ # The random primitive tests, with threefry (both partitionable and # non-partitionable), and unsafe_rbg. @@ -3465,9 +3501,6 @@ def test_harness(self, harness: PolyHarness): # Exclude some harnesses that are known to fail for native serialization # Set of harness.group_name:platform that are implemented with custom call custom_call_harnesses = { - "householder_product:gpu", - "vmap_geqrf:gpu", # used for linalg.qr - "vmap_qr:gpu", "qr:gpu", "vmap_svd:gpu", } name_device_key = f"{harness.group_name}:{jtu.device_under_test()}" @@ -3478,14 +3511,7 @@ def test_harness(self, harness: PolyHarness): # polymorphism for some new primitives as we add them. This check is # required so that we can still run the test suite with older versions of # jaxlib. - version_gated = { - # TODO(danfm): remove these checks when jaxlib 0.4.32 is released. - "lu_pivots_to_permutation:gpu": (0, 4, 32), - "lu_pivots_to_permutation_error:gpu": (0, 4, 32), - "lu:gpu": (0, 4, 32), - "vmap_lu:gpu": (0, 4, 32), - "vmap_custom_linear_solve:gpu": (0, 4, 32), - } + version_gated = {} if version_gated.get(name_device_key, jaxlib_version) > jaxlib_version: raise unittest.SkipTest(f"shape polymorphism not supported by jaxlib version {jaxlib_version}") @@ -3496,13 +3522,6 @@ def test_harness(self, harness: PolyHarness): if "nr_fft_lengths_2" in harness.fullname: raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU") - if harness.group_name == "vmap_eigh" and jtu.test_device_matches(["gpu"]): - # For eigh on GPU with shape polymorphism under native serialization, - # we use a different lowering for small matrices. - shape = harness.original_harness.params["shape"] - if 0 < shape[-1] <= 32: - harness.check_result = False - if harness.group_name == "vmap_eigh": raise unittest.SkipTest( "Should not compare eigendecompositions for equality directly" @@ -3534,6 +3553,12 @@ def test_harness(self, harness: PolyHarness): if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]): raise unittest.SkipTest("JAX implements eig only on CPU.") + if (harness.group_name == "eigh" and + not harness.polymorphic_shapes[0].endswith("...") and + jtu.test_device_matches(["tpu"])): + raise unittest.SkipTest( + "Shape polymorphsim for Eigh is only supported for batch dimensions on TPU.") + config_flags = harness.override_jax_config_flags # Update this here rather than in harness object because vmap_random_gamma is derived # from test_harnesses.all_harnesses, which strips override_jax_config_flags. @@ -3544,12 +3569,6 @@ def test_harness(self, harness: PolyHarness): if "cholesky" in harness.group_name and jtu.test_device_matches(["tpu"]): harness.tol = 5e-5 - # TODO(b/360788062): Clean up after the compatibility period. - if harness.group_name in [ - "lu", "vmap_lu", "custom_linear_solve", "vmap_custom_linear_solve" - ] and jtu.test_device_matches(["gpu"]): - config_flags = {**config_flags, "jax_export_ignore_forward_compatibility": True} - with jtu.global_config_context(**config_flags): harness.run_test(self) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 0c1155ddf1ab..0ddbd4b50bd0 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -37,6 +37,7 @@ from jax._src import core from jax._src import prng from jax._src import test_util as jtu +from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals from jax._src.mesh import AbstractMesh @@ -2636,6 +2637,8 @@ def fwd(a): @jtu.with_config(jax_use_shardy_partitioner=True) +# TODO(phawkins): enable this test unconditionally once shardy is the default. +@unittest.skipIf(sdy is None, "shardy is not enabled") class SdyIntegrationTest(jtu.JaxTestCase): # Verify we can lower to a `ManualComputationOp`. def test_shardy_collective_permute(self): diff --git a/tests/state_test.py b/tests/state_test.py index d04a674ab8c0..0d6cddfc88c8 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -739,6 +739,20 @@ def f(ref): in_avals = [shaped_array_ref((), jnp.float32)] pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) + def test_partial_discharge(self): + def f(a_ref, b_ref): + a_ref[...] = jnp.array(0., dtype=jnp.float32) + b_ref[...] = jnp.array(1., dtype=jnp.float32) + return a_ref[...], b_ref[...] + + scalar_ref = shaped_array_ref((), jnp.float32) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(f), [scalar_ref, scalar_ref]) + + discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) + prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns) + self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr)) + self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr)) if CAN_USE_HYPOTHESIS: @@ -1061,6 +1075,27 @@ def false_fun(): out = jax.jit(f)(False) self.assertTupleEqual(out, (0., 5.)) + def test_cond_discharge(self): + def f0(pred, x_ref, y_ref): + def true_fun(): + x_ref[...] = 1. + def false_fun(): + y_ref[...] = 2. + lax.cond(pred, true_fun, false_fun) + return x_ref[...], y_ref[...] + ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.)) + jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True]) + # Effects on y_ref were discharged away but not the effects on x_ref + self.assertEqual(f_jaxpr.effects, {ReadEffect(1), WriteEffect(1), ReadEffect(2), WriteEffect(2)}) + self.assertEqual(jaxpr.effects, {ReadEffect(1), WriteEffect(1)}) + # x_ref arg is still a reference but y_ref is discharged + self.assertNotIsInstance(jaxpr.invars[2].aval, AbstractRef) + self.assertIsInstance(jaxpr.invars[1].aval, AbstractRef) + # x_ref value is returned as part of the discharged refs set. + self.assertLen(f_jaxpr.out_avals, 2) + self.assertLen(jaxpr.outvars, 3) + def test_cond_with_ref_reuse(self): def f(pred): def body(x_ref): @@ -1079,6 +1114,25 @@ def false_fun(): expected_false = 2. self.assertAllClose(out_false, expected_false) + def test_cond_readonly_refs(self): + def f(pred): + def body(refs): + x_ref, y_ref, z_ref = refs + def true_fun(): + y_ref[()] = x_ref[()] + def false_fun(): + y_ref[()] = x_ref[()] + z_ref[()] + lax.cond(pred, true_fun, false_fun) + return run_state(body)((1., 0., 2.)) + jaxpr = jax.make_jaxpr(f)(True).jaxpr + [run_state_eqn] = jaxpr.eqns + *_, cond_eqn = discharge_state(run_state_eqn.params["jaxpr"], ())[0].eqns + self.assertIs(cond_eqn.primitive, lax.cond_p) + self.assertLen(cond_eqn.invars, 4) # pred + 3x ref values + self.assertLen(cond_eqn.outvars, 1) # only the updated ref value + self.assertAllClose(jax.jit(f)(True), (1., 1., 2.)) + self.assertAllClose(jax.jit(f)(False), (1., 3., 2.)) + def test_simple_cond_using_multiple_refs_with_interleaved_consts(self): def f(pred): def body(refs): diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index c5342a99365d..a8e54537d14f 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -24,7 +24,6 @@ import jax from jax import flatten_util from jax import tree_util -from jax._src.lib import xla_extension_version from jax._src import test_util as jtu from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp @@ -485,10 +484,8 @@ def testFlattenUpTo(self, tree, xs, expected): [([1], (2,), {"a": [1]})], re.escape("Custom node type mismatch"), ), - *( - [] - if xla_extension_version < 288 - else [(None, [2], re.escape("Expected None, got [2]."))] + ( + (None, [2], re.escape("Expected None, got [2].")) ), ) def testFlattenUpToErrors(self, tree, xs, error): diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 7e778cc99d2c..9c64734645d2 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -25,7 +25,6 @@ from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge as xb -from jax._src.interpreters import xla from jax._src.lib import xla_client as xc config.parse_flags_with_absl() @@ -120,19 +119,6 @@ def test_deterministic_serialization(self): # Map order does not matter. self.assertEqual(c1str, c2.SerializeAsString()) - def test_parameter_replication_default(self): - c = xc.XlaBuilder("test") - _ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ())) - built_c = c.Build() - assert "replication" not in built_c.as_hlo_text() - - def test_parameter_replication(self): - c = xc.XlaBuilder("test") - _ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "", - False) - built_c = c.Build() - assert "parameter_replication={false}" in built_c.as_hlo_text() - def test_local_devices(self): self.assertNotEmpty(xb.local_devices()) with self.assertRaisesRegex(ValueError, "Unknown process_index 100"): diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index ca84599d6cff..d0acbbf2889f 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "0e732d65bdf8fb158c7b01e18139e5ba59ca7025" -XLA_SHA256 = "16e4aeca04ce94bd0fcfa32990d76be3779c026c2b649478bf27d0db0679e65c" +XLA_COMMIT = "f12e5d4d538ef7b15ec56ef31b942a6f14d19634" +XLA_SHA256 = "f82822bb427338866463641c793500c898d6797f3880458a83172d8058482c90" def repo(): tf_http_archive( From f40dec60a60ae1257acf07d71b1f771a29f2a89c Mon Sep 17 00:00:00 2001 From: Jehandad Khan Date: Fri, 11 Oct 2024 21:10:47 +0000 Subject: [PATCH 2/2] remove container --- .github/workflows/ci-build.yaml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 03e7a040570f..14f54944a9f0 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -39,8 +39,6 @@ jobs: build: name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" runs-on: ROCM-Ubuntu - container: - image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 timeout-minutes: 60 strategy: matrix: @@ -58,10 +56,6 @@ jobs: num_generated_cases: 1 steps: - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - - name: Image Setup - run: | - apt update - apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: