From cc8a7f8a0f3199fec0b720b8c84a88c9f8631008 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Thu, 16 Jan 2025 14:38:51 -0800 Subject: [PATCH] [Mosaic] Remove hardcoded TARGET_SHAPE and align Python/C++ APIs. PiperOrigin-RevId: 716383597 --- jaxlib/mlir/_mlir_libs/tpu_ext.cc | 152 +++++++++++++++++++----------- 1 file changed, 98 insertions(+), 54 deletions(-) diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index c3883a97aad4..2b5ec898ad3e 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -64,8 +64,7 @@ namespace { constexpr const char LAYOUT_DEFS_MODULE[] = "jax.jaxlib.mosaic.python.layout_defs"; constexpr const char IR_MODULE[] = "jaxlib.mlir.ir"; -// TODO(tlongeri): Get rid of this somehow -constexpr MlirTpuI64TargetTuple TARGET_SHAPE{8, 128}; +constexpr MlirTpuI64TargetTuple DEFAULT_TARGET_SHAPE{8, 128}; // TODO(tlongeri): Add type annotations via nanobind once there is // a release for it (and maybe add a custom Sequence one as well). @@ -163,6 +162,29 @@ struct nb::detail::type_caster { } }; +template <> +struct nb::detail::type_caster { + NB_TYPE_CASTER(MlirTpuI64TargetTuple, const_name("TargetTuple")); + + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { + auto target_tuple_cls = + nb::module_::import_(LAYOUT_DEFS_MODULE).attr("TargetTuple"); + if (!nb::isinstance(src, target_tuple_cls)) { + return false; + } + value = {nb::cast(src.attr("sublanes")), + nb::cast(src.attr("lanes"))}; + return true; + } + + static handle from_cpp(MlirTpuI64TargetTuple target_tuple, rv_policy policy, + cleanup_list* cleanup) noexcept { + nb::object target_tuple_cls = + nb::module_::import_(LAYOUT_DEFS_MODULE).attr("TargetTuple"); + return target_tuple_cls(target_tuple.sublane, target_tuple.lane).release(); + } +}; + namespace { // Handler for use with MLIR C API print functions. The 2nd parameter is an // opaque pointer to "user data" that should always be a string. @@ -335,35 +357,38 @@ NB_MODULE(_tpu_ext, m) { .max_sublanes_in_scratch = max_sublanes_in_scratch}; }, nb::arg("hardware_generation") = -1, - nb::arg("target_shape") = toPyTuple(TARGET_SHAPE), + nb::arg("target_shape") = toPyTuple(DEFAULT_TARGET_SHAPE), nb::arg("mxu_shape") = nb::make_tuple(128, 128), nb::arg("max_sublanes_in_scratch") = 0); nb::class_(m, "VRegDataBounds") .def("mask_varies_along", - [](MlirTpuVregDataBounds self, MlirTpuDirection direction) { + [](MlirTpuVregDataBounds self, MlirTpuDirection direction, + MlirTpuI64TargetTuple target_shape) { return mlirTpuVregDataBoundsMaskVariesAlong(self, direction, - TARGET_SHAPE); + target_shape); + }) + .def("complete", + [](MlirTpuVregDataBounds self, MlirTpuI64TargetTuple target_shape) { + return mlirTpuVregDataBoundsIsComplete(self, target_shape); }) - .def_prop_ro("complete", - [](MlirTpuVregDataBounds self) { - return mlirTpuVregDataBoundsIsComplete(self, TARGET_SHAPE); - }) .def("get_vector_mask", - [](MlirTpuVregDataBounds self, int generation) { + [](MlirTpuVregDataBounds self, int generation, + MlirTpuI64TargetTuple target_shape) { // TODO: Does this work? Test in Python MlirValue mask = mlirTpuVregDataBoundsGetVectorMask( self, getDefaultInsertionPoint(), getDefaultLocation(), - generation, TARGET_SHAPE); + generation, target_shape); if (mask.ptr == nullptr) { throw std::runtime_error("getVectorMask failed"); } return mask; }) - .def("get_sublane_mask", [](MlirTpuVregDataBounds self) { - return mlirTpuVregDataBoundsGetSublaneMask(self, getDefaultContext(), - TARGET_SHAPE); - }); + .def("get_sublane_mask", + [](MlirTpuVregDataBounds self, MlirTpuI64TargetTuple target_shape) { + return mlirTpuVregDataBoundsGetSublaneMask( + self, getDefaultContext(), target_shape); + }); // TODO(tlongeri): More precise argument type annotations. There currently // seems to be no way to define your own? @@ -384,9 +409,6 @@ NB_MODULE(_tpu_ext, m) { offsetFromPyOffset(offsets[1])}, {nb::cast(tiling[0]), nb::cast(tiling[1])}, implicit_dim); - if (!mlirTpuVectorLayoutIsValid(layout, TARGET_SHAPE)) { - throw nb::value_error("Layout not valid for target shape"); - } new (self) PyTpuVectorLayout(layout); }, nb::arg("bitwidth"), nb::arg("offsets"), nb::arg("tiling"), @@ -435,49 +457,59 @@ NB_MODULE(_tpu_ext, m) { return mlirTpuVectorLayoutGetLayoutRank(self.layout); }, "The number of minormost dimensions tiled by this layout.") - .def_prop_ro( + .def( "has_natural_topology", - [](const PyTpuVectorLayout& self) { + [](const PyTpuVectorLayout& self, + MlirTpuI64TargetTuple target_shape) { return mlirTpuVectorLayoutHasNaturalTopology(self.layout, - TARGET_SHAPE); + target_shape); }, + nb::arg("target_shape"), "True, if every vector register has a layout without jumps.\n" "\n" "By without jumps we mean that traversing vregs over (sub)lanes " "always leads to a contiguous traversal of the (second) minormost " "dimension of data. This is only true for 32-bit types, since " "narrower types use two level tiling.") - .def_prop_ro( + .def( "has_native_tiling", - [](const PyTpuVectorLayout& self) { + [](const PyTpuVectorLayout& self, + MlirTpuI64TargetTuple target_shape) { return mlirTpuVectorLayoutHasNativeTiling(self.layout, - TARGET_SHAPE); + target_shape); }, + nb::arg("target_shape"), "True, if every vector register has a natural \"packed\" topology.\n" "\n" "This is equivalent to has_natural_topology for 32-bit types, but " "generalizes it to narrower values with packed layouts too.") - .def_prop_ro( + .def( "tiles_per_vreg", - [](const PyTpuVectorLayout& self) { - return mlirTpuVectorLayoutTilesPerVreg(self.layout, TARGET_SHAPE); + [](const PyTpuVectorLayout& self, + MlirTpuI64TargetTuple target_shape) { + return mlirTpuVectorLayoutTilesPerVreg(self.layout, target_shape); }, + nb::arg("target_shape"), "How many tiles fit in each vector register.") - .def_prop_ro( + .def( "sublanes_per_tile", - [](const PyTpuVectorLayout& self) { + [](const PyTpuVectorLayout& self, + MlirTpuI64TargetTuple target_shape) { return mlirTpuVectorLayoutSublanesPerTile(self.layout, - TARGET_SHAPE); + target_shape); }, + nb::arg("target_shape"), "The number of sublanes necessary to store each tile.") - .def_prop_ro( + .def( "vreg_slice", - [](const PyTpuVectorLayout& self) { + [](const PyTpuVectorLayout& self, + MlirTpuI64TargetTuple target_shape) { MlirTpuI64TargetTuple vreg_slice = - mlirTpuVectorLayoutVregSlice(self.layout, TARGET_SHAPE); + mlirTpuVectorLayoutVregSlice(self.layout, target_shape); return nb::module_::import_(LAYOUT_DEFS_MODULE) .attr("TargetTuple")(vreg_slice.sublane, vreg_slice.lane); }, + nb::arg("target_shape"), "Returns the size of a window contained in a single vreg.\n" "\n" "We never reuse the same vector register to store data of multiple " @@ -498,20 +530,21 @@ NB_MODULE(_tpu_ext, m) { nb::arg("shape")) .def( "tile_array_shape", - [](const PyTpuVectorLayout& self, nb::sequence shape) { + [](const PyTpuVectorLayout& self, nb::sequence shape, + MlirTpuI64TargetTuple target_shape) { llvm::SmallVector tile_array_shape_vec = sequenceToSmallVector(shape); MlirTpuI64ArrayRef tile_array_shape = mlirTpuVectorLayoutTileArrayShape( self.layout, {tile_array_shape_vec.data(), tile_array_shape_vec.size()}, - TARGET_SHAPE); + target_shape); nb::tuple ret = toPyTuple(tile_array_shape.ptr, tile_array_shape.size); free(tile_array_shape.ptr); return ret; }, - nb::arg("shape"), + nb::arg("shape"), nb::arg("target_shape"), "Returns the shape of an ndarray of vregs needed to represent a " "value.\n" "\n" @@ -527,7 +560,8 @@ NB_MODULE(_tpu_ext, m) { .def( "tile_data_bounds", [](const PyTpuVectorLayout& self, nb::sequence shape, - nb::sequence ixs, std::variant allow_replicated) { + nb::sequence ixs, MlirTpuI64TargetTuple target_shape, + std::variant allow_replicated) { llvm::SmallVector shape_vec = sequenceToSmallVector(shape); llvm::SmallVector ixs_vec = @@ -541,18 +575,19 @@ NB_MODULE(_tpu_ext, m) { if constexpr (std::is_same_v) { return mlirTpuVectorLayoutTileDataBounds( self.layout, getDefaultContext(), shape_vec.data(), - ixs_vec.data(), shape_vec.size(), TARGET_SHAPE, + ixs_vec.data(), shape_vec.size(), target_shape, {ar, ar}); } else { return mlirTpuVectorLayoutTileDataBounds( self.layout, getDefaultContext(), shape_vec.data(), - ixs_vec.data(), shape_vec.size(), TARGET_SHAPE, + ixs_vec.data(), shape_vec.size(), target_shape, {nb::cast(ar[0]), nb::cast(ar[1])}); } }, allow_replicated); }, - nb::arg("shape"), nb::arg("ixs"), nb::arg("allow_replicated") = false, + nb::arg("shape"), nb::arg("ixs"), nb::arg("target_shape"), + nb::arg("allow_replicated") = false, "Returns the bounds of the given tile that hold useful data.\n" "\n" "Arguments:\n" @@ -564,6 +599,7 @@ NB_MODULE(_tpu_ext, m) { "REPLICATED. If True, offsets are allowed to be REPLICATED, but the " "bounds will span the full dimension of the tile (i.e. potentially " "multiple repeats of the actual data).\n" + " target_shape: The target shape of the TPU.\n" "\n" "Returns:\n" " A TargetTuple of slices, indicating the span of useful data " @@ -571,18 +607,20 @@ NB_MODULE(_tpu_ext, m) { .def( "generalizes", [](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other, - std::optional shape) { + std::optional shape, + MlirTpuI64TargetTuple target_shape) { if (shape) { llvm::SmallVector shape_vec = sequenceToSmallVector(*shape); return mlirTpuVectorLayoutGeneralizes( self.layout, other.layout, - {shape_vec.data(), shape_vec.size()}, TARGET_SHAPE); + {shape_vec.data(), shape_vec.size()}, target_shape); } return mlirTpuVectorLayoutGeneralizes(self.layout, other.layout, - {nullptr, 0}, TARGET_SHAPE); + {nullptr, 0}, target_shape); }, - nb::arg("other"), nb::arg("shape").none() = std::nullopt, + nb::arg("other"), nb::kw_only(), + nb::arg("shape").none() = std::nullopt, nb::arg("target_shape"), "Returns True if the other layout is a special case of this one.\n" "\n" "In here, other is considered \"a special case\" when the set of " @@ -606,22 +644,25 @@ NB_MODULE(_tpu_ext, m) { " The generalization relation is larger than usual for some " "shapes. That is, if self.generalizes(other) then also " "self.generalizes(other, shape) for any shape, but that implication " - "does not hold the other way around for some shapes.") + "does not hold the other way around for some shapes.\n" + " target_shape: The target shape of the TPU.") .def( "equivalent_to", [](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other, - std::optional shape) { + std::optional shape, + MlirTpuI64TargetTuple target_shape) { if (shape) { llvm::SmallVector shape_vec = sequenceToSmallVector(*shape); return mlirTpuVectorLayoutEquivalentTo( self.layout, other.layout, - {shape_vec.data(), shape_vec.size()}, TARGET_SHAPE); + {shape_vec.data(), shape_vec.size()}, target_shape); } return mlirTpuVectorLayoutEquivalentTo(self.layout, other.layout, - {nullptr, 0}, TARGET_SHAPE); + {nullptr, 0}, target_shape); }, - nb::arg("other"), nb::arg("shape").none() = std::nullopt, + nb::arg("other"), nb::kw_only(), + nb::arg("shape").none() = std::nullopt, nb::arg("target_shape"), "Returns True if the two layouts are equivalent.\n" "\n" "That is, when all potential vector entries where the value can be " @@ -632,7 +673,8 @@ NB_MODULE(_tpu_ext, m) { " other: The layout compared against self.\n" " shape: An optional shape of the vector to which both layouts " "apply. More layouts are considered equivalent when the shape is " - "specified. Also see the docstring of the generalizes method.") + "specified. Also see the docstring of the generalizes method.\n" + " target_shape: The target shape of the TPU.") .def("__eq__", [](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other) { return mlirTpuVectorLayoutEquals(self.layout, other.layout); @@ -646,7 +688,8 @@ NB_MODULE(_tpu_ext, m) { // TODO(tlongeri): Can we make the first parameter a VectorType? m.def("assemble", [](const MlirType ty, const PyTpuVectorLayout& layout, - nb::object np_arr_obj) -> MlirOperation { + nb::object np_arr_obj, + MlirTpuI64TargetTuple target_shape) -> MlirOperation { // TODO(tlongeri): Remove nb::array::c_style, I only added it because // I couldn't find a simple way to iterate over array data, but it // causes yet another unnecessary copy. @@ -668,12 +711,13 @@ NB_MODULE(_tpu_ext, m) { getDefaultInsertionPoint(), ty, layout.layout, MlirTpuValueArray{MlirTpuI64ArrayRef{shape.data(), shape.size()}, vals.data()}, - TARGET_SHAPE); + target_shape); }); - m.def("disassemble", [](const PyTpuVectorLayout& layout, MlirValue val) { + m.def("disassemble", [](const PyTpuVectorLayout& layout, MlirValue val, + MlirTpuI64TargetTuple target_shape) { DiagnosticCapture diag_capture(getDefaultContext()); MlirTpuValueArray val_arr = mlirTpuDisassemble( - getDefaultInsertionPoint(), layout.layout, val, TARGET_SHAPE); + getDefaultInsertionPoint(), layout.layout, val, target_shape); if (val_arr.vals == nullptr) { diag_capture.throwIfError(); throw nb::value_error("Failed to disassemble");