From a2d79936df175a490f11bfa7837b8572ac361e5c Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Wed, 7 Aug 2024 00:40:16 -0500 Subject: [PATCH] [ROCM] Fix BUILD.bazel library source paths --- jaxlib/gpu/solver.cc | 1 - jaxlib/gpu/vendor.h | 1 + jaxlib/rocm/BUILD.bazel | 13 +++++++------ 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 8d2f23d6b09a..223c8a9798be 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "absl/status/statusor.h" -#include "third_party/gpus/cuda/include/cusolver_common.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/solver_kernels.h" diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index f1d4bfa86b15..96266ca93378 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -30,6 +30,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export +#include "third_party/gpus/cuda/include/cusolver_common.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export #include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index 007cfd018281..1ec36fd30c8e 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -100,10 +100,10 @@ cc_library( cc_library( name = "hipblas_kernels_ffi", - srcs = ["//third_party/py/jax/jaxlib/gpu:blas_kernels_ffi.cc"], - hdrs = ["//third_party/py/jax/jaxlib/gpu:blas_kernels_ffi.h"], + srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"], + hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"], deps = [ - ":hip_gpu_handle_pools", + ":hip_blas_handle_pool", ":hip_gpu_kernel_helpers", ":hip_vendor", "//jaxlib:ffi_helpers", @@ -173,8 +173,8 @@ cc_library( cc_library( name = "hipsolver_kernels_ffi", - srcs = ["//third_party/py/jax/jaxlib/gpu:solver_kernels_ffi.cc"], - hdrs = ["//third_party/py/jax/jaxlib/gpu:solver_kernels_ffi.h"], + srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"], + hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"], deps = [ ":hip_gpu_kernel_helpers", ":hip_solver_handle_pool", @@ -199,10 +199,11 @@ pybind_extension( features = ["-use_header_modules"], module_name = "_solver", deps = [ - ":hip_gpu_handle_pools", + ":hip_solver_handle_pool", ":hip_gpu_kernel_helpers", ":hip_vendor", ":hipsolver_kernels", + ":hipsolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format",