Skip to content

Commit

Permalink
[Mosaic] Allow passing ApplyVectorLayoutCtx to tpu.apply_layout_op.
Browse files Browse the repository at this point in the history
To make it the same with C++ API. While I'm here, fix a bug in test_concatenate.

PiperOrigin-RevId: 716016244
  • Loading branch information
WindQAQ authored and Google-ML-Automation committed Jan 16, 2025
1 parent d3ba1eb commit 4a9cc9f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 12 deletions.
5 changes: 2 additions & 3 deletions jaxlib/mlir/_mlir_libs/tpu_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,10 +693,9 @@ NB_MODULE(_tpu_ext, m) {
});

m.def("apply_layout_op",
[](int hardware_generation, const MlirOperation c_op) {
[](MlirTpuApplyVectorLayoutContext ctx, const MlirOperation c_op) {
DiagnosticCapture diag_capture(getDefaultContext());
MlirLogicalResult res =
mlirTpuApplyLayoutOp(hardware_generation, c_op, TARGET_SHAPE);
MlirLogicalResult res = mlirTpuApplyLayoutOp(ctx, c_op);
if (mlirLogicalResultIsFailure(res)) {
diag_capture.throwIfError();
throw std::runtime_error("applyLayoutOp failed");
Expand Down
11 changes: 4 additions & 7 deletions jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,10 @@ MlirTpuValueArray mlirTpuDisassemble(MlirTpuInsertionPoint insertion_point,
return MlirTpuValueArrayFromXlaArray(std::move(failure_or_vals).value());
}

MlirLogicalResult mlirTpuApplyLayoutOp(int hardware_generation,
MlirOperation op,
MlirTpuI64TargetTuple target_shape) {
mlir::tpu::ApplyVectorLayoutContext ctx{
.hardware_generation = hardware_generation,
.target_shape = unwrap(target_shape)};
return wrap(mlir::tpu::applyLayoutOp(ctx, *unwrap(op)));
MlirLogicalResult mlirTpuApplyLayoutOp(MlirTpuApplyVectorLayoutContext ctx,
MlirOperation op) {
mlir::tpu::ApplyVectorLayoutContext unwrapped_ctx = unwrap(ctx);
return wrap(mlir::tpu::applyLayoutOp(unwrapped_ctx, *unwrap(op)));
}

MlirValue mlirTpuRelayout(MlirTpuInsertionPoint insertion_point, MlirValue val,
Expand Down
3 changes: 1 addition & 2 deletions jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ MLIR_CAPI_EXPORTED MlirTpuValueArray mlirTpuDisassemble(
MlirValue val, MlirTpuI64TargetTuple target_shape);

MLIR_CAPI_EXPORTED MlirLogicalResult
mlirTpuApplyLayoutOp(int hardware_generation, MlirOperation op,
MlirTpuI64TargetTuple target_shape);
mlirTpuApplyLayoutOp(MlirTpuApplyVectorLayoutContext ctx, MlirOperation op);

// Returns null on failure
MLIR_CAPI_EXPORTED MlirValue
Expand Down

0 comments on commit 4a9cc9f

Please sign in to comment.