Skip to content

Commit

Permalink
#sdy Run inlined mesh lifter pass at the end of JAX lowering.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685728692
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Oct 14, 2024
1 parent d15d70d commit 75e22f2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions jaxlib/mlir/_mlir_libs/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ py_extension(
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@local_config_python//:headers",
"@pybind11",
"@shardy//shardy/integrations/c:sdy_capi_headers",
],
)

Expand Down
3 changes: 3 additions & 0 deletions jaxlib/mlir/_mlir_libs/register_jax_dialects.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "mlir-c/Dialect/Vector.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "shardy/integrations/c/passes.h"


namespace py = pybind11;

Expand All @@ -37,6 +39,7 @@ PYBIND11_MODULE(register_jax_dialects, m, py::mod_gil_not_used()) {
REGISTER_DIALECT(nvvm);
REGISTER_DIALECT(llvm);
mlirRegisterTransformsPasses();
mlirRegisterAllSdyPassesAndPipelines();
// Transforms used by JAX.
mlirRegisterTransformsStripDebugInfo();
});
Expand Down

0 comments on commit 75e22f2

Please sign in to comment.