diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 2b179542da5e..d2fb8c85b57c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -582,6 +582,7 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, const VectorLayout &layout_out) { ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); auto result_ty = cast(op.getResult().getType()); + auto source_ty = cast(op.getIn().getType()); if (layout_out.bitwidth() != 32) { return op.emitOpError( "Not implemented: Only extensions to 32-bit supported"); @@ -596,6 +597,9 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, if (layout_in.implicit_dim() != layout_out.implicit_dim()) { return op.emitOpError("Not implemented: Change of layout during the cast"); } + if (layout_in.offsets() != layout_out.offsets()) { + return op.emitOpError("Not implemented: Change of offsets during the cast"); + } switch (layout_in.implicit_dim()) { case VectorLayout::ImplicitDim::kNone: { if (layout_in.tiling() != layout_out.tiling()) { @@ -620,8 +624,18 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, return op.emitOpError( "Not implemented: Only casts of lane-oriented values supported"); case VectorLayout::ImplicitDim::kSecondMinor: { + auto is_one_tile = [](VectorType vty, VectorLayout layout) { + auto implicit_shape = layout.implicitShape(vty.getShape()); + auto tiled_shape = ArrayRef(implicit_shape).take_back(2); + return (layout.offsets()[0].value_or(0) + tiled_shape[0] <= + layout.tiling()[0]) && + (layout.offsets()[1].value_or(0) + tiled_shape[1] <= + layout.tiling()[1]); + }; if (input_vregs.dimensions() != absl::Span{1} || - output_vregs.dimensions() != absl::Span{1}) { + output_vregs.dimensions() != absl::Span{1} || + !is_one_tile(source_ty, layout_in) || + !is_one_tile(result_ty, layout_out)) { return op.emitOpError("Not implemented"); } if (layout_in.offsets()[0] >= ctx.target_shape[0]) { diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index b7a60f087db7..d1cb6426bcb2 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -1787,6 +1787,7 @@ def _ext_op_rule( # pylint: disable=missing-function-docstring result_ty = ir.VectorType(op.result.type) if layout_out.bitwidth != 32: raise NotImplementedError("Only extensions to 32-bit supported") + source_ty = ir.VectorType(op.in_.type) input_vregs = disassemble(layout_in, op.in_) output_vregs = np.empty( layout_out.tile_array_shape(result_ty.shape), dtype=object @@ -1794,10 +1795,23 @@ def _ext_op_rule( # pylint: disable=missing-function-docstring res_vreg_ty = native_vreg_ty(result_ty.element_type) if layout_in.implicit_dim != layout_out.implicit_dim: raise NotImplementedError("Change of layout during the cast") + if layout_in.offsets != layout_out.offsets: + raise NotImplementedError("Change of offsets during the cast") if layout_in.implicit_dim is not None: if layout_in.implicit_dim != ImplicitDim.SECOND_MINOR: raise NotImplementedError("Only casts of lane-oriented values supported") - if input_vregs.shape != (1,) or output_vregs.shape != (1,): + def is_one_tile(vty, layout): + ishape = layout.implicit_shape(vty.shape) + return all( + o + s <= t + for o, s, t in zip(layout.offsets, ishape[-2:], layout.tiling) + ) + if ( + input_vregs.size != 1 + or output_vregs.size != 1 + or not is_one_tile(source_ty, layout_in) + or not is_one_tile(result_ty, layout_out) + ): raise NotImplementedError if layout_in.offsets[0] >= TARGET_SHAPE.sublanes: raise NotImplementedError