diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index db02eb8bbff1..1c45a4ce9463 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -257,6 +257,7 @@ py_extension( "@llvm-project//mlir:MLIRBindingsPythonHeaders", "@local_config_python//:headers", "@pybind11", + "@shardy//shardy/integrations/c:sdy_capi_headers", ], ) diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 2e10062945b5..bcc1ae56852c 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -14,6 +14,8 @@ #include "mlir-c/Dialect/Vector.h" #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "shardy/integrations/c/passes.h" + namespace py = pybind11; @@ -37,6 +39,7 @@ PYBIND11_MODULE(register_jax_dialects, m, py::mod_gil_not_used()) { REGISTER_DIALECT(nvvm); REGISTER_DIALECT(llvm); mlirRegisterTransformsPasses(); + mlirRegisterAllSdyPassesAndPipelines(); // Transforms used by JAX. mlirRegisterTransformsStripDebugInfo(); });