Skip to content

Commit

Permalink
[Mosaic] Strengthen overly lax checks in apply vector layout
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576512236
  • Loading branch information
apaszke authored and jax authors committed Oct 25, 2023
1 parent 7325b75 commit aa80689
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
16 changes: 15 additions & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_out) {
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
auto result_ty = cast<VectorType>(op.getResult().getType());
auto source_ty = cast<VectorType>(op.getIn().getType());
if (layout_out.bitwidth() != 32) {
return op.emitOpError(
"Not implemented: Only extensions to 32-bit supported");
Expand All @@ -596,6 +597,9 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
return op.emitOpError("Not implemented: Change of layout during the cast");
}
if (layout_in.offsets() != layout_out.offsets()) {
return op.emitOpError("Not implemented: Change of offsets during the cast");
}
switch (layout_in.implicit_dim()) {
case VectorLayout::ImplicitDim::kNone: {
if (layout_in.tiling() != layout_out.tiling()) {
Expand All @@ -620,8 +624,18 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
return op.emitOpError(
"Not implemented: Only casts of lane-oriented values supported");
case VectorLayout::ImplicitDim::kSecondMinor: {
auto is_one_tile = [](VectorType vty, VectorLayout layout) {
auto implicit_shape = layout.implicitShape(vty.getShape());
auto tiled_shape = ArrayRef<int64_t>(implicit_shape).take_back(2);
return (layout.offsets()[0].value_or(0) + tiled_shape[0] <=
layout.tiling()[0]) &&
(layout.offsets()[1].value_or(0) + tiled_shape[1] <=
layout.tiling()[1]);
};
if (input_vregs.dimensions() != absl::Span<const int64_t>{1} ||
output_vregs.dimensions() != absl::Span<const int64_t>{1}) {
output_vregs.dimensions() != absl::Span<const int64_t>{1} ||
!is_one_tile(source_ty, layout_in) ||
!is_one_tile(result_ty, layout_out)) {
return op.emitOpError("Not implemented");
}
if (layout_in.offsets()[0] >= ctx.target_shape[0]) {
Expand Down
16 changes: 15 additions & 1 deletion jaxlib/mosaic/python/apply_vector_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,17 +1787,31 @@ def _ext_op_rule( # pylint: disable=missing-function-docstring
result_ty = ir.VectorType(op.result.type)
if layout_out.bitwidth != 32:
raise NotImplementedError("Only extensions to 32-bit supported")
source_ty = ir.VectorType(op.in_.type)
input_vregs = disassemble(layout_in, op.in_)
output_vregs = np.empty(
layout_out.tile_array_shape(result_ty.shape), dtype=object
)
res_vreg_ty = native_vreg_ty(result_ty.element_type)
if layout_in.implicit_dim != layout_out.implicit_dim:
raise NotImplementedError("Change of layout during the cast")
if layout_in.offsets != layout_out.offsets:
raise NotImplementedError("Change of offsets during the cast")
if layout_in.implicit_dim is not None:
if layout_in.implicit_dim != ImplicitDim.SECOND_MINOR:
raise NotImplementedError("Only casts of lane-oriented values supported")
if input_vregs.shape != (1,) or output_vregs.shape != (1,):
def is_one_tile(vty, layout):
ishape = layout.implicit_shape(vty.shape)
return all(
o + s <= t
for o, s, t in zip(layout.offsets, ishape[-2:], layout.tiling)
)
if (
input_vregs.size != 1
or output_vregs.size != 1
or not is_one_tile(source_ty, layout_in)
or not is_one_tile(result_ty, layout_out)
):
raise NotImplementedError
if layout_in.offsets[0] >= TARGET_SHAPE.sublanes:
raise NotImplementedError
Expand Down

0 comments on commit aa80689

Please sign in to comment.