Skip to content

Commit

Permalink
Add patterns in ShapeLegalizeToStablehlo for tensor::ExtractOp (#2075)
Browse files Browse the repository at this point in the history
Add patterns in ShapeLegalizeToStablehlo for tensor::ExtractOp

1. Add support for tensor.Extract, which is lowered to mhlo.slice
2. Support a special case from TF graph:
   tensor.extract tensor<?xi32> -> i32
   arith.index_cast i32 -> index
   This is lower to:
   mhlo.slice tensor<?xi32> -> tensor<1xi32>
   mhlo.reshape tensor<ixi32> -> tensor<i32>
   unrealized_conversion_cast tensor<i32> -> i32
   unrealized_conversion_cast i32 -> tensor<i32>
   unrealized_conversion_cast tensor<i32> -> index

Co-authored-by: Andy Wan <[email protected]>
  • Loading branch information
GleasonK and quanwanandy authored Mar 6, 2024
1 parent 2cdab42 commit 7c4ccac
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 8 deletions.
67 changes: 59 additions & 8 deletions stablehlo/tests/shape_legalize_to_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,16 @@ func.func @index_cast_scalar_index_to_i64(%arg0: index) -> i64 {

// -----

func.func @index_cast_scalar_i32_to_index(%arg0: i32) -> index {
// CHECK: %[[CAST_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : i32 to tensor<i32>
// CHECK-NEXT: %[[CAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CAST_I32]] : tensor<i32> to index
// CHECK-NEXT: return %[[CAST_INDEX]] : index
%0 = arith.index_cast %arg0 : i32 to index
return %0 : index
}

// -----

func.func @index_cast_index_to_i8(%arg0: tensor<2xindex>) -> tensor<2xi8> {
// expected-error@+1 {{failed to legalize operation 'arith.index_cast' that was explicitly marked illegal}}
%0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<2xi8>
Expand All @@ -295,14 +305,6 @@ func.func @index_cast_i8_to_index(%arg0: tensor<2xi8>) -> tensor<2xindex> {

// -----

func.func @index_cast_scalar_i32_to_index(%arg0: i32) -> index {
// expected-error@+1 {{failed to legalize operation 'arith.index_cast' that was explicitly marked illegal}}
%0 = arith.index_cast %arg0 : i32 to index
return %0 : index
}

// -----

// CHECK-LABEL: func @muli
func.func @muli(%arg0: index, %arg1: index) -> index {
%0 = arith.muli %arg0, %arg1 : index
Expand Down Expand Up @@ -336,3 +338,52 @@ func.func @muli_i32(%arg0: i32, %arg1: i32) -> i32 {
%0 = arith.muli %arg0, %arg1 : i32
return %0 : i32
}

// -----

// CHECK-LABEL: func @tensor_extract
func.func @tensor_extract(%arg0: tensor<3x3xindex>) -> index {
%c1 = arith.constant 0 : index
%c2 = arith.constant 1 : index
%0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xindex>
return %0 : index
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<3x3xindex> to tensor<3x3xi32>
// CHECK: %[[SLICE:.*]] = stablehlo.slice %[[CAST]] [0:1, 1:2] : (tensor<3x3xi32>) -> tensor<1x1xi32>
// CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %[[SLICE]] : (tensor<1x1xi32>) -> tensor<i32>
// CHECK-NEXT: %[[RES_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE]] : tensor<i32> to index
// CHECK-NEXT: return %[[RES_INDEX]] : index
}

// -----

// CHECK-LABEL: func @tensor_extract_i32
func.func @tensor_extract_i32(%arg0: tensor<3x3xi32>) -> i32 {
%c1 = arith.constant 0 : index
%c2 = arith.constant 1 : index
%0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xi32>
return %0 : i32
// CHECK: %[[SLICE:.*]] = stablehlo.slice %arg0 [0:1, 1:2] : (tensor<3x3xi32>) -> tensor<1x1xi32>
// CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %[[SLICE]] : (tensor<1x1xi32>) -> tensor<i32>
// CHECK-NEXT: %[[RES_I32:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE]] : tensor<i32> to i32
// CHECK-NEXT: return %[[RES_I32]] : i32
}

// -----

func.func @tensor_extract_out_of_range(%arg0: tensor<3x3xindex>) -> index {
%c1 = arith.constant 4 : index
%c2 = arith.constant 4 : index
// expected-error@+1 {{failed to legalize operation 'tensor.extract' that was explicitly marked illegal}}
%0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xindex>
return %0 : index
}

// -----

func.func @tensor_extract_dynamic(%arg0: tensor<?x3xindex>) -> index {
%c1 = arith.constant 0 : index
%c2 = arith.constant 2 : index
// expected-error@+1 {{failed to legalize operation 'tensor.extract' that was explicitly marked illegal}}
%0 = tensor.extract %arg0[%c1, %c2] : tensor<?x3xindex>
return %0 : index
}
68 changes: 68 additions & 0 deletions stablehlo/transforms/ShapeLegalizeToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,21 @@ struct ConvertIndexCastOpPattern : public OpRewritePattern<arith::IndexCastOp> {
op.getLoc(), op.getOut().getType(), result));
return success();
}
if (!op.getIn().getType().isa<ShapedType>() &&
isIndexOrShapedOfIndex(op.getOut())) {
// Handle a special case of i32 -> index.
// This is converted to the following sequence:
// unrealized_conversion_cast i32 -> tensor<i32>
// unrealized_conversion_cast tensor<i32> -> index
result = rewriter
.create<UnrealizedConversionCastOp>(
op.getLoc(), RankedTensorType::get({}, result.getType()),
result)
.getResult(0);
rewriter.replaceOp(op, rewriter.create<UnrealizedConversionCastOp>(
op.getLoc(), op.getOut().getType(), result));
return success();
}

if (isIndexOrShapedOfIndex(result)) {
result = castToI32(rewriter, op.getLoc(), result);
Expand Down Expand Up @@ -439,6 +454,58 @@ struct ConvertTensorDimPattern : public OpRewritePattern<tensor::DimOp> {
}
};

struct ConvertTensorExtractPattern
: public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp op,
PatternRewriter& rewriter) const override {
SmallVector<int64_t> indices;
auto tensorType = op.getTensor().getType();
// We only support getting static indices.
for (auto index : op.getIndices()) {
auto constIndex =
dyn_cast_or_null<arith::ConstantIndexOp>(index.getDefiningOp());
if (!constIndex)
return rewriter.notifyMatchFailure(op, "expected constant index op");

// Check if the index is out of range.
int idx = indices.size();
if (tensorType.isDynamicDim(idx) ||
constIndex.value() >= tensorType.getDimSize(idx))
return rewriter.notifyMatchFailure(op, "index out of range");

indices.push_back(constIndex.value());
}
auto input = castToI32(rewriter, op.getLoc(), op.getTensor());
auto startIndices = rewriter.getDenseI64ArrayAttr(indices);
for (auto& index : indices) {
index += 1;
}
auto limitIndices = rewriter.getDenseI64ArrayAttr(indices);

Value extractedTensor = rewriter.create<SliceOp>(
op.getLoc(), input, startIndices, limitIndices,
/*strides=*/
rewriter.getDenseI64ArrayAttr(SmallVector<int64_t>(indices.size(), 1)));
Value extractedScalarTensor = rewriter.create<ReshapeOp>(
op.getLoc(), RankedTensorType::get({}, rewriter.getI32Type()),
extractedTensor);
if (getElementTypeOrSelf(op.getResult().getType()).isIndex()) {
auto extractedIndex =
castToIndex(rewriter, op.getLoc(), extractedScalarTensor);
rewriter.replaceOp(op, extractedIndex);
} else {
// For the special case when the input is a i32 tensor and output is i32,
// convert the result back to i32 to be consistent:
// unrealized_conversion_cast tensor<i32> -> i32
rewriter.replaceOp(op, rewriter.create<UnrealizedConversionCastOp>(
op.getLoc(), op.getResult().getType(),
extractedScalarTensor));
}
return success();
}
};

struct ConvertTensorFromElementsPattern
: public OpRewritePattern<tensor::FromElementsOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -599,6 +666,7 @@ void populateShapeToStablehloPatterns(MLIRContext* context,
patterns->add<CastOperandsPattern<DynamicBroadcastInDimOp>>(context);
patterns->add<CastOperandsPattern<DynamicReshapeOp>>(context);
patterns->add<ConvertTensorDimPattern>(context);
patterns->add<ConvertTensorExtractPattern>(context);
patterns->add<ConvertTensorFromElementsPattern>(context);
}

Expand Down

0 comments on commit 7c4ccac

Please sign in to comment.