Skip to content

Commit

Permalink
[ROCM] Fix BUILD.bazel library source paths
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruturaj4 committed Aug 7, 2024
1 parent dd958ad commit a2d7993
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
1 change: 0 additions & 1 deletion jaxlib/gpu/solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cusolver_common.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/solver_handle_pool.h"
#include "jaxlib/gpu/solver_kernels.h"
Expand Down
1 change: 1 addition & 0 deletions jaxlib/gpu/vendor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusolver_common.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export

Expand Down
13 changes: 7 additions & 6 deletions jaxlib/rocm/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ cc_library(

cc_library(
name = "hipblas_kernels_ffi",
srcs = ["//third_party/py/jax/jaxlib/gpu:blas_kernels_ffi.cc"],
hdrs = ["//third_party/py/jax/jaxlib/gpu:blas_kernels_ffi.h"],
srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"],
hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"],
deps = [
":hip_gpu_handle_pools",
":hip_blas_handle_pool",
":hip_gpu_kernel_helpers",
":hip_vendor",
"//jaxlib:ffi_helpers",
Expand Down Expand Up @@ -173,8 +173,8 @@ cc_library(

cc_library(
name = "hipsolver_kernels_ffi",
srcs = ["//third_party/py/jax/jaxlib/gpu:solver_kernels_ffi.cc"],
hdrs = ["//third_party/py/jax/jaxlib/gpu:solver_kernels_ffi.h"],
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"],
deps = [
":hip_gpu_kernel_helpers",
":hip_solver_handle_pool",
Expand All @@ -199,10 +199,11 @@ pybind_extension(
features = ["-use_header_modules"],
module_name = "_solver",
deps = [
":hip_gpu_handle_pools",
":hip_solver_handle_pool",
":hip_gpu_kernel_helpers",
":hip_vendor",
":hipsolver_kernels",
":hipsolver_kernels_ffi",
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
Expand Down

0 comments on commit a2d7993

Please sign in to comment.