Skip to content

Commit

Permalink
[Mosaic] Remove hardcoded TARGET_SHAPE and align Python/C++ APIs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716383597
  • Loading branch information
WindQAQ authored and Google-ML-Automation committed Jan 21, 2025
1 parent 70a5175 commit b43b1b6
Showing 1 changed file with 98 additions and 54 deletions.
152 changes: 98 additions & 54 deletions jaxlib/mlir/_mlir_libs/tpu_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ namespace {
constexpr const char LAYOUT_DEFS_MODULE[] =
"jax.jaxlib.mosaic.python.layout_defs";
constexpr const char IR_MODULE[] = "jaxlib.mlir.ir";
// TODO(tlongeri): Get rid of this somehow
constexpr MlirTpuI64TargetTuple TARGET_SHAPE{8, 128};
constexpr MlirTpuI64TargetTuple DEFAULT_TARGET_SHAPE{8, 128};

// TODO(tlongeri): Add type annotations via nanobind once there is
// a release for it (and maybe add a custom Sequence<T> one as well).
Expand Down Expand Up @@ -163,6 +162,29 @@ struct nb::detail::type_caster<MlirTpuDirection> {
}
};

template <>
struct nb::detail::type_caster<MlirTpuI64TargetTuple> {
NB_TYPE_CASTER(MlirTpuI64TargetTuple, const_name("TargetTuple"));

bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
auto target_tuple_cls =
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("TargetTuple");
if (!nb::isinstance(src, target_tuple_cls)) {
return false;
}
value = {nb::cast<int64_t>(src.attr("sublanes")),
nb::cast<int64_t>(src.attr("lanes"))};
return true;
}

static handle from_cpp(MlirTpuI64TargetTuple target_tuple, rv_policy policy,
cleanup_list* cleanup) noexcept {
nb::object target_tuple_cls =
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("TargetTuple");
return target_tuple_cls(target_tuple.sublane, target_tuple.lane).release();
}
};

namespace {
// Handler for use with MLIR C API print functions. The 2nd parameter is an
// opaque pointer to "user data" that should always be a string.
Expand Down Expand Up @@ -335,35 +357,38 @@ NB_MODULE(_tpu_ext, m) {
.max_sublanes_in_scratch = max_sublanes_in_scratch};
},
nb::arg("hardware_generation") = -1,
nb::arg("target_shape") = toPyTuple(TARGET_SHAPE),
nb::arg("target_shape") = toPyTuple(DEFAULT_TARGET_SHAPE),
nb::arg("mxu_shape") = nb::make_tuple(128, 128),
nb::arg("max_sublanes_in_scratch") = 0);

nb::class_<MlirTpuVregDataBounds>(m, "VRegDataBounds")
.def("mask_varies_along",
[](MlirTpuVregDataBounds self, MlirTpuDirection direction) {
[](MlirTpuVregDataBounds self, MlirTpuDirection direction,
MlirTpuI64TargetTuple target_shape) {
return mlirTpuVregDataBoundsMaskVariesAlong(self, direction,
TARGET_SHAPE);
target_shape);
})
.def("complete",
[](MlirTpuVregDataBounds self, MlirTpuI64TargetTuple target_shape) {
return mlirTpuVregDataBoundsIsComplete(self, target_shape);
})
.def_prop_ro("complete",
[](MlirTpuVregDataBounds self) {
return mlirTpuVregDataBoundsIsComplete(self, TARGET_SHAPE);
})
.def("get_vector_mask",
[](MlirTpuVregDataBounds self, int generation) {
[](MlirTpuVregDataBounds self, int generation,
MlirTpuI64TargetTuple target_shape) {
// TODO: Does this work? Test in Python
MlirValue mask = mlirTpuVregDataBoundsGetVectorMask(
self, getDefaultInsertionPoint(), getDefaultLocation(),
generation, TARGET_SHAPE);
generation, target_shape);
if (mask.ptr == nullptr) {
throw std::runtime_error("getVectorMask failed");
}
return mask;
})
.def("get_sublane_mask", [](MlirTpuVregDataBounds self) {
return mlirTpuVregDataBoundsGetSublaneMask(self, getDefaultContext(),
TARGET_SHAPE);
});
.def("get_sublane_mask",
[](MlirTpuVregDataBounds self, MlirTpuI64TargetTuple target_shape) {
return mlirTpuVregDataBoundsGetSublaneMask(
self, getDefaultContext(), target_shape);
});

// TODO(tlongeri): More precise argument type annotations. There currently
// seems to be no way to define your own?
Expand All @@ -384,9 +409,6 @@ NB_MODULE(_tpu_ext, m) {
offsetFromPyOffset(offsets[1])},
{nb::cast<int64_t>(tiling[0]), nb::cast<int64_t>(tiling[1])},
implicit_dim);
if (!mlirTpuVectorLayoutIsValid(layout, TARGET_SHAPE)) {
throw nb::value_error("Layout not valid for target shape");
}
new (self) PyTpuVectorLayout(layout);
},
nb::arg("bitwidth"), nb::arg("offsets"), nb::arg("tiling"),
Expand Down Expand Up @@ -435,49 +457,59 @@ NB_MODULE(_tpu_ext, m) {
return mlirTpuVectorLayoutGetLayoutRank(self.layout);
},
"The number of minormost dimensions tiled by this layout.")
.def_prop_ro(
.def(
"has_natural_topology",
[](const PyTpuVectorLayout& self) {
[](const PyTpuVectorLayout& self,
MlirTpuI64TargetTuple target_shape) {
return mlirTpuVectorLayoutHasNaturalTopology(self.layout,
TARGET_SHAPE);
target_shape);
},
nb::arg("target_shape"),
"True, if every vector register has a layout without jumps.\n"
"\n"
"By without jumps we mean that traversing vregs over (sub)lanes "
"always leads to a contiguous traversal of the (second) minormost "
"dimension of data. This is only true for 32-bit types, since "
"narrower types use two level tiling.")
.def_prop_ro(
.def(
"has_native_tiling",
[](const PyTpuVectorLayout& self) {
[](const PyTpuVectorLayout& self,
MlirTpuI64TargetTuple target_shape) {
return mlirTpuVectorLayoutHasNativeTiling(self.layout,
TARGET_SHAPE);
target_shape);
},
nb::arg("target_shape"),
"True, if every vector register has a natural \"packed\" topology.\n"
"\n"
"This is equivalent to has_natural_topology for 32-bit types, but "
"generalizes it to narrower values with packed layouts too.")
.def_prop_ro(
.def(
"tiles_per_vreg",
[](const PyTpuVectorLayout& self) {
return mlirTpuVectorLayoutTilesPerVreg(self.layout, TARGET_SHAPE);
[](const PyTpuVectorLayout& self,
MlirTpuI64TargetTuple target_shape) {
return mlirTpuVectorLayoutTilesPerVreg(self.layout, target_shape);
},
nb::arg("target_shape"),
"How many tiles fit in each vector register.")
.def_prop_ro(
.def(
"sublanes_per_tile",
[](const PyTpuVectorLayout& self) {
[](const PyTpuVectorLayout& self,
MlirTpuI64TargetTuple target_shape) {
return mlirTpuVectorLayoutSublanesPerTile(self.layout,
TARGET_SHAPE);
target_shape);
},
nb::arg("target_shape"),
"The number of sublanes necessary to store each tile.")
.def_prop_ro(
.def(
"vreg_slice",
[](const PyTpuVectorLayout& self) {
[](const PyTpuVectorLayout& self,
MlirTpuI64TargetTuple target_shape) {
MlirTpuI64TargetTuple vreg_slice =
mlirTpuVectorLayoutVregSlice(self.layout, TARGET_SHAPE);
mlirTpuVectorLayoutVregSlice(self.layout, target_shape);
return nb::module_::import_(LAYOUT_DEFS_MODULE)
.attr("TargetTuple")(vreg_slice.sublane, vreg_slice.lane);
},
nb::arg("target_shape"),
"Returns the size of a window contained in a single vreg.\n"
"\n"
"We never reuse the same vector register to store data of multiple "
Expand All @@ -498,20 +530,21 @@ NB_MODULE(_tpu_ext, m) {
nb::arg("shape"))
.def(
"tile_array_shape",
[](const PyTpuVectorLayout& self, nb::sequence shape) {
[](const PyTpuVectorLayout& self, nb::sequence shape,
MlirTpuI64TargetTuple target_shape) {
llvm::SmallVector<int64_t> tile_array_shape_vec =
sequenceToSmallVector<int64_t>(shape);
MlirTpuI64ArrayRef tile_array_shape =
mlirTpuVectorLayoutTileArrayShape(
self.layout,
{tile_array_shape_vec.data(), tile_array_shape_vec.size()},
TARGET_SHAPE);
target_shape);
nb::tuple ret =
toPyTuple(tile_array_shape.ptr, tile_array_shape.size);
free(tile_array_shape.ptr);
return ret;
},
nb::arg("shape"),
nb::arg("shape"), nb::arg("target_shape"),
"Returns the shape of an ndarray of vregs needed to represent a "
"value.\n"
"\n"
Expand All @@ -527,7 +560,8 @@ NB_MODULE(_tpu_ext, m) {
.def(
"tile_data_bounds",
[](const PyTpuVectorLayout& self, nb::sequence shape,
nb::sequence ixs, std::variant<bool, nb::tuple> allow_replicated) {
nb::sequence ixs, MlirTpuI64TargetTuple target_shape,
std::variant<bool, nb::tuple> allow_replicated) {
llvm::SmallVector<int64_t> shape_vec =
sequenceToSmallVector<int64_t>(shape);
llvm::SmallVector<int64_t> ixs_vec =
Expand All @@ -541,18 +575,19 @@ NB_MODULE(_tpu_ext, m) {
if constexpr (std::is_same_v<decltype(ar), bool>) {
return mlirTpuVectorLayoutTileDataBounds(
self.layout, getDefaultContext(), shape_vec.data(),
ixs_vec.data(), shape_vec.size(), TARGET_SHAPE,
ixs_vec.data(), shape_vec.size(), target_shape,
{ar, ar});
} else {
return mlirTpuVectorLayoutTileDataBounds(
self.layout, getDefaultContext(), shape_vec.data(),
ixs_vec.data(), shape_vec.size(), TARGET_SHAPE,
ixs_vec.data(), shape_vec.size(), target_shape,
{nb::cast<bool>(ar[0]), nb::cast<bool>(ar[1])});
}
},
allow_replicated);
},
nb::arg("shape"), nb::arg("ixs"), nb::arg("allow_replicated") = false,
nb::arg("shape"), nb::arg("ixs"), nb::arg("target_shape"),
nb::arg("allow_replicated") = false,
"Returns the bounds of the given tile that hold useful data.\n"
"\n"
"Arguments:\n"
Expand All @@ -564,25 +599,28 @@ NB_MODULE(_tpu_ext, m) {
"REPLICATED. If True, offsets are allowed to be REPLICATED, but the "
"bounds will span the full dimension of the tile (i.e. potentially "
"multiple repeats of the actual data).\n"
" target_shape: The target shape of the TPU.\n"
"\n"
"Returns:\n"
" A TargetTuple of slices, indicating the span of useful data "
"within the tile selected by idx.")
.def(
"generalizes",
[](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other,
std::optional<nb::sequence> shape) {
std::optional<nb::sequence> shape,
MlirTpuI64TargetTuple target_shape) {
if (shape) {
llvm::SmallVector<int64_t> shape_vec =
sequenceToSmallVector<int64_t>(*shape);
return mlirTpuVectorLayoutGeneralizes(
self.layout, other.layout,
{shape_vec.data(), shape_vec.size()}, TARGET_SHAPE);
{shape_vec.data(), shape_vec.size()}, target_shape);
}
return mlirTpuVectorLayoutGeneralizes(self.layout, other.layout,
{nullptr, 0}, TARGET_SHAPE);
{nullptr, 0}, target_shape);
},
nb::arg("other"), nb::arg("shape").none() = std::nullopt,
nb::arg("other"), nb::kw_only(),
nb::arg("shape").none() = std::nullopt, nb::arg("target_shape"),
"Returns True if the other layout is a special case of this one.\n"
"\n"
"In here, other is considered \"a special case\" when the set of "
Expand All @@ -606,22 +644,25 @@ NB_MODULE(_tpu_ext, m) {
" The generalization relation is larger than usual for some "
"shapes. That is, if self.generalizes(other) then also "
"self.generalizes(other, shape) for any shape, but that implication "
"does not hold the other way around for some shapes.")
"does not hold the other way around for some shapes.\n"
" target_shape: The target shape of the TPU.")
.def(
"equivalent_to",
[](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other,
std::optional<nb::sequence> shape) {
std::optional<nb::sequence> shape,
MlirTpuI64TargetTuple target_shape) {
if (shape) {
llvm::SmallVector<int64_t> shape_vec =
sequenceToSmallVector<int64_t>(*shape);
return mlirTpuVectorLayoutEquivalentTo(
self.layout, other.layout,
{shape_vec.data(), shape_vec.size()}, TARGET_SHAPE);
{shape_vec.data(), shape_vec.size()}, target_shape);
}
return mlirTpuVectorLayoutEquivalentTo(self.layout, other.layout,
{nullptr, 0}, TARGET_SHAPE);
{nullptr, 0}, target_shape);
},
nb::arg("other"), nb::arg("shape").none() = std::nullopt,
nb::arg("other"), nb::kw_only(),
nb::arg("shape").none() = std::nullopt, nb::arg("target_shape"),
"Returns True if the two layouts are equivalent.\n"
"\n"
"That is, when all potential vector entries where the value can be "
Expand All @@ -632,7 +673,8 @@ NB_MODULE(_tpu_ext, m) {
" other: The layout compared against self.\n"
" shape: An optional shape of the vector to which both layouts "
"apply. More layouts are considered equivalent when the shape is "
"specified. Also see the docstring of the generalizes method.")
"specified. Also see the docstring of the generalizes method.\n"
" target_shape: The target shape of the TPU.")
.def("__eq__",
[](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other) {
return mlirTpuVectorLayoutEquals(self.layout, other.layout);
Expand All @@ -646,7 +688,8 @@ NB_MODULE(_tpu_ext, m) {
// TODO(tlongeri): Can we make the first parameter a VectorType?
m.def("assemble",
[](const MlirType ty, const PyTpuVectorLayout& layout,
nb::object np_arr_obj) -> MlirOperation {
nb::object np_arr_obj,
MlirTpuI64TargetTuple target_shape) -> MlirOperation {
// TODO(tlongeri): Remove nb::array::c_style, I only added it because
// I couldn't find a simple way to iterate over array data, but it
// causes yet another unnecessary copy.
Expand All @@ -668,12 +711,13 @@ NB_MODULE(_tpu_ext, m) {
getDefaultInsertionPoint(), ty, layout.layout,
MlirTpuValueArray{MlirTpuI64ArrayRef{shape.data(), shape.size()},
vals.data()},
TARGET_SHAPE);
target_shape);
});
m.def("disassemble", [](const PyTpuVectorLayout& layout, MlirValue val) {
m.def("disassemble", [](const PyTpuVectorLayout& layout, MlirValue val,
MlirTpuI64TargetTuple target_shape) {
DiagnosticCapture diag_capture(getDefaultContext());
MlirTpuValueArray val_arr = mlirTpuDisassemble(
getDefaultInsertionPoint(), layout.layout, val, TARGET_SHAPE);
getDefaultInsertionPoint(), layout.layout, val, target_shape);
if (val_arr.vals == nullptr) {
diag_capture.throwIfError();
throw nb::value_error("Failed to disassemble");
Expand Down

0 comments on commit b43b1b6

Please sign in to comment.