Skip to content

Commit

Permalink
[Mosaic] Fix handling of i1 splat constants
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694248723
  • Loading branch information
tlongeri authored and Google-ML-Automation committed Nov 7, 2024
1 parent 3b2e4a1 commit 04a6652
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3190,8 +3190,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op,
}
const VectorLayout &layout_out = *layouts_out.front();
DenseElementsAttr value = cast<DenseElementsAttr>(constant_op.getValue());
const VectorType target_vty =
getNativeVregType(vty.getElementType(), ctx.target_shape);
const VectorType target_vty = getNativeVregOrVmaskType(
vty.getElementType(), layout_out.bitwidth(), ctx.target_shape);
if (value.isSplat()) {
if (layout_out.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) {
return op.emitOpError(
Expand Down
5 changes: 5 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ class VectorLayoutInferer {
TPU_CHECK_OP(ty.getRank() > 0, "rank 0 vectors unsupported");
TPU_CHECK_OP(elems, "expected vector constants to use DenseElementsAttr");
auto bitwidth = ty.getElementTypeBitWidth();
if (bitwidth == 1) {
// i1 is a special case where the layout bitwidth can be different from
// the element bitwidth, see comment in VectorLayout class
bitwidth = kNativeBitwidth;
}
if (elems.isSplat()) {
if (ty.getRank() == 1) {
// Here, we choose to lay out along lanes arbitrarily. It would be
Expand Down

0 comments on commit 04a6652

Please sign in to comment.