diff --git a/.bazelrc b/.bazelrc index 000619d838f1..e81ea68b0052 100644 --- a/.bazelrc +++ b/.bazelrc @@ -125,9 +125,10 @@ build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true # Default hermetic CUDA and CUDNN versions. build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +build:cuda --@local_config_cuda//cuda:include_cuda_libs=true -# This flag is needed to include CUDA libraries for bazel tests. -test:cuda --@local_config_cuda//cuda:include_cuda_libs=true +# This configuration is used for building the wheels. +build:cuda_wheel --@local_config_cuda//cuda:include_cuda_libs=false # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to diff --git a/WORKSPACE b/WORKSPACE index 130c9f804c93..b937b5d6f881 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -62,6 +62,21 @@ xla_workspace0() load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() +load("//jaxlib:jax_python_wheel_version.bzl", "jax_python_wheel_version_repository") +jax_python_wheel_version_repository( + name = "jax_wheel_version", + file_with_version = "//jax:version.py", + version_key = "_version", +) + +load( + "@tsl//third_party/py:python_wheel_version_suffix.bzl", + "python_wheel_version_suffix_repository", +) +python_wheel_version_suffix_repository( + name = "jax_wheel_version_suffix", +) + load( "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", diff --git a/build/build.py b/build/build.py index d2f68f80efc6..6da015e6485a 100755 --- a/build/build.py +++ b/build/build.py @@ -532,6 +532,7 @@ async def main(): if "cuda" in args.wheels: wheel_build_command_base.append("--config=cuda") + wheel_build_command_base.append("--config=cuda_wheel") if args.use_clang: wheel_build_command_base.append( f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index 83d0b4b25923..84cc697d1894 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -62,11 +62,6 @@ def platform_tag(cpu: str) -> str: }[(platform.system(), cpu)] return f"{platform_name}_{cpu_name}" -def get_githash(jaxlib_git_hash): - if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash): - with open(jaxlib_git_hash, "r") as f: - return f.readline().strip() - return jaxlib_git_hash def build_wheel( sources_path: str, output_path: str, package_name: str, git_hash: str = "" diff --git a/jax/version.py b/jax/version.py index 484cd96acf41..8e1b5e81e1ff 100644 --- a/jax/version.py +++ b/jax/version.py @@ -35,6 +35,10 @@ def _get_version_string() -> str: # In this case we return it directly. if _release_version is not None: return _release_version + if os.getenv("WHEEL_BUILD_TAG"): + return _version + if os.getenv("WHEEL_VERSION_SUFFIX"): + return _version + os.getenv("WHEEL_VERSION_SUFFIX", "") return _version_from_git_tree(_version) or _version_from_todays_date(_version) @@ -71,16 +75,23 @@ def _get_version_for_build() -> str: """Determine the version at build time. The returned version string depends on which environment variables are set: - - if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16" + - if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc" + - if JAX_RELEASE, WHEEL_BUILD_TAG, or JAXLIB_RELEASE are set: version looks like "0.4.16" - if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906" - if none are set: version looks like "0.4.16.dev20230906+ge58560fdc """ if _release_version is not None: return _release_version - if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'): - return _version_from_todays_date(_version) - if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'): + if os.getenv("WHEEL_VERSION_SUFFIX"): + return _version + os.getenv("WHEEL_VERSION_SUFFIX", "") + if ( + os.getenv("JAX_RELEASE") + or os.getenv("JAXLIB_RELEASE") + or os.getenv("WHEEL_BUILD_TAG") + ): return _version + if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"): + return _version_from_todays_date(_version) return _version_from_git_tree(_version) or _version_from_todays_date(_version) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 49062c7283fd..b60430aaff0b 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -14,7 +14,10 @@ """Bazel macros used by the JAX build.""" +load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library") +load("@jax_wheel_version//:wheel_version.bzl", "WHEEL_VERSION") +load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "BUILD_TAG", "WHEEL_VERSION_SUFFIX") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") @@ -267,7 +270,7 @@ def jax_multiplatform_test( ] test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, []) if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]): - test_tags += ["manual"] + test_tags.append("manual") if backend == "gpu": test_tags += tf_cuda_tests_tags() native.py_test( @@ -308,15 +311,93 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) +def _get_wheel_platform_name(platform_name, cpu_name): + platform = "" + cpu = "" + if platform_name == "linux": + platform = "manylinux2014" + cpu = cpu_name + elif platform_name == "macosx": + if cpu_name == "arm64": + cpu = "arm64" + platform = "macosx_11_0" + else: + cpu = "x86_64" + platform = "macosx_10_14" + elif platform_name == "win": + platform = "win" + cpu = "amd64" + return "{platform}_{cpu}".format( + platform = platform, + cpu = cpu, + ) + +def _get_cpu(platform_name, platform_tag): + # Following the convention in jax/tools/build_utils.py. + if platform_name == "macosx" and platform_tag == "arm64": + return "arm64" + if platform_name == "win" and platform_tag == "x86_64": + return "AMD64" + return "aarch64" if platform_tag == "arm64" else platform_tag + +def _get_full_wheel_name(rule_name, platform_name, cpu_name, major_cuda_version, wheel_version): + if "pjrt" in rule_name: + wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl" + else: + wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl" + python_version = HERMETIC_PYTHON_VERSION.replace(".", "") + package_name = rule_name.replace("_wheel", "").replace( + "cuda", + "cuda{}".format(major_cuda_version), + ) + return wheel_name_template.format( + package_name = package_name, + python_version = python_version, + major_python_version = python_version[0], + wheel_version = wheel_version, + wheel_platform_tag = _get_wheel_platform_name(platform_name, cpu_name), + ) + def _jax_wheel_impl(ctx): + include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value + override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value + git_hash = ctx.attr.git_hash[BuildSettingInfo].value + output_path = ctx.attr.output_path[BuildSettingInfo].value executable = ctx.executable.wheel_binary - output = ctx.actions.declare_directory(ctx.label.name) + if include_cuda_libs and not override_include_cuda_libs: + fail("JAX wheel shouldn't be built with CUDA dependencies." + + " Please provide `--config=cuda_wheel` for bazel build command." + + " If you absolutely need to add CUDA dependencies, provide" + + " `--@local_config_cuda//cuda:override_include_cuda_libs=true`.") + + env = {} args = ctx.actions.args() - args.add("--output_path", output.path) # required argument - args.add("--cpu", ctx.attr.platform_tag) # required argument - jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path - args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument + + full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX) + env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX + if BUILD_TAG: + env["WHEEL_BUILD_TAG"] = BUILD_TAG + args.add("--build-tag", BUILD_TAG) + full_wheel_version += "-{}".format(BUILD_TAG) + if not WHEEL_VERSION_SUFFIX and not BUILD_TAG: + env["JAX_RELEASE"] = "1" + + cpu = _get_cpu(ctx.attr.platform_name, ctx.attr.platform_tag) + wheel_name = _get_full_wheel_name( + ctx.label.name, + ctx.attr.platform_name, + cpu, + ctx.attr.platform_version, + full_wheel_version, + ) + output_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + wheel_dir = output_file.path[:output_file.path.rfind("/")] + + args.add("--output_path", wheel_dir) # required argument + args.add("--cpu", cpu) # required argument + args.add("--jaxlib_git_hash", "\"{}\"".format(git_hash)) # required argument if ctx.attr.enable_cuda: args.add("--enable-cuda", "True") @@ -335,11 +416,13 @@ def _jax_wheel_impl(ctx): args.use_param_file("@%s", use_always = False) ctx.actions.run( arguments = [args], - inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [], - outputs = [output], + inputs = [], + outputs = [output_file], executable = executable, + env = env, ) - return [DefaultInfo(files = depset(direct = [output]))] + + return [DefaultInfo(files = depset(direct = [output_file]))] _jax_wheel = rule( attrs = { @@ -350,12 +433,16 @@ _jax_wheel = rule( cfg = "target", ), "platform_tag": attr.string(mandatory = True), - "git_hash": attr.label(allow_single_file = True), + "platform_name": attr.string(mandatory = True), + "git_hash": attr.string(), + "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. "platform_version": attr.string(mandatory = True, default = ""), "skip_gpu_kernels": attr.bool(default = False), "enable_rocm": attr.bool(default = False), + "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), + "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), }, implementation = _jax_wheel_impl, executable = False, @@ -380,19 +467,16 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): wheel_binary = wheel_binary, enable_cuda = enable_cuda, platform_version = platform_version, - # Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to - # pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to - # the git hash file needs to be created first. - git_hash = select({ - "//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink", - "//conditions:default": None, + platform_name = select({ + "@platforms//os:osx": "macosx", + "@platforms//os:macos": "macosx", + "@platforms//os:windows": "win", + "@platforms//os:linux": "linux", }), - # Following the convention in jax/tools/build_utils.py. # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. platform_tag = select({ - "//jaxlib/tools:macos_arm64": "arm64", - "//jaxlib/tools:win_amd64": "AMD64", - "//jaxlib/tools:arm64": "aarch64", + "@platforms//cpu:aarch64": "arm64", + "@platforms//cpu:arm64": "arm64", "@platforms//cpu:x86_64": "x86_64", }), ) diff --git a/jaxlib/jax_python_wheel_version.bzl b/jaxlib/jax_python_wheel_version.bzl new file mode 100644 index 000000000000..9eba463cbf93 --- /dev/null +++ b/jaxlib/jax_python_wheel_version.bzl @@ -0,0 +1,38 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Repository rule to generate a file with JAX python wheel version. """ + +def _jax_python_wheel_version_repository_impl(repository_ctx): + file_content = repository_ctx.read( + repository_ctx.path(repository_ctx.attr.file_with_version), + ) + version_line_start_index = file_content.find(repository_ctx.attr.version_key) + version_line_end_index = version_line_start_index + file_content[version_line_start_index:].find("\n") + repository_ctx.file( + "wheel_version.bzl", + file_content[version_line_start_index:version_line_end_index].replace( + repository_ctx.attr.version_key, + "WHEEL_VERSION", + ), + ) + repository_ctx.file("BUILD", "") + +jax_python_wheel_version_repository = repository_rule( + implementation = _jax_python_wheel_version_repository_impl, + attrs = { + "file_with_version": attr.label(mandatory = True, allow_single_file = True), + "version_key": attr.string(mandatory = True), + }, +) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 63f2643fe230..9529188fe339 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -14,10 +14,13 @@ # JAX is Autograd and XLA -load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "@tsl//third_party/py:py_manylinux_compliance_test.bzl", + "verify_manylinux_compliance_test", +) load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel") licenses(["notice"]) # Apache 2 @@ -136,48 +139,9 @@ py_binary( ], ) -selects.config_setting_group( - name = "macos", - match_any = [ - "@platforms//os:osx", - "@platforms//os:macos", - ], -) - -selects.config_setting_group( - name = "arm64", - match_any = [ - "@platforms//cpu:aarch64", - "@platforms//cpu:arm64", - ], -) - -selects.config_setting_group( - name = "macos_arm64", - match_all = [ - ":arm64", - ":macos", - ], -) - -selects.config_setting_group( - name = "win_amd64", - match_all = [ - "@platforms//cpu:x86_64", - "@platforms//os:windows", - ], -) - string_flag( - name = "jaxlib_git_hash", - build_setting_default = "", -) - -config_setting( - name = "jaxlib_git_hash_nightly_or_release", - flag_values = { - ":jaxlib_git_hash": "nightly", - }, + name = "output_path", + build_setting_default = "dist", ) jax_wheel( @@ -200,3 +164,36 @@ jax_wheel( platform_version = "12", wheel_binary = ":build_gpu_plugin_wheel", ) + +verify_manylinux_compliance_test( + name = "jaxlib_manylinux_compliance_test", + aarch64_compliance_tag = "manylinux_2_17_aarch64", + test_tags = [ + "mac_excluded", + "windows_excluded", + ], + wheel = ":jaxlib_wheel", + x86_64_compliance_tag = "manylinux_2_17_x86_64", +) + +verify_manylinux_compliance_test( + name = "jax_cuda_plugin_manylinux_compliance_test", + aarch64_compliance_tag = "manylinux_2_17_aarch64", + test_tags = [ + "mac_excluded", + "windows_excluded", + ], + wheel = ":jax_cuda_plugin_wheel", + x86_64_compliance_tag = "manylinux_2_17_x86_64", +) + +verify_manylinux_compliance_test( + name = "jax_cuda_pjrt_manylinux_compliance_test", + aarch64_compliance_tag = "manylinux_2_17_aarch64", + test_tags = [ + "mac_excluded", + "windows_excluded", + ], + wheel = ":jax_cuda_pjrt_wheel", + x86_64_compliance_tag = "manylinux_2_17_x86_64", +) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 65412f0365dc..2016ccdf2342 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -61,6 +61,11 @@ "--enable-rocm", default=False, help="Should we build with ROCM enabled?") +parser.add_argument( + "--build-tag", + default=None, + required=False, + help="Wheel build tag. Optional.") args = parser.parse_args() r = runfiles.Create() @@ -68,14 +73,14 @@ def write_setup_cfg(sources_path, cpu): - tag = build_utils.platform_tag(cpu) + plat_tag = build_utils.platform_tag(cpu) with open(sources_path / "setup.cfg", "w") as f: f.write(f"""[metadata] license_files = LICENSE.txt [bdist_wheel] -plat_name={tag} -""") +plat_name={plat_tag} +""" + (f"build_number={args.build_tag}\n" if args.build_tag else "")) def prepare_wheel_cuda( @@ -174,12 +179,11 @@ def prepare_wheel_rocm( if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=git_hash, + git_hash=args.jaxlib_git_hash, ) finally: tmpdir.cleanup() diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 08c2389c292a..3aafaafc3b5d 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -67,23 +67,26 @@ "--enable-rocm", default=False, help="Should we build with ROCM enabled?") +parser.add_argument( + "--build-tag", + default=None, + required=False, + help="Wheel build tag. Optional.") args = parser.parse_args() r = runfiles.Create() def write_setup_cfg(sources_path, cpu): - tag = build_utils.platform_tag(cpu) + plat_tag = build_utils.platform_tag(cpu) with open(sources_path / "setup.cfg", "w") as f: - f.write( - f"""[metadata] + f.write(f"""[metadata] license_files = LICENSE.txt [bdist_wheel] -plat_name={tag} +plat_name={plat_tag} python-tag=py3 -""" - ) +""" + (f"build_number={args.build_tag}\n" if args.build_tag else "")) def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version): @@ -167,12 +170,11 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=git_hash, + git_hash=args.jaxlib_git_hash, ) finally: if tmpdir: diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 4b71bd5de2d8..fc11e2060f5c 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -56,6 +56,11 @@ action="store_true", help="Create an 'editable' jaxlib build instead of a wheel.", ) +parser.add_argument( + "--build-tag", + default=None, + required=False, + help="Wheel build tag. Optional.") args = parser.parse_args() r = runfiles.Create() @@ -150,16 +155,14 @@ def verify_mac_libraries_dont_reference_chkstack(): def write_setup_cfg(sources_path, cpu): - tag = build_utils.platform_tag(cpu) + plat_tag = build_utils.platform_tag(cpu) with open(sources_path / "setup.cfg", "w") as f: - f.write( - f"""[metadata] + f.write(f"""[metadata] license_files = LICENSE.txt [bdist_wheel] -plat_name={tag} -""" - ) +plat_name={plat_tag} +""" + (f"build_number={args.build_tag}\n" if args.build_tag else "")) def prepare_wheel(sources_path: pathlib.Path, *, cpu): @@ -387,8 +390,12 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) - build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=git_hash) + build_utils.build_wheel( + sources_path, + args.output_path, + package_name, + git_hash=args.jaxlib_git_hash, + ) finally: if tmpdir: tmpdir.cleanup() diff --git a/tests/version_test.py b/tests/version_test.py index 51297a9716b1..034e49787ca4 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -157,6 +157,22 @@ def testBuildVersionFromEnvironment(self): self.assertTrue(version.endswith("test")) self.assertValidVersion(version) + with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE="1", + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None, + WHEEL_BUILD_TAG ="0"): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertEqual(version, base_version) + self.assertValidVersion(version) + + with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, JAXLIB_NIGHTLY="1", + WHEEL_VERSION_SUFFIX=".dev20250101+1c0f1076erc1"): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertEqual(version, f"{base_version}.dev20250101+1c0f1076erc1") + self.assertValidVersion(version) + def testVersions(self): check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3", minimum_jaxlib_version="1.2.3")