From be38a8f8af9607fe40193f4f49d2758f5b392b1e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 30 Oct 2024 10:20:15 -0700 Subject: [PATCH] Re-factor build CLI to a subcommand based approach This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script. Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions. There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time. Usage: * Building `jaxlib`: ``` python build/build.py build --wheels=jaxlib --python_version=3.10 ``` * Building `jax-cuda-plugin`: ``` python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10 ``` * Building multiple packages: ``` python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10 ``` * Building `jax-rocm-pjrt`: ``` python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm ``` * Using a local XLA path: ``` python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla ``` * Updating requirements_lock.txt files: ``` python build/build.py requirements_update --python_version=3.10 ``` For more details on each argument and to see available options, run: ``` python build/build.py build --help ``` or ``` python build/build.py requirements_update --help ``` PiperOrigin-RevId: 691466647 --- .bazelrc | 5 + .github/workflows/asan.yaml | 2 +- .github/workflows/wheel_win_x64.yml | 2 +- .github/workflows/windows_ci.yml | 5 +- CHANGELOG.md | 5 + build/build.py | 957 +++++++++++++++------------- build/tools/command.py | 111 ++++ build/tools/utils.py | 89 ++- docs/developer.md | 60 +- third_party/xla/workspace.bzl | 2 +- 10 files changed, 713 insertions(+), 525 deletions(-) create mode 100644 build/tools/command.py diff --git a/.bazelrc b/.bazelrc index 98bca5901d47..6ef7d4493937 100644 --- a/.bazelrc +++ b/.bazelrc @@ -183,6 +183,7 @@ build:macos_cache_push --config=macos_cache --remote_upload_local_results=true - build:ci_linux_x86_64 --config=avx_linux --config=avx_posix build:ci_linux_x86_64 --config=mkl_open_source_only build:ci_linux_x86_64 --config=clang --verbose_failures=true +build:ci_linux_x86_64 --color=yes # TODO(b/356695103): We do not have a CPU only toolchain so we use the CUDA # toolchain for both CPU and GPU builds. @@ -203,6 +204,7 @@ build:ci_linux_x86_64_cuda --config=ci_linux_x86_64 # Linux Aarch64 CI configs build:ci_linux_aarch64_base --config=clang --verbose_failures=true build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10" +build:ci_linux_aarch64_base --color=yes build:ci_linux_aarch64 --config=ci_linux_aarch64_base build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" @@ -221,11 +223,13 @@ build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm build:ci_darwin_x86_64 --macos_minimum_os=10.14 build:ci_darwin_x86_64 --config=macos_cache_push build:ci_darwin_x86_64 --verbose_failures=true +build:ci_darwin_x86_64 --color=yes # Mac Arm64 CI configs build:ci_darwin_arm64 --macos_minimum_os=11.0 build:ci_darwin_arm64 --config=macos_cache_push build:ci_darwin_arm64 --verbose_failures=true +build:ci_darwin_arm64 --color=yes # Windows x86 CI configs build:ci_windows_amd64 --config=avx_windows @@ -233,6 +237,7 @@ build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=tru build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE +build:ci_windows_amd64 --color=yes # ############################################################################# # RBE config options below. These inherit the CI configs above and set the diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index ea87d4e29e40..d261ba3a09c2 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -65,7 +65,7 @@ jobs: run: | source ${GITHUB_WORKSPACE}/venv/bin/activate cd jax - python build/build.py \ + python build/build.py build --wheels=jaxlib --verbose \ --bazel_options=--color=yes \ --bazel_options=--copt=-fsanitize=address \ --clang_path=/usr/bin/clang-18 diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 2b4a616e224a..3904bf1b8f10 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -40,7 +40,7 @@ jobs: python -m pip install -r build/test-requirements.txt python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py ` + python.exe build\build.py build --wheels=jaxlib ` --bazel_options=--color=yes ` --bazel_options=--config=win_clang ` --verbose diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 3173b81e6819..4c404ef4cb75 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -49,9 +49,10 @@ jobs: python -m pip install -r build/test-requirements.txt python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py ` + python.exe build\build.py build --wheels=jaxlib ` --bazel_options=--color=yes ` - --bazel_options=--config=win_clang + --bazel_options=--config=win_clang ` + --verbose - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index be9aaebcd615..ce8b040439c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `platforms` instead. * Hashing of tracers, which has been deprecated since version 0.4.30, now results in a `TypeError`. + * Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and + replaces previous build.py usage. Run `python build/build.py --help` for + more details. Brief overview of the new subcommand options: + * `build`: Builds JAX wheel packages. For e.g., `python build/build.py build --wheels=jaxlib,jax-cuda-pjrt` + * `requirements_update`: Updates requirements_lock.txt files. * {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel` on the function inputs. diff --git a/build/build.py b/build/build.py index 62e4217c10a2..12ad0fa3b011 100755 --- a/build/build.py +++ b/build/build.py @@ -14,94 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# Helper script for building JAX's libjax easily. +# CLI for building JAX wheel packages from source and for updating the +# requirements_lock.txt files import argparse +import asyncio import logging import os import platform -import textwrap +import sys +import copy -from tools import utils +from tools import command, utils +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) -def write_bazelrc(*, remote_build, - cuda_version, cudnn_version, rocm_toolkit_path, - cpu, cuda_compute_capabilities, - rocm_amdgpu_targets, target_cpu_features, - wheel_cpu, enable_mkl_dnn, use_clang, clang_path, - clang_major_version, python_version, - enable_cuda, enable_nccl, enable_rocm, - use_cuda_nvcc): - - with open("../.jax_configure.bazelrc", "w") as f: - if not remote_build: - f.write(textwrap.dedent("""\ - build --strategy=Genrule=standalone - """)) - - if use_clang: - f.write(f'build --action_env CLANG_COMPILER_PATH="{clang_path}"\n') - f.write(f'build --repo_env CC="{clang_path}"\n') - f.write(f'build --repo_env BAZEL_COMPILER="{clang_path}"\n') - f.write('build --copt=-Wno-error=unused-command-line-argument\n') - if clang_major_version in (16, 17, 18): - # Necessary due to XLA's old version of upb. See: - # https://github.com/openxla/xla/blob/c4277a076e249f5b97c8e45c8cb9d1f554089d76/.bazelrc#L505 - f.write("build --copt=-Wno-gnu-offsetof-extensions\n") - - if rocm_toolkit_path: - f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n" - .format(rocm_toolkit_path=rocm_toolkit_path)) - if rocm_amdgpu_targets: - f.write( - f'build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"\n') - if cpu is not None: - f.write(f"build --cpu={cpu}\n") - - if target_cpu_features == "release": - if wheel_cpu == "x86_64": - f.write("build --config=avx_windows\n" if utils.is_windows() - else "build --config=avx_posix\n") - elif target_cpu_features == "native": - if utils.is_windows(): - print("--target_cpu_features=native is not supported on Windows; ignoring.") - else: - f.write("build --config=native_arch_posix\n") - - if enable_mkl_dnn: - f.write("build --config=mkl_open_source_only\n") - if enable_cuda: - f.write("build --config=cuda\n") - if use_cuda_nvcc: - f.write("build --config=build_cuda_with_nvcc\n") - else: - f.write("build --config=build_cuda_with_clang\n") - f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") - if not enable_nccl: - f.write("build --config=nonccl\n") - if cuda_version: - f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n" - .format(cuda_version=cuda_version)) - if cudnn_version: - f.write("build --repo_env HERMETIC_CUDNN_VERSION=\"{cudnn_version}\"\n" - .format(cudnn_version=cudnn_version)) - if cuda_compute_capabilities: - f.write( - f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n') - if enable_rocm: - f.write("build --config=rocm_base\n") - if not enable_nccl: - f.write("build --config=nonccl\n") - if use_clang: - f.write("build --config=rocm\n") - f.write(f"build --action_env=CLANG_COMPILER_PATH={clang_path}\n") - if python_version: - f.write( - "build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format( - python_version=python_version)) BANNER = r""" _ _ __ __ | | / \ \ \/ / @@ -112,421 +43,559 @@ def write_bazelrc(*, remote_build, """ EPILOG = """ +From the root directory of the JAX repository, run + `python build/build.py build --wheels=` to build JAX + artifacts. -From the 'build' directory in the JAX repository, run - python build.py -or - python3 build.py -to download and build JAX's XLA (jaxlib) dependency. -""" + Multiple wheels can be built with a single invocation of the CLI. + E.g. python build/build.py build --wheels=jaxlib,jax-cuda-plugin + To update the requirements_lock.txt files, run + `python build/build.py requirements_update` +""" -def _parse_string_as_bool(s): - """Parses a string as a boolean argument.""" - lower = s.lower() - if lower == "true": - return True - elif lower == "false": - return False - else: - raise ValueError(f"Expected either 'true' or 'false'; got {s}") +# Define the build target for each wheel. +WHEEL_BUILD_TARGET_DICT = { + "jaxlib": "//jaxlib/tools:build_wheel", + "jax-cuda-plugin": "//jaxlib/tools:build_gpu_kernels_wheel", + "jax-cuda-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", + "jax-rocm-plugin": "//jaxlib/tools:build_gpu_kernels_wheel", + "jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", +} -def add_boolean_argument(parser, name, default=False, help_str=None): - """Creates a boolean flag.""" - group = parser.add_mutually_exclusive_group() - group.add_argument( - "--" + name, - nargs="?", - default=default, - const=True, - type=_parse_string_as_bool, - help=help_str) - group.add_argument("--no" + name, dest=name, action="store_false") +def add_global_arguments(parser: argparse.ArgumentParser): + """Adds all the global arguments that applies to all the CLI subcommands.""" + parser.add_argument( + "--python_version", + type=str, + choices=["3.10", "3.11", "3.12", "3.13"], + default=f"{sys.version_info.major}.{sys.version_info.minor}", + help= + """ + Hermetic Python version to use. Default is to use the version of the + Python binary that executed the CLI. + """, + ) + bazel_group = parser.add_argument_group('Bazel Options') + bazel_group.add_argument( + "--bazel_path", + type=str, + default="", + help=""" + Path to the Bazel binary to use. The default is to find bazel via the + PATH; if none is found, downloads a fresh copy of Bazel from GitHub. + """, + ) -def _get_editable_output_paths(output_path): - """Returns the paths to the editable wheels.""" - return ( - os.path.join(output_path, "jaxlib"), - os.path.join(output_path, "jax_gpu_pjrt"), - os.path.join(output_path, "jax_gpu_plugin"), + bazel_group.add_argument( + "--bazel_startup_options", + action="append", + default=[], + help=""" + Additional startup options to pass to Bazel, can be specified multiple + times to pass multiple options. + E.g. --bazel_startup_options='--nobatch' + """, ) + bazel_group.add_argument( + "--bazel_options", + action="append", + default=[], + help=""" + Additional build options to pass to Bazel, can be specified multiple + times to pass multiple options. + E.g. --bazel_options='--local_resources=HOST_CPUS' + """, + ) -def main(): - cwd = os.getcwd() - parser = argparse.ArgumentParser( - description="Builds jaxlib from source.", epilog=EPILOG) - add_boolean_argument( - parser, - "verbose", - default=False, - help_str="Should we produce verbose debugging output?") - parser.add_argument( - "--bazel_path", - help="Path to the Bazel binary to use. The default is to find bazel via " - "the PATH; if none is found, downloads a fresh copy of bazel from " - "GitHub.") - parser.add_argument( - "--python_bin_path", - help="Path to Python binary whose version to match while building with " - "hermetic python. The default is the Python interpreter used to run the " - "build script. DEPRECATED: use --python_version instead.") parser.add_argument( - "--target_cpu_features", - choices=["release", "native", "default"], - default="release", - help="What CPU features should we target? 'release' enables CPU " - "features that should be enabled for a release build, which on " - "x86-64 architectures enables AVX. 'native' enables " - "-march=native, which generates code targeted to use all " - "features of the current machine. 'default' means don't opt-in " - "to any architectural features and use whatever the C compiler " - "generates by default.") - add_boolean_argument( - parser, - "use_clang", - default = "true", - help_str=( - "DEPRECATED: This flag is redundant because clang is " - "always used as default compiler." - ), + "--dry_run", + action="store_true", + help="Prints the Bazel command that is going to be executed.", ) + parser.add_argument( - "--clang_path", - help=( - "Path to clang binary to use. The default is " - "to find clang via the PATH." - ), - ) - add_boolean_argument( - parser, - "enable_mkl_dnn", - default=True, - help_str="Should we build with MKL-DNN enabled?", + "--verbose", + action="store_true", + help="Produce verbose output for debugging.", ) - add_boolean_argument( - parser, - "enable_cuda", - help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN." + + +def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): + """Adds all the arguments that applies to the artifact subcommands.""" + parser.add_argument( + "--wheels", + type=str, + default="jaxlib", + help= + """ + A comma separated list of JAX wheels to build. E.g: --wheels="jaxlib", + --wheels="jaxlib,jax-cuda-plugin", etc. + Valid options are: jaxlib, jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt, + jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt + """, ) - add_boolean_argument( - parser, - "use_cuda_nvcc", - default=True, - help_str=( - "Should we build CUDA code using NVCC compiler driver? The default value " - "is true. If --nouse_cuda_nvcc flag is used then CUDA code is built " - "by clang compiler." - ), + + parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' build instead of a wheel.", ) - add_boolean_argument( - parser, - "build_gpu_plugin", - default=False, - help_str=( - "Are we building the gpu plugin in addition to jaxlib? The GPU " - "plugin is still experimental and is not ready for use yet." - ), + + parser.add_argument( + "--output_path", + type=str, + default=os.path.join(os.getcwd(), "dist"), + help="Directory to which the JAX wheel packages should be written.", ) + parser.add_argument( - "--build_gpu_kernel_plugin", - choices=["cuda", "rocm"], - default="", - help=( - "Specify 'cuda' or 'rocm' to build the respective kernel plugin." - " When this flag is set, jaxlib will not be built." - ), + "--configure_only", + action="store_true", + help=""" + If true, writes the Bazel options to the .jax_configure.bazelrc file but + does not build the artifacts. + """, ) - add_boolean_argument( - parser, - "build_gpu_pjrt_plugin", - default=False, - help_str=( - "Are we building the cuda/rocm pjrt plugin? jaxlib will not be built " - "when this flag is True." - ), + + # CUDA Options + cuda_group = parser.add_argument_group('CUDA Options') + cuda_group.add_argument( + "--cuda_version", + type=str, + help= + """ + Hermetic CUDA version to use. Default is to use the version specified + in the .bazelrc. + """, ) - parser.add_argument( - "--gpu_plugin_cuda_version", - choices=["12"], + + cuda_group.add_argument( + "--cuda_major_version", + type=str, default="12", - help="Which CUDA major version the gpu plugin is for.") - parser.add_argument( - "--gpu_plugin_rocm_version", - choices=["60"], - default="60", - help="Which ROCM major version the gpu plugin is for.") - add_boolean_argument( - parser, - "enable_rocm", - help_str="Should we build with ROCm enabled?") - add_boolean_argument( - parser, - "enable_nccl", - default=True, - help_str="Should we build with NCCL enabled? Has no effect for non-CUDA " - "builds.") - add_boolean_argument( - parser, - "remote_build", - default=False, - help_str="Should we build with RBE (Remote Build Environment)?") - parser.add_argument( - "--cuda_version", - default=None, - help="CUDA toolkit version, e.g., 12.3.2") - parser.add_argument( + help= + """ + Which CUDA major version should the wheel be tagged as? Auto-detected if + --cuda_version is set. When --cuda_version is not set, the default is to + set the major version to 12 to match the default in .bazelrc. + """, + ) + + cuda_group.add_argument( "--cudnn_version", - default=None, - help="CUDNN version, e.g., 8.9.7.29") - # Caution: if changing the default list of CUDA capabilities, you should also - # update the list in .bazelrc, which is used for wheel builds. - parser.add_argument( + type=str, + help= + """ + Hermetic cuDNN version to use. Default is to use the version specified + in the .bazelrc. + """, + ) + + cuda_group.add_argument( + "--disable_nccl", + action="store_true", + help="Should NCCL be disabled?", + ) + + cuda_group.add_argument( "--cuda_compute_capabilities", + type=str, default=None, - help="A comma-separated list of CUDA compute capabilities to support.") - parser.add_argument( + help= + """ + A comma-separated list of CUDA compute capabilities to support. Default + is to use the values specified in the .bazelrc. + """, + ) + + cuda_group.add_argument( + "--build_cuda_with_clang", + action="store_true", + help=""" + Should CUDA code be compiled using Clang? The default behavior is to + compile CUDA with NVCC. + """, + ) + + # ROCm Options + rocm_group = parser.add_argument_group('ROCm Options') + rocm_group.add_argument( + "--rocm_version", + type=str, + default="60", + help="ROCm version to use", + ) + + rocm_group.add_argument( "--rocm_amdgpu_targets", + type=str, default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100", - help="A comma-separated list of ROCm amdgpu targets to support.") - parser.add_argument( + help="A comma-separated list of ROCm amdgpu targets to support.", + ) + + rocm_group.add_argument( "--rocm_path", - default=None, - help="Path to the ROCm toolkit.") - parser.add_argument( - "--bazel_startup_options", - action="append", default=[], - help="Additional startup options to pass to bazel.") - parser.add_argument( - "--bazel_options", - action="append", default=[], - help="Additional options to pass to the main Bazel command to be " - "executed, e.g. `run`.") - parser.add_argument( - "--output_path", - default=os.path.join(cwd, "dist"), - help="Directory to which the jaxlib wheel should be written") - parser.add_argument( - "--target_cpu", - default=None, - help="CPU platform to target. Default is the same as the host machine. " - "Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.") - parser.add_argument( - "--editable", + type=str, + default="", + help="Path to the ROCm toolkit.", + ) + + # Compile Options + compile_group = parser.add_argument_group('Compile Options') + + compile_group.add_argument( + "--use_clang", + type=utils._parse_string_as_bool, + default="true", + const=True, + nargs="?", + help=""" + Whether to use Clang as the compiler. Not recommended to set this to + False as JAX uses Clang as the default compiler. + """, + ) + + compile_group.add_argument( + "--clang_path", + type=str, + default="", + help=""" + Path to the Clang binary to use. + """, + ) + + compile_group.add_argument( + "--disable_mkl_dnn", action="store_true", - help="Create an 'editable' jaxlib build instead of a wheel.") - parser.add_argument( - "--python_version", + help=""" + Disables MKL-DNN. + """, + ) + + compile_group.add_argument( + "--target_cpu_features", + choices=["release", "native", "default"], + default="release", + help=""" + What CPU features should we target? Release enables CPU features that + should be enabled for a release build, which on x86-64 architectures + enables AVX. Native enables -march=native, which generates code targeted + to use all features of the current machine. Default means don't opt-in + to any architectural features and use whatever the C compiler generates + by default. + """, + ) + + compile_group.add_argument( + "--target_cpu", default=None, - help="hermetic python version, e.g., 3.10") - add_boolean_argument( - parser, - "configure_only", - default=False, - help_str="If true, writes a .bazelrc file but does not build jaxlib.") - add_boolean_argument( - parser, - "requirements_update", - default=False, - help_str="If true, writes a .bazelrc and updates requirements_lock.txt " - "for a corresponding version of Python but does not build " - "jaxlib.") - add_boolean_argument( - parser, - "requirements_nightly_update", - default=False, - help_str="Same as update_requirements, but will consider dev, nightly " - "and pre-release versions of packages.") + help="CPU platform to target. Default is the same as the host machine.", + ) + + compile_group.add_argument( + "--local_xla_path", + type=str, + default=os.environ.get("JAXCI_XLA_GIT_DIR", ""), + help=""" + Path to local XLA repository to use. If not set, Bazel uses the XLA at + the pinned version in workspace.bzl. + """, + ) + +async def main(): + parser = argparse.ArgumentParser( + description=r""" + CLI for building JAX wheel packages from source and for updating the + requirements_lock.txt files + """, + epilog=EPILOG, + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + # Create subparsers for build and requirements_update + subparsers = parser.add_subparsers(dest="command", required=True) + + # requirements_update subcommand + requirements_update_parser = subparsers.add_parser( + "requirements_update", help="Updates the requirements_lock.txt files" + ) + requirements_update_parser.add_argument( + "--nightly_update", + action="store_true", + help=""" + If true, updates requirements_lock.txt for a corresponding version of + Python and will consider dev, nightly and pre-release versions of + packages. + """, + ) + add_global_arguments(requirements_update_parser) + + # Artifact build subcommand + build_artifact_parser = subparsers.add_parser( + "build", help="Builds the jaxlib, plugin, and pjrt artifact" + ) + add_artifact_subcommand_arguments(build_artifact_parser) + add_global_arguments(build_artifact_parser) + + arch = platform.machine() + os_name = platform.system().lower() args = parser.parse_args() - logging.basicConfig() + logger.info("%s", BANNER) + if args.verbose: - logger.setLevel(logging.DEBUG) + logging.getLogger().setLevel(logging.DEBUG) + logger.info("Verbose logging enabled") + + bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path) + + logging.debug("Bazel path: %s", bazel_path) + logging.debug("Bazel version: %s", bazel_version) + + executor = command.SubprocessExecutor() - if args.enable_cuda and args.enable_rocm: - parser.error("--enable_cuda and --enable_rocm cannot be enabled at the same time.") + # Start constructing the Bazel command + bazel_command_base = command.CommandBuilder(bazel_path) + + if args.bazel_startup_options: + logging.debug( + "Additional Bazel startup options: %s", args.bazel_startup_options + ) + for option in args.bazel_startup_options: + bazel_command_base.append(option) + + bazel_command_base.append("run") + + if args.python_version: + logging.debug("Hermetic Python version: %s", args.python_version) + bazel_command_base.append( + f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}" + ) - print(BANNER) + # Enable verbose failures. + bazel_command_base.append("--verbose_failures=true") + + # Requirements update subcommand execution + if args.command == "requirements_update": + requirements_command = copy.deepcopy(bazel_command_base) + if args.bazel_options: + logging.debug( + "Using additional build options: %s", args.bazel_options + ) + for option in args.bazel_options: + requirements_command.append(option) + + if args.nightly_update: + logging.info( + "--nightly_update is set. Bazel will run" + " //build:requirements_nightly.update" + ) + requirements_command.append("//build:requirements_nightly.update") + else: + requirements_command.append("//build:requirements.update") - output_path = os.path.abspath(args.output_path) - os.chdir(os.path.dirname(__file__ or args.prog) or '.') + await executor.run(requirements_command.get_command_as_string(), args.dry_run) + sys.exit(0) - host_cpu = platform.machine() wheel_cpus = { "darwin_arm64": "arm64", "darwin_x86_64": "x86_64", "ppc": "ppc64le", "aarch64": "aarch64", } - # TODO(phawkins): support other bazel cpu overrides. - wheel_cpu = (wheel_cpus[args.target_cpu] if args.target_cpu is not None - else host_cpu) - - # Find a working Bazel. - bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path) - print(f"Bazel binary path: {bazel_path}") - print(f"Bazel version: {bazel_version}") - - if args.python_version: - python_version = args.python_version - else: - python_bin_path = utils.get_python_bin_path(args.python_bin_path) - print(f"Python binary path: {python_bin_path}") - python_version = utils.get_python_version(python_bin_path) - print("Python version: {}".format(".".join(map(str, python_version)))) - utils.check_python_version(python_version) - python_version = ".".join(map(str, python_version)) - - print("Use clang: {}".format("yes" if args.use_clang else "no")) - clang_path = args.clang_path - clang_major_version = None - if args.use_clang: - if not clang_path: - clang_path = utils.get_clang_path_or_exit() - print(f"clang path: {clang_path}") - clang_major_version = utils.get_clang_major_version(clang_path) - - print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no")) - print(f"Target CPU: {wheel_cpu}") - print(f"Target CPU features: {args.target_cpu_features}") - - rocm_toolkit_path = args.rocm_path - print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no")) - if args.enable_cuda: - if args.cuda_compute_capabilities is not None: - print(f"CUDA compute capabilities: {args.cuda_compute_capabilities}") - if args.cuda_version: - print(f"CUDA version: {args.cuda_version}") - if args.cudnn_version: - print(f"CUDNN version: {args.cudnn_version}") - print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no")) - - print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no")) - if args.enable_rocm: - if rocm_toolkit_path: - print(f"ROCm toolkit path: {rocm_toolkit_path}") - print(f"ROCm amdgpu targets: {args.rocm_amdgpu_targets}") - - write_bazelrc( - remote_build=args.remote_build, - cuda_version=args.cuda_version, - cudnn_version=args.cudnn_version, - rocm_toolkit_path=rocm_toolkit_path, - cpu=args.target_cpu, - cuda_compute_capabilities=args.cuda_compute_capabilities, - rocm_amdgpu_targets=args.rocm_amdgpu_targets, - target_cpu_features=args.target_cpu_features, - wheel_cpu=wheel_cpu, - enable_mkl_dnn=args.enable_mkl_dnn, - use_clang=args.use_clang, - clang_path=clang_path, - clang_major_version=clang_major_version, - python_version=python_version, - enable_cuda=args.enable_cuda, - enable_nccl=args.enable_nccl, - enable_rocm=args.enable_rocm, - use_cuda_nvcc=args.use_cuda_nvcc, + target_cpu = ( + wheel_cpus[args.target_cpu] if args.target_cpu is not None else arch ) - if args.requirements_update or args.requirements_nightly_update: - if args.requirements_update: - task = "//build:requirements.update" - else: # args.requirements_nightly_update - task = "//build:requirements_nightly.update" - update_command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true", task, *args.bazel_options]) - print(" ".join(update_command)) - utils.shell(update_command) - return - - if args.configure_only: - return - - print("\nBuilding XLA and installing it in the jaxlib source tree...") - - command_base = ( - bazel_path, - *args.bazel_startup_options, - "run", - "--verbose_failures=true", - *args.bazel_options, - ) - - if args.build_gpu_plugin and args.editable: - output_path_jaxlib, output_path_jax_pjrt, output_path_jax_kernel = ( - _get_editable_output_paths(output_path) + if args.local_xla_path: + logging.debug("Local XLA path: %s", args.local_xla_path) + bazel_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"") + + if args.target_cpu: + logging.debug("Target CPU: %s", args.target_cpu) + bazel_command_base.append(f"--cpu={args.target_cpu}") + + if args.disable_nccl: + logging.debug("Disabling NCCL") + bazel_command_base.append("--config=nonccl") + + git_hash = utils.get_githash() + + # Wheel build command execution + for wheel in args.wheels.split(","): + # Allow CUDA/ROCm wheels without the "jax-" prefix. + if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: + wheel = "jax-" + wheel + + if wheel not in WHEEL_BUILD_TARGET_DICT.keys(): + logging.error( + "Incorrect wheel name provided, valid choices are jaxlib," + " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," + " jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt" + ) + sys.exit(1) + + wheel_build_command = copy.deepcopy(bazel_command_base) + print("\n") + logger.info( + "Building %s for %s %s...", + wheel, + os_name, + arch, ) - else: - output_path_jaxlib = output_path - output_path_jax_pjrt = output_path - output_path_jax_kernel = output_path - - if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin: - build_cpu_wheel_command = [ - *command_base, - "//jaxlib/tools:build_wheel", - "--", - f"--output_path={output_path_jaxlib}", - f"--jaxlib_git_hash={utils.get_githash()}", - f"--cpu={wheel_cpu}", - ] - if args.build_gpu_plugin: - build_cpu_wheel_command.append("--skip_gpu_kernels") - if args.editable: - build_cpu_wheel_command.append("--editable") - print(" ".join(build_cpu_wheel_command)) - utils.shell(build_cpu_wheel_command) - - if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \ - (args.build_gpu_kernel_plugin == "rocm"): - build_gpu_kernels_command = [ - *command_base, - "//jaxlib/tools:build_gpu_kernels_wheel", - "--", - f"--output_path={output_path_jax_kernel}", - f"--jaxlib_git_hash={utils.get_githash()}", - f"--cpu={wheel_cpu}", - ] - if args.enable_cuda: - build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}") - build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") - elif args.enable_rocm: - build_gpu_kernels_command.append(f"--enable-rocm={args.enable_rocm}") - build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + + clang_path = "" + if args.use_clang: + clang_path = args.clang_path or utils.get_clang_path_or_exit() + clang_major_version = utils.get_clang_major_version(clang_path) + logging.debug( + "Using Clang as the compiler, clang path: %s, clang version: %s", + clang_path, + clang_major_version, + ) + + # Use double quotes around clang path to avoid path issues on Windows. + wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") + else: + logging.debug("Use Clang: False") + + # Do not apply --config=clang on Mac as these settings do not apply to + # Apple Clang. + if os_name != "darwin": + wheel_build_command.append("--config=clang") + + if not args.disable_mkl_dnn: + logging.debug("Enabling MKL DNN") + wheel_build_command.append("--config=mkl_open_source_only") + + if args.target_cpu_features == "release": + if arch in ["x86_64", "AMD64"]: + logging.debug( + "Using release cpu features: --config=avx_%s", + "windows" if os_name == "windows" else "posix", + ) + wheel_build_command.append( + "--config=avx_windows" + if os_name == "windows" + else "--config=avx_posix" + ) + elif wheel_build_command == "native": + if os_name == "windows": + logger.warning( + "--target_cpu_features=native is not supported on Windows;" + " ignoring." + ) + else: + logging.debug("Using native cpu features: --config=native_arch_posix") + wheel_build_command.append("--config=native_arch_posix") else: - raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") - if args.editable: - build_gpu_kernels_command.append("--editable") - print(" ".join(build_gpu_kernels_command)) - utils.shell(build_gpu_kernels_command) - - if args.build_gpu_plugin or args.build_gpu_pjrt_plugin: - build_pjrt_plugin_command = [ - *command_base, - "//jaxlib/tools:build_gpu_plugin_wheel", - "--", - f"--output_path={output_path_jax_pjrt}", - f"--jaxlib_git_hash={utils.get_githash()}", - f"--cpu={wheel_cpu}", - ] - if args.enable_cuda: - build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}") - build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") - elif args.enable_rocm: - build_pjrt_plugin_command.append(f"--enable-rocm={args.enable_rocm}") - build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + logging.debug("Using default cpu features") + + if "cuda" in wheel: + wheel_build_command.append("--config=cuda") + wheel_build_command.append( + f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" + ) + if args.build_cuda_with_clang: + logging.debug("Building CUDA with Clang") + wheel_build_command.append("--config=build_cuda_with_clang") + else: + logging.debug("Building CUDA with NVCC") + wheel_build_command.append("--config=build_cuda_with_nvcc") + + if args.cuda_version: + logging.debug("Hermetic CUDA version: %s", args.cuda_version) + wheel_build_command.append( + f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}" + ) + if args.cudnn_version: + logging.debug("Hermetic cuDNN version: %s", args.cudnn_version) + wheel_build_command.append( + f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}" + ) + if args.cuda_compute_capabilities: + logging.debug( + "Hermetic CUDA compute capabilities: %s", + args.cuda_compute_capabilities, + ) + wheel_build_command.append( + f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}" + ) + + if "rocm" in wheel: + wheel_build_command.append("--config=rocm_base") + if args.use_clang: + wheel_build_command.append("--config=rocm") + wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + if args.rocm_path: + logging.debug("ROCm tookit path: %s", args.rocm_path) + wheel_build_command.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") + if args.rocm_amdgpu_targets: + logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) + wheel_build_command.append( + f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" + ) + + # Append additional build options at the end to override any options set in + # .bazelrc or above. + if args.bazel_options: + logging.debug( + "Additional Bazel build options: %s", args.bazel_options + ) + for option in args.bazel_options: + wheel_build_command.append(option) + + with open(".jax_configure.bazelrc", "w") as f: + jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list()) + if not jax_configure_options: + logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") + sys.exit(1) + f.write(jax_configure_options) + logging.info("Bazel options written to .jax_configure.bazelrc") + + if args.configure_only: + logging.info("--configure_only is set so not running any Bazel commands.") else: - raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") - if args.editable: - build_pjrt_plugin_command.append("--editable") - print(" ".join(build_pjrt_plugin_command)) - utils.shell(build_pjrt_plugin_command) + # Append the build target to the Bazel command. + build_target = WHEEL_BUILD_TARGET_DICT[wheel] + wheel_build_command.append(build_target) + + wheel_build_command.append("--") + + output_path = args.output_path + logger.debug("Artifacts output directory: %s", output_path) + + if args.editable: + logger.info("Building an editable build") + output_path = os.path.join(output_path, wheel) + wheel_build_command.append("--editable") + + wheel_build_command.append(f'--output_path="{output_path}"') + wheel_build_command.append(f"--cpu={target_cpu}") + + if "cuda" in wheel: + wheel_build_command.append("--enable-cuda=True") + if args.cuda_version: + cuda_major_version = args.cuda_version.split(".")[0] + else: + cuda_major_version = args.cuda_major_version + wheel_build_command.append(f"--platform_version={cuda_major_version}") + + if "rocm" in wheel: + wheel_build_command.append("--enable-rocm=True") + wheel_build_command.append(f"--platform_version={args.rocm_version}") + + wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") - utils.shell([bazel_path] + args.bazel_startup_options + ["shutdown"]) + await executor.run(wheel_build_command.get_command_as_string(), args.dry_run) if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/build/tools/command.py b/build/tools/command.py new file mode 100644 index 000000000000..48a9bfc1c0d6 --- /dev/null +++ b/build/tools/command.py @@ -0,0 +1,111 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Helper script for the JAX build CLI for running subprocess commands. +import asyncio +import dataclasses +import datetime +import os +import logging +from typing import Dict, Optional + +logger = logging.getLogger(__name__) + +class CommandBuilder: + def __init__(self, base_command: str): + self.command = [base_command] + + def append(self, parameter: str): + self.command.append(parameter) + return self + + def get_command_as_string(self) -> str: + return " ".join(self.command) + + def get_command_as_list(self) -> list[str]: + return self.command + +@dataclasses.dataclass +class CommandResult: + """ + Represents the result of executing a subprocess command. + """ + + command: str + return_code: int = 2 # Defaults to not successful + logs: str = "" + start_time: datetime.datetime = dataclasses.field( + default_factory=datetime.datetime.now + ) + end_time: Optional[datetime.datetime] = None + + +async def _process_log_stream(stream, result: CommandResult): + """Logs the output of a subprocess stream.""" + while True: + line_bytes = await stream.readline() + if not line_bytes: + break + line = line_bytes.decode().rstrip() + result.logs += line + logger.info("%s", line) + + +class SubprocessExecutor: + """ + Manages execution of subprocess commands with reusable environment and logging. + """ + + def __init__(self, environment: Dict[str, str] = None): + """ + + Args: + environment: + """ + self.environment = environment or dict(os.environ) + + async def run(self, cmd: str, dry_run: bool = False) -> CommandResult: + """ + Executes a subprocess command. + + Args: + cmd: The command to execute. + dry_run: If True, prints the command instead of executing it. + + Returns: + A CommandResult instance. + """ + result = CommandResult(command=cmd) + if dry_run: + logger.info("[DRY RUN] %s", cmd) + result.return_code = 0 # Dry run is a success + return result + + logger.info("[EXECUTING] %s", cmd) + + process = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=self.environment, + ) + + await asyncio.gather( + _process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result) + ) + + result.return_code = await process.wait() + result.end_time = datetime.datetime.now() + logger.debug("Command finished with return code %s", result.return_code) + return result diff --git a/build/tools/utils.py b/build/tools/utils.py index 4c8765371316..5d7c8e0f20b2 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -28,25 +28,6 @@ logger = logging.getLogger(__name__) -def is_windows(): - return sys.platform.startswith("win32") - -def shell(cmd): - try: - logger.info("shell(): %s", cmd) - output = subprocess.check_output(cmd) - except subprocess.CalledProcessError as e: - logger.info("subprocess raised: %s", e) - if e.output: - print(e.output) - raise - except Exception as e: - logger.info("subprocess raised: %s", e) - raise - return output.decode("UTF-8").strip() - - -# Bazel BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/" BazelPackage = collections.namedtuple( "BazelPackage", ["base_uri", "file", "sha256"] @@ -89,7 +70,6 @@ def shell(cmd): ), } - def download_and_verify_bazel(): """Downloads a bazel binary from GitHub, verifying its SHA256 hash.""" package = bazel_packages.get((platform.system(), platform.machine())) @@ -144,7 +124,6 @@ def progress(block_count, block_size, total_size): return os.path.join(".", package.file) - def get_bazel_paths(bazel_path_flag): """Yields a sequence of guesses about bazel path. @@ -155,7 +134,6 @@ def get_bazel_paths(bazel_path_flag): yield shutil.which("bazel") yield download_and_verify_bazel() - def get_bazel_path(bazel_path_flag): """Returns the path to a Bazel binary, downloading Bazel if not found. @@ -177,10 +155,14 @@ def get_bazel_path(bazel_path_flag): ) sys.exit(-1) - def get_bazel_version(bazel_path): try: - version_output = shell([bazel_path, "--version"]) + version_output = subprocess.run( + [bazel_path, "--version"], + encoding="utf-8", + capture_output=True, + check=True, + ).stdout.strip() except (subprocess.CalledProcessError, OSError): return None match = re.search(r"bazel *([0-9\\.]+)", version_output) @@ -188,7 +170,6 @@ def get_bazel_version(bazel_path): return None return tuple(int(x) for x in match.group(1).split(".")) - def get_clang_path_or_exit(): which_clang_output = shutil.which("clang") if which_clang_output: @@ -202,7 +183,6 @@ def get_clang_path_or_exit(): ) sys.exit(-1) - def get_clang_major_version(clang_path): clang_version_proc = subprocess.run( [clang_path, "-E", "-P", "-"], @@ -215,35 +195,42 @@ def get_clang_major_version(clang_path): return major_version - -# Python -def get_python_bin_path(python_bin_path_flag): - """Returns the path to the Python interpreter to use.""" - path = python_bin_path_flag or sys.executable - return path.replace(os.sep, "/") - - -def get_python_version(python_bin_path): - version_output = shell([ - python_bin_path, - "-c", - ( - 'import sys; print("{}.{}".format(sys.version_info[0], ' - "sys.version_info[1]))" - ), - ]) - major, minor = map(int, version_output.split(".")) - return major, minor - -def check_python_version(python_version): - if python_version < (3, 10): - print("ERROR: JAX requires Python 3.10 or newer, found ", python_version) - sys.exit(-1) +def get_jax_configure_bazel_options(bazel_command: list[str]): + """Returns the bazel options to be written to .jax_configure.bazelrc.""" + # Get the index of the "run" parameter. Build options will come after "run" so + # we find the index of "run" and filter everything after it. + start = bazel_command.index("run") + jax_configure_bazel_options = "" + try: + for i in range(start + 1, len(bazel_command)): + bazel_flag = bazel_command[i] + # On Windows, replace all backslashes with double backslashes to avoid + # unintended escape sequences. + if platform.system() == "Windows": + bazel_flag = bazel_flag.replace("\\", "\\\\") + jax_configure_bazel_options += f"build {bazel_flag}\n" + return jax_configure_bazel_options + except ValueError: + logging.error("Unable to find index for 'run' in the Bazel command") + return "" def get_githash(): try: return subprocess.run( - ["git", "rev-parse", "HEAD"], encoding="utf-8", capture_output=True + ["git", "rev-parse", "HEAD"], + encoding="utf-8", + capture_output=True, + check=True, ).stdout.strip() except OSError: return "" + +def _parse_string_as_bool(s): + """Parses a string as a boolean value.""" + lower = s.lower() + if lower == "true": + return True + elif lower == "false": + return False + else: + raise ValueError(f"Expected either 'true' or 'false'; got {s}") diff --git a/docs/developer.md b/docs/developer.md index cbb60382b7f1..29a3cb6068ac 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -63,7 +63,7 @@ To build `jaxlib` from source, you must also install some prerequisites: To build `jaxlib` for CPU or TPU, you can run: ``` -python build/build.py +python build/build.py build --wheels=jaxlib --verbose pip install dist/*.whl # installs jaxlib (includes XLA) ``` @@ -71,7 +71,7 @@ To build a wheel for a version of Python different from your current system installation pass `--python_version` flag to the build command: ``` -python build/build.py --python_version=3.12 +python build/build.py build --wheels=jaxlib --python_version=3.12 --verbose ``` The rest of this document assumes that you are building for Python version @@ -81,13 +81,13 @@ version, simply append `--python_version=` flag every time you call installation regardless of whether the `--python_version` parameter is passed or not. -There are two ways to build `jaxlib` with CUDA support: (1) use -`python build/build.py --enable_cuda` to generate a jaxlib wheel with cuda -support, or (2) use -`python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` +If you would like to build `jaxlib` and the CUDA plugins: Run +``` +python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt +``` to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and -jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and -clang, but it can be restricted to clang via the `--nouse_cuda_nvcc` flag. +jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and +clang, but it can be restricted to clang via the `--build_cuda_with_clang` flag. See `python build/build.py --help` for configuration options. Here `python` should be the name of your Python 3 interpreter; on some systems, you @@ -102,11 +102,16 @@ current directory. target dependencies. To download the specific versions of CUDA/CUDNN redistributions, you can use - the following command: + the `--cuda_version` and `--cudnn_version` flags: ```bash - python build/build.py --enable_cuda \ - --cuda_version=12.3.2 --cudnn_version=9.1.1 + python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 \ + --cudnn_version=9.1.1 + ``` + or + ```bash + python build/build.py build --wheels=jax-cuda-pjrt --cuda_version=12.3.2 \ + --cudnn_version=9.1.1 ``` Please note that these parameters are optional: by default Bazel will @@ -118,7 +123,7 @@ current directory. the following command: ```bash - python build/build.py --enable_cuda \ + python build/build.py build --wheels=jax-cuda-plugin \ --bazel_options=--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \ --bazel_options=--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \ --bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" @@ -141,7 +146,7 @@ ways to do this: line flag to `build.py` as follows: ``` - python build/build.py --bazel_options=--override_repository=xla=/path/to/xla + python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla ``` - modify the `WORKSPACE` file in the root of the JAX source tree to point to @@ -183,7 +188,7 @@ path of the current session. Ensure `bazel`, `patch` and `realpath` are accessible. Activate the conda environment. ``` -python .\build\build.py +python .\build\build.py build --wheels=jaxlib ``` To build with debug information, add the flag `--bazel_options='--copt=/Z7'`. @@ -203,12 +208,14 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \ The recommended way to install these dependencies is by running our script, `jax/build/rocm/tools/get_rocm.py`, and selecting the appropriate options. -To build jaxlib with ROCM support, you can run the following build command, +To build jaxlib with ROCM support, you can run the following build commands, suitably adjusted for your paths and ROCM version. ``` -python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=/opt/rocm-6.2.3 +python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 ``` +to generate three wheels (jaxlib without rocm, jax-rocm-plugin, and +jax-rocm-pjrt) AMD's fork of the XLA repository may include fixes not present in the upstream XLA repository. If you experience problems with the upstream repository, you can @@ -221,7 +228,7 @@ git clone https://github.com/ROCm/xla.git and override the XLA repository with which JAX is built: ``` -python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --bazel_options=--override_repository=xla=/rel/xla/ --rocm_path=/opt/rocm-6.2.3 +python3 ./build/build.py build --wheels=jax-rocm-plugin --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 --local_xla_path=/rel/xla/ ``` For a simplified installation process, we also recommend checking out the `jax/build/rocm/dev_build_rocm.py script`. @@ -246,7 +253,7 @@ run `build/build.py` script. To choose a specific version explicitly you may pass `--python_version` argument to the tool: ``` -python build/build.py --python_version=3.12 +python build/build.py build --python_version=3.12 ``` Under the hood, the hermetic Python version is controlled @@ -284,7 +291,7 @@ direct dependencies list and then execute the following command (which will call [pip-compile](https://pypi.org/project/pip-tools/) under the hood): ``` -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` Alternatively, if you need more control, you may run the bazel command @@ -328,7 +335,7 @@ For example: ``` echo -e "\n$(realpath jaxlib-0.4.27.dev20240416-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` ### Specifying dependencies on nightly wheels @@ -338,7 +345,7 @@ dependencies we provide a special version of the dependency updater command as follows: ``` -python build/build.py --requirements_nightly_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 --nightly_update ``` Or, if you run `bazel` directly (the two commands are equivalent): @@ -469,10 +476,13 @@ or using pytest. ### Using Bazel -First, configure the JAX build by running: +First, configure the JAX build by using the `--configure_only` flag. Pass +`--wheel_list=jaxlib` for CPU tests and CUDA/ROCM for GPU for GPU tests: ``` -python build/build.py --configure_only +python build/build.py build --wheels=jaxlib --configure_only +python build/build.py build --wheels=jax-cuda-plugin --configure_only +python build/build.py build --wheels=jax-rocm-plugin --configure_only ``` You may pass additional options to `build.py` to configure the build; see the @@ -494,14 +504,14 @@ make it available in the hermetic Python. To install a specific version of ``` echo -e "\njaxlib >= 0.4.26" >> build/requirements.in -python build/build.py --requirements_update +python build/build.py requirements_update ``` Alternatively, to install `jaxlib` from a local wheel (assuming Python 3.12): ``` echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` Once you have `jaxlib` installed hermetically, run: diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 32508f8310bb..135d02aecd8a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -37,7 +37,7 @@ def repo(): # local checkout by either: # a) overriding the TF repository on the build.py command line by passing a flag # like: - # python build/build.py --bazel_options=--override_repository=xla=/path/to/xla + # python build/build.py build --local_xla_path=/path/to/xla # or # b) by commenting out the http_archive above and uncommenting the following: # local_repository(