From 7c4ccacffca0c55a43a9b1982086419a8831c5d3 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Wed, 6 Mar 2024 15:29:40 -0600 Subject: [PATCH] Add patterns in ShapeLegalizeToStablehlo for tensor::ExtractOp (#2075) Add patterns in ShapeLegalizeToStablehlo for tensor::ExtractOp 1. Add support for tensor.Extract, which is lowered to mhlo.slice 2. Support a special case from TF graph: tensor.extract tensor -> i32 arith.index_cast i32 -> index This is lower to: mhlo.slice tensor -> tensor<1xi32> mhlo.reshape tensor -> tensor unrealized_conversion_cast tensor -> i32 unrealized_conversion_cast i32 -> tensor unrealized_conversion_cast tensor -> index Co-authored-by: Andy Wan --- .../tests/shape_legalize_to_stablehlo.mlir | 67 +++++++++++++++--- .../transforms/ShapeLegalizeToStablehlo.cpp | 68 +++++++++++++++++++ 2 files changed, 127 insertions(+), 8 deletions(-) diff --git a/stablehlo/tests/shape_legalize_to_stablehlo.mlir b/stablehlo/tests/shape_legalize_to_stablehlo.mlir index f891ddb92ea..4869f2d233e 100644 --- a/stablehlo/tests/shape_legalize_to_stablehlo.mlir +++ b/stablehlo/tests/shape_legalize_to_stablehlo.mlir @@ -279,6 +279,16 @@ func.func @index_cast_scalar_index_to_i64(%arg0: index) -> i64 { // ----- +func.func @index_cast_scalar_i32_to_index(%arg0: i32) -> index { + // CHECK: %[[CAST_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : i32 to tensor + // CHECK-NEXT: %[[CAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CAST_I32]] : tensor to index + // CHECK-NEXT: return %[[CAST_INDEX]] : index + %0 = arith.index_cast %arg0 : i32 to index + return %0 : index +} + +// ----- + func.func @index_cast_index_to_i8(%arg0: tensor<2xindex>) -> tensor<2xi8> { // expected-error@+1 {{failed to legalize operation 'arith.index_cast' that was explicitly marked illegal}} %0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<2xi8> @@ -295,14 +305,6 @@ func.func @index_cast_i8_to_index(%arg0: tensor<2xi8>) -> tensor<2xindex> { // ----- -func.func @index_cast_scalar_i32_to_index(%arg0: i32) -> index { - // expected-error@+1 {{failed to legalize operation 'arith.index_cast' that was explicitly marked illegal}} - %0 = arith.index_cast %arg0 : i32 to index - return %0 : index -} - -// ----- - // CHECK-LABEL: func @muli func.func @muli(%arg0: index, %arg1: index) -> index { %0 = arith.muli %arg0, %arg1 : index @@ -336,3 +338,52 @@ func.func @muli_i32(%arg0: i32, %arg1: i32) -> i32 { %0 = arith.muli %arg0, %arg1 : i32 return %0 : i32 } + +// ----- + +// CHECK-LABEL: func @tensor_extract +func.func @tensor_extract(%arg0: tensor<3x3xindex>) -> index { + %c1 = arith.constant 0 : index + %c2 = arith.constant 1 : index + %0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xindex> + return %0 : index + // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<3x3xindex> to tensor<3x3xi32> + // CHECK: %[[SLICE:.*]] = stablehlo.slice %[[CAST]] [0:1, 1:2] : (tensor<3x3xi32>) -> tensor<1x1xi32> + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %[[SLICE]] : (tensor<1x1xi32>) -> tensor + // CHECK-NEXT: %[[RES_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE]] : tensor to index + // CHECK-NEXT: return %[[RES_INDEX]] : index +} + +// ----- + +// CHECK-LABEL: func @tensor_extract_i32 +func.func @tensor_extract_i32(%arg0: tensor<3x3xi32>) -> i32 { + %c1 = arith.constant 0 : index + %c2 = arith.constant 1 : index + %0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xi32> + return %0 : i32 + // CHECK: %[[SLICE:.*]] = stablehlo.slice %arg0 [0:1, 1:2] : (tensor<3x3xi32>) -> tensor<1x1xi32> + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %[[SLICE]] : (tensor<1x1xi32>) -> tensor + // CHECK-NEXT: %[[RES_I32:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE]] : tensor to i32 + // CHECK-NEXT: return %[[RES_I32]] : i32 +} + +// ----- + +func.func @tensor_extract_out_of_range(%arg0: tensor<3x3xindex>) -> index { + %c1 = arith.constant 4 : index + %c2 = arith.constant 4 : index + // expected-error@+1 {{failed to legalize operation 'tensor.extract' that was explicitly marked illegal}} + %0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xindex> + return %0 : index +} + +// ----- + +func.func @tensor_extract_dynamic(%arg0: tensor) -> index { + %c1 = arith.constant 0 : index + %c2 = arith.constant 2 : index + // expected-error@+1 {{failed to legalize operation 'tensor.extract' that was explicitly marked illegal}} + %0 = tensor.extract %arg0[%c1, %c2] : tensor + return %0 : index +} diff --git a/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp b/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp index ccf3b2fed2d..d8b54428b09 100644 --- a/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp @@ -324,6 +324,21 @@ struct ConvertIndexCastOpPattern : public OpRewritePattern { op.getLoc(), op.getOut().getType(), result)); return success(); } + if (!op.getIn().getType().isa() && + isIndexOrShapedOfIndex(op.getOut())) { + // Handle a special case of i32 -> index. + // This is converted to the following sequence: + // unrealized_conversion_cast i32 -> tensor + // unrealized_conversion_cast tensor -> index + result = rewriter + .create( + op.getLoc(), RankedTensorType::get({}, result.getType()), + result) + .getResult(0); + rewriter.replaceOp(op, rewriter.create( + op.getLoc(), op.getOut().getType(), result)); + return success(); + } if (isIndexOrShapedOfIndex(result)) { result = castToI32(rewriter, op.getLoc(), result); @@ -439,6 +454,58 @@ struct ConvertTensorDimPattern : public OpRewritePattern { } }; +struct ConvertTensorExtractPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractOp op, + PatternRewriter& rewriter) const override { + SmallVector indices; + auto tensorType = op.getTensor().getType(); + // We only support getting static indices. + for (auto index : op.getIndices()) { + auto constIndex = + dyn_cast_or_null(index.getDefiningOp()); + if (!constIndex) + return rewriter.notifyMatchFailure(op, "expected constant index op"); + + // Check if the index is out of range. + int idx = indices.size(); + if (tensorType.isDynamicDim(idx) || + constIndex.value() >= tensorType.getDimSize(idx)) + return rewriter.notifyMatchFailure(op, "index out of range"); + + indices.push_back(constIndex.value()); + } + auto input = castToI32(rewriter, op.getLoc(), op.getTensor()); + auto startIndices = rewriter.getDenseI64ArrayAttr(indices); + for (auto& index : indices) { + index += 1; + } + auto limitIndices = rewriter.getDenseI64ArrayAttr(indices); + + Value extractedTensor = rewriter.create( + op.getLoc(), input, startIndices, limitIndices, + /*strides=*/ + rewriter.getDenseI64ArrayAttr(SmallVector(indices.size(), 1))); + Value extractedScalarTensor = rewriter.create( + op.getLoc(), RankedTensorType::get({}, rewriter.getI32Type()), + extractedTensor); + if (getElementTypeOrSelf(op.getResult().getType()).isIndex()) { + auto extractedIndex = + castToIndex(rewriter, op.getLoc(), extractedScalarTensor); + rewriter.replaceOp(op, extractedIndex); + } else { + // For the special case when the input is a i32 tensor and output is i32, + // convert the result back to i32 to be consistent: + // unrealized_conversion_cast tensor -> i32 + rewriter.replaceOp(op, rewriter.create( + op.getLoc(), op.getResult().getType(), + extractedScalarTensor)); + } + return success(); + } +}; + struct ConvertTensorFromElementsPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -599,6 +666,7 @@ void populateShapeToStablehloPatterns(MLIRContext* context, patterns->add>(context); patterns->add>(context); patterns->add(context); + patterns->add(context); patterns->add(context); }