From 098d434ccafbb95e137f1c6447536f82f9400220 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 Nov 2024 08:20:41 -0800 Subject: [PATCH] Set clang compiler options for the compiler that has "clang" in the path. `cc.endswith("clang")` ddidn't work for the cases when the clang compiler path is like `/usr/bin/clang-18`. This change addresses [Github issue](https://github.com/jax-ml/jax/issues/23689). PiperOrigin-RevId: 693735256 --- third_party/gpus/cuda/hermetic/cuda_configure.bzl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/third_party/gpus/cuda/hermetic/cuda_configure.bzl b/third_party/gpus/cuda/hermetic/cuda_configure.bzl index 0201bfecb..710cc27e2 100644 --- a/third_party/gpus/cuda/hermetic/cuda_configure.bzl +++ b/third_party/gpus/cuda/hermetic/cuda_configure.bzl @@ -290,12 +290,13 @@ def _setup_toolchains(repository_ctx, cc, cuda_version): }) cuda_defines["%{builtin_sysroot}"] = tf_sysroot + is_clang_compiler = "clang" in cc if not enable_cuda(repository_ctx): cuda_defines["%{cuda_toolkit_path}"] = "" cuda_defines["%{cuda_nvcc_files}"] = "[]" nvcc_relative_path = "" else: - if cc.endswith("clang"): + if is_clang_compiler: cuda_defines["%{cuda_toolkit_path}"] = repository_ctx.attr.nvcc_binary.workspace_root else: cuda_defines["%{cuda_toolkit_path}"] = "" @@ -306,7 +307,7 @@ def _setup_toolchains(repository_ctx, cc, cuda_version): repository_ctx.attr.nvcc_binary.workspace_root, repository_ctx.attr.nvcc_binary.name, ) - if cc.endswith("clang"): + if is_clang_compiler: cuda_defines["%{compiler}"] = "clang" cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings( @@ -348,7 +349,7 @@ def _setup_toolchains(repository_ctx, cc, cuda_version): "%{cuda_version}": cuda_version, "%{nvcc_path}": nvcc_relative_path, "%{host_compiler_path}": str(cc), - "%{use_clang_compiler}": str(cc.endswith("clang")), + "%{use_clang_compiler}": str(is_clang_compiler), "%{tmpdir}": get_host_environ( repository_ctx, _TMPDIR, @@ -454,7 +455,7 @@ def _create_local_cuda_repository(repository_ctx): "%{cuda_is_configured}": "True", "%{cuda_extra_copts}": _compute_cuda_extra_copts( cuda_config.compute_capabilities, - cc.endswith("clang"), + "clang" in cc, ), "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities), "%{cuda_version}": cuda_config.cuda_version,