Skip to content

Commit

Permalink
Re-factor build CLI to a subcommand based approach
Browse files Browse the repository at this point in the history
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.

Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.

There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time.

Usage:
* Building `jaxlib`:
```
python build/build.py build --wheels=jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building multiple packages:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```

For more details on each argument and to see available options, run:
```
python build/build.py build --help
```
or
```
python build/build.py requirements_update --help
```

PiperOrigin-RevId: 700075411
  • Loading branch information
nitins17 authored and Google-ML-Automation committed Nov 25, 2024
1 parent f22bafa commit 6761512
Show file tree
Hide file tree
Showing 10 changed files with 713 additions and 525 deletions.
5 changes: 5 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ build:macos_cache_push --config=macos_cache --remote_upload_local_results=true -
build:ci_linux_x86_64 --config=avx_linux --config=avx_posix
build:ci_linux_x86_64 --config=mkl_open_source_only
build:ci_linux_x86_64 --config=clang --verbose_failures=true
build:ci_linux_x86_64 --color=yes

# TODO(b/356695103): We do not have a CPU only toolchain so we use the CUDA
# toolchain for both CPU and GPU builds.
Expand All @@ -203,6 +204,7 @@ build:ci_linux_x86_64_cuda --config=ci_linux_x86_64
# Linux Aarch64 CI configs
build:ci_linux_aarch64_base --config=clang --verbose_failures=true
build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10"
build:ci_linux_aarch64_base --color=yes

build:ci_linux_aarch64 --config=ci_linux_aarch64_base
build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain"
Expand All @@ -221,18 +223,21 @@ build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm
build:ci_darwin_x86_64 --macos_minimum_os=10.14
build:ci_darwin_x86_64 --config=macos_cache_push
build:ci_darwin_x86_64 --verbose_failures=true
build:ci_darwin_x86_64 --color=yes

# Mac Arm64 CI configs
build:ci_darwin_arm64 --macos_minimum_os=11.0
build:ci_darwin_arm64 --config=macos_cache_push
build:ci_darwin_arm64 --verbose_failures=true
build:ci_darwin_arm64 --color=yes

# Windows x86 CI configs
build:ci_windows_amd64 --config=avx_windows
build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE
build:ci_windows_amd64 --color=yes

# #############################################################################
# RBE config options below. These inherit the CI configs above and set the
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/asan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
run: |
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
python build/build.py \
python build/build.py build --wheels=jaxlib --verbose \
--bazel_options=--color=yes \
--bazel_options=--copt=-fsanitize=address \
--clang_path=/usr/bin/clang-18
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/wheel_win_x64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
python -m pip install -r build/test-requirements.txt
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py `
python.exe build\build.py build --wheels=jaxlib `
--bazel_options=--color=yes `
--bazel_options=--config=win_clang `
--verbose
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/windows_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ jobs:
python -m pip install -r build/test-requirements.txt
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py `
python.exe build\build.py build --wheels=jaxlib `
--bazel_options=--color=yes `
--bazel_options=--config=win_clang
--bazel_options=--config=win_clang `
--verbose
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
with:
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
`platforms` instead.
* Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a `TypeError`.
* Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
replaces previous build.py usage. Run `python build/build.py --help` for
more details. Brief overview of the new subcommand options:
* `build`: Builds JAX wheel packages. For e.g., `python build/build.py build --wheels=jaxlib,jax-cuda-pjrt`
* `requirements_update`: Updates requirements_lock.txt files.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
Expand Down
Loading

0 comments on commit 6761512

Please sign in to comment.