Skip to content

Commit

Permalink
Refactor JAX wheel build rules to control the wheel filename and main…
Browse files Browse the repository at this point in the history
…tain reproducible wheel content and filename results.

This change is a part of the initiative to test the JAX wheels in the presubmit properly.

The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.

2. The version suffix of the wheel in the build rule output depends on the environment variables.

3. Environment variables combinations for creating wheels with different versions:
  * `0.5.0-0` (snapshot, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * `0.5.0` (release): `--repo_env=ML_WHEEL_TYPE=release`
  * `0.5.0.dev20250101` (nightly): `--repo_env=ML_WHEEL_TYPE=nightly --repo_env=ML_WHEEL_BUILD_DATE=20250101`
  * `0.5.0.dev20241122+677cd8ebf` (custom): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`
  * `0.5.0.dev20241122+677cd8ebfalpha` (custom): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) --repo_env=ML_WHEEL_VERSION_SUFFIX=-alpha`

PiperOrigin-RevId: 699315679
  • Loading branch information
Google-ML-Automation committed Jan 21, 2025
1 parent 70a5175 commit 1e94f4d
Show file tree
Hide file tree
Showing 12 changed files with 265 additions and 94 deletions.
5 changes: 3 additions & 2 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,10 @@ build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# This flag is needed to include CUDA libraries for bazel tests.
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true
# This configuration is used for building the wheels.
build:cuda_wheel --@local_config_cuda//cuda:include_cuda_libs=false

# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
Expand Down
15 changes: 15 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ xla_workspace0()
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
flatbuffers()

load("//jaxlib:jax_python_wheel_version.bzl", "jax_python_wheel_version_repository")
jax_python_wheel_version_repository(
name = "jax_wheel_version",
file_with_version = "//jax:version.py",
version_key = "_version",
)

load(
"@tsl//third_party/py:python_wheel_version_suffix.bzl",
"python_wheel_version_suffix_repository",
)
python_wheel_version_suffix_repository(
name = "jax_wheel_version_suffix",
)

load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
Expand Down
1 change: 1 addition & 0 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ async def main():

if "cuda" in args.wheels:
wheel_build_command_base.append("--config=cuda")
wheel_build_command_base.append("--config=cuda_wheel")
if args.use_clang:
wheel_build_command_base.append(
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
Expand Down
5 changes: 0 additions & 5 deletions jax/tools/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ def platform_tag(cpu: str) -> str:
}[(platform.system(), cpu)]
return f"{platform_name}_{cpu_name}"

def get_githash(jaxlib_git_hash):
if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash):
with open(jaxlib_git_hash, "r") as f:
return f.readline().strip()
return jaxlib_git_hash

def build_wheel(
sources_path: str, output_path: str, package_name: str, git_hash: str = ""
Expand Down
19 changes: 15 additions & 4 deletions jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def _get_version_string() -> str:
# In this case we return it directly.
if _release_version is not None:
return _release_version
if os.getenv("WHEEL_BUILD_TAG"):
return _version
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
return _version_from_git_tree(_version) or _version_from_todays_date(_version)


Expand Down Expand Up @@ -71,16 +75,23 @@ def _get_version_for_build() -> str:
"""Determine the version at build time.
The returned version string depends on which environment variables are set:
- if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16"
- if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc"
- if JAX_RELEASE, WHEEL_BUILD_TAG, or JAXLIB_RELEASE are set: version looks like "0.4.16"
- if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906"
- if none are set: version looks like "0.4.16.dev20230906+ge58560fdc
"""
if _release_version is not None:
return _release_version
if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'):
return _version_from_todays_date(_version)
if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'):
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
if (
os.getenv("JAX_RELEASE")
or os.getenv("JAXLIB_RELEASE")
or os.getenv("WHEEL_BUILD_TAG")
):
return _version
if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"):
return _version_from_todays_date(_version)
return _version_from_git_tree(_version) or _version_from_todays_date(_version)


Expand Down
124 changes: 104 additions & 20 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

"""Bazel macros used by the JAX build."""

load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
load("@jax_wheel_version//:wheel_version.bzl", "WHEEL_VERSION")
load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "BUILD_TAG", "WHEEL_VERSION_SUFFIX")
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
Expand Down Expand Up @@ -267,7 +270,7 @@ def jax_multiplatform_test(
]
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
test_tags += ["manual"]
test_tags.append("manual")
if backend == "gpu":
test_tags += tf_cuda_tests_tags()
native.py_test(
Expand Down Expand Up @@ -308,15 +311,93 @@ def jax_generate_backend_suites(backends = []):
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
)

def _get_wheel_platform_name(platform_name, cpu_name):
platform = ""
cpu = ""
if platform_name == "linux":
platform = "manylinux2014"
cpu = cpu_name
elif platform_name == "macosx":
if cpu_name == "arm64":
cpu = "arm64"
platform = "macosx_11_0"
else:
cpu = "x86_64"
platform = "macosx_10_14"
elif platform_name == "win":
platform = "win"
cpu = "amd64"
return "{platform}_{cpu}".format(
platform = platform,
cpu = cpu,
)

def _get_cpu(platform_name, platform_tag):
# Following the convention in jax/tools/build_utils.py.
if platform_name == "macosx" and platform_tag == "arm64":
return "arm64"
if platform_name == "win" and platform_tag == "x86_64":
return "AMD64"
return "aarch64" if platform_tag == "arm64" else platform_tag

def _get_full_wheel_name(rule_name, platform_name, cpu_name, major_cuda_version, wheel_version):
if "pjrt" in rule_name:
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
else:
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
python_version = HERMETIC_PYTHON_VERSION.replace(".", "")
package_name = rule_name.replace("_wheel", "").replace(
"cuda",
"cuda{}".format(major_cuda_version),
)
return wheel_name_template.format(
package_name = package_name,
python_version = python_version,
major_python_version = python_version[0],
wheel_version = wheel_version,
wheel_platform_tag = _get_wheel_platform_name(platform_name, cpu_name),
)

def _jax_wheel_impl(ctx):
include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
git_hash = ctx.attr.git_hash[BuildSettingInfo].value
output_path = ctx.attr.output_path[BuildSettingInfo].value
executable = ctx.executable.wheel_binary

output = ctx.actions.declare_directory(ctx.label.name)
if include_cuda_libs and not override_include_cuda_libs:
fail("JAX wheel shouldn't be built with CUDA dependencies." +
" Please provide `--config=cuda_wheel` for bazel build command." +
" If you absolutely need to add CUDA dependencies, provide" +
" `--@local_config_cuda//cuda:override_include_cuda_libs=true`.")

env = {}
args = ctx.actions.args()
args.add("--output_path", output.path) # required argument
args.add("--cpu", ctx.attr.platform_tag) # required argument
jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path
args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument

full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX)
env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX
if BUILD_TAG:
env["WHEEL_BUILD_TAG"] = BUILD_TAG
args.add("--build-tag", BUILD_TAG)
full_wheel_version += "-{}".format(BUILD_TAG)
if not WHEEL_VERSION_SUFFIX and not BUILD_TAG:
env["JAX_RELEASE"] = "1"

cpu = _get_cpu(ctx.attr.platform_name, ctx.attr.platform_tag)
wheel_name = _get_full_wheel_name(
ctx.label.name,
ctx.attr.platform_name,
cpu,
ctx.attr.platform_version,
full_wheel_version,
)
output_file = ctx.actions.declare_file(output_path +
"/" + wheel_name)
wheel_dir = output_file.path[:output_file.path.rfind("/")]

args.add("--output_path", wheel_dir) # required argument
args.add("--cpu", cpu) # required argument
args.add("--jaxlib_git_hash", "\"{}\"".format(git_hash)) # required argument

if ctx.attr.enable_cuda:
args.add("--enable-cuda", "True")
Expand All @@ -335,11 +416,13 @@ def _jax_wheel_impl(ctx):
args.use_param_file("@%s", use_always = False)
ctx.actions.run(
arguments = [args],
inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [],
outputs = [output],
inputs = [],
outputs = [output_file],
executable = executable,
env = env,
)
return [DefaultInfo(files = depset(direct = [output]))]

return [DefaultInfo(files = depset(direct = [output_file]))]

_jax_wheel = rule(
attrs = {
Expand All @@ -350,12 +433,16 @@ _jax_wheel = rule(
cfg = "target",
),
"platform_tag": attr.string(mandatory = True),
"git_hash": attr.label(allow_single_file = True),
"platform_name": attr.string(mandatory = True),
"git_hash": attr.string(),
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
"enable_cuda": attr.bool(default = False),
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
"platform_version": attr.string(mandatory = True, default = ""),
"skip_gpu_kernels": attr.bool(default = False),
"enable_rocm": attr.bool(default = False),
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
},
implementation = _jax_wheel_impl,
executable = False,
Expand All @@ -380,19 +467,16 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
wheel_binary = wheel_binary,
enable_cuda = enable_cuda,
platform_version = platform_version,
# Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to
# pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to
# the git hash file needs to be created first.
git_hash = select({
"//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink",
"//conditions:default": None,
platform_name = select({
"@platforms//os:osx": "macosx",
"@platforms//os:macos": "macosx",
"@platforms//os:windows": "win",
"@platforms//os:linux": "linux",
}),
# Following the convention in jax/tools/build_utils.py.
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
platform_tag = select({
"//jaxlib/tools:macos_arm64": "arm64",
"//jaxlib/tools:win_amd64": "AMD64",
"//jaxlib/tools:arm64": "aarch64",
"@platforms//cpu:aarch64": "arm64",
"@platforms//cpu:arm64": "arm64",
"@platforms//cpu:x86_64": "x86_64",
}),
)
Expand Down
38 changes: 38 additions & 0 deletions jaxlib/jax_python_wheel_version.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Repository rule to generate a file with JAX python wheel version. """

def _jax_python_wheel_version_repository_impl(repository_ctx):
file_content = repository_ctx.read(
repository_ctx.path(repository_ctx.attr.file_with_version),
)
version_line_start_index = file_content.find(repository_ctx.attr.version_key)
version_line_end_index = version_line_start_index + file_content[version_line_start_index:].find("\n")
repository_ctx.file(
"wheel_version.bzl",
file_content[version_line_start_index:version_line_end_index].replace(
repository_ctx.attr.version_key,
"WHEEL_VERSION",
),
)
repository_ctx.file("BUILD", "")

jax_python_wheel_version_repository = repository_rule(
implementation = _jax_python_wheel_version_repository_impl,
attrs = {
"file_with_version": attr.label(mandatory = True, allow_single_file = True),
"version_key": attr.string(mandatory = True),
},
)
Loading

0 comments on commit 1e94f4d

Please sign in to comment.