diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 53faa45112d9..1f62e7954d06 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -26,6 +26,12 @@ limitations under the License. #include #include +// clang-format: off +// pybind11 must be included before mlir/Bindings/Python/PybindAdaptors.h, +// otherwise this code will not build on Windows. +#include "pybind11/pybind11.h" +// clang-format: on + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" @@ -43,7 +49,6 @@ limitations under the License. #include "pybind11/cast.h" #include "pybind11/detail/common.h" #include "pybind11/numpy.h" -#include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h" @@ -577,11 +582,11 @@ PYBIND11_MODULE(_tpu_ext, m) { for (int64_t i = 0; i < np_arr.ndim(); ++i) { shape.data()[i] = np_arr.shape()[i]; } - return mlirTpuAssemble(getDefaultInsertionPoint(), ty, layout, - MlirTpuValueArray{ - MlirTpuI64ArrayRef{shape.data(), shape.size()}, - vals.data()}, - TARGET_SHAPE); + return mlirTpuAssemble( + getDefaultInsertionPoint(), ty, layout, + MlirTpuValueArray{MlirTpuI64ArrayRef{shape.data(), shape.size()}, + vals.data()}, + TARGET_SHAPE); }); m.def("disassemble", [](MlirTpuVectorLayout layout, MlirValue val) { NotImplementedDetector detector(getDefaultContext());