Skip to content

Commit

Permalink
Support all ranked tensor types in EvalSlice folder
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpant committed Sep 23, 2024
1 parent 9bb28f8 commit 87b6fee
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 42 deletions.
44 changes: 33 additions & 11 deletions stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -433,17 +433,25 @@ func.func @eval_sign() -> tensor<3xi64> {
// -----

// CHECK-LABEL: func @eval_slice
func.func @eval_slice() -> tensor<2xi64> {
func.func @eval_slice() -> (tensor<2xi64>, tensor<1x2x1xi64>) {
// CHECK-NOT: stablehlo.slice
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[1, 2]> : tensor<2xi64>
// CHECK: return [[RESULT]]
// CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<[1, 2]> : tensor<2xi64>
// CHECK: [[RESULT2:%.*]] = stablehlo.constant dense<{{\[\[}}[15], [19]]]> : tensor<1x2x1xi64>
// CHECK: return [[RESULT1]], [[RESULT2]]
%0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%1 = "stablehlo.slice"(%0) {
start_indices = array<i64: 0>,
limit_indices = array<i64: 2>,
strides = array<i64: 1>
} : (tensor<4xi64>) -> tensor<2xi64>
func.return %1 : tensor<2xi64>
%2 = stablehlo.constant dense<[[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]],
[[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]> : tensor<2x3x4xi64>
%3 = "stablehlo.slice"(%2) {
start_indices = array<i64: 0, 1, 1>,
limit_indices = array<i64: 2, 3, 3>,
strides = array<i64: 3, 1, 2>
} : (tensor<2x3x4xi64>) -> tensor<1x2x1xi64>
func.return %1, %3 : tensor<2xi64>, tensor<1x2x1xi64>
}

// -----
Expand Down Expand Up @@ -496,18 +504,32 @@ func.func @eval_slice_unit_prefix() -> (tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>

// -----

// CHECK-LABEL: func @eval_slice_non_unit_prefix
func.func @eval_slice_non_unit_prefix() -> tensor<1x2x1xi64> {
// CHECK: stablehlo.constant {{.*}} : tensor<1x2x2xi64>
// CHECK: [[RESULT:%.*]] = stablehlo.slice{{.*}}
// CHECK-LABEL: func @eval_slice_zerodim
func.func @eval_slice_zerodim() -> tensor<0x2x1xi64> {
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<> : tensor<0x2x1xi64>
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi64>
%1 = "stablehlo.slice"(%0) {
start_indices = array<i64: 0, 0, 1>,
start_indices = array<i64: 1, 0, 1>,
limit_indices = array<i64: 1, 2, 2>,
strides = array<i64: 1, 1, 1>
} : (tensor<1x2x2xi64>) -> tensor<1x2x1xi64>
func.return %1 : tensor<1x2x1xi64>
} : (tensor<1x2x2xi64>) -> tensor<0x2x1xi64>
func.return %1 : tensor<0x2x1xi64>
}

// -----

// CHECK-LABEL: func @eval_slice_zerorank
func.func @eval_slice_zerorank() -> tensor<f32> {
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<3.300000e+01> : tensor<f32>
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<33.0> : tensor<f32>
%1 = "stablehlo.slice"(%0) {
start_indices = array<i64>,
limit_indices = array<i64>,
strides = array<i64>
} : (tensor<f32>) -> tensor<f32>
func.return %1 : tensor<f32>
}

// -----
Expand Down
79 changes: 48 additions & 31 deletions stablehlo/transforms/StablehloAggressiveFolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,41 @@ struct EvalSignOpPattern : public OpRewritePattern<SignOp> {
}
};

template <typename RangeType>
DenseElementsAttr sliceType(SliceOp& op, const RangeType& data) {
using ElementType = std::decay_t<decltype(*std::begin(data))>;

RankedTensorType operandType = op.getOperand().getType();
RankedTensorType resultType = op.getResult().getType();

const auto dimOffsets = computeStrides(operandType.getShape());
auto startIndices = op.getStartIndices();
auto limitIndices = op.getLimitIndices();
auto strides = op.getStrides();

const SmallVector<int64_t> startIndex(startIndices);
const SmallVector<int64_t> endIndex(limitIndices);

SmallVector<ElementType> result;
result.reserve(resultType.getNumElements());

SmallVector<int64_t> srcIndex(startIndex);
for (int64_t i = 0; i < resultType.getNumElements(); ++i) {
auto srcLinearIndex = linearize(srcIndex, dimOffsets);
result.push_back(data[srcLinearIndex]);
for (int64_t dim = srcIndex.size() - 1; dim >= 0; --dim) {
srcIndex[dim] += strides[dim];
if (srcIndex[dim] >= endIndex[dim])
srcIndex[dim] = startIndex[dim];
else
break;
}
}

return DenseElementsAttr::get(op.getResult().getType(),
ArrayRef<ElementType>(result));
}

struct EvalSliceOpPattern : public OpRewritePattern<SliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(SliceOp op,
Expand All @@ -529,45 +564,27 @@ struct EvalSliceOpPattern : public OpRewritePattern<SliceOp> {
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
return failure();

if (resultType.getRank() < 1)
return rewriter.notifyMatchFailure(
op, "expected non-0 ranked tensor result type");

auto operand = cast<TypedValue<RankedTensorType>>(op.getOperand());
auto operand = op.getOperand();
RankedTensorType operandType = operand.getType();
if (!operandType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "expected operand with static ranked tensor type");

// A ranked tensor type with unit dimension prefix of R-1 size is physically
// compatible with 1-dimensional type.
if (!llvm::all_of(resultType.getShape().drop_back(),
[](int64_t s) { return s == 1; }))
ElementsAttr els;
if (!matchPattern(operand, m_Constant(&els)))
return rewriter.notifyMatchFailure(
op, "expected 1-dimensional compatible result type");

SmallVector<APSInt> operandData;
if (failed(hlo::matchInts(operand, operandData)))
return rewriter.notifyMatchFailure(op, "expected constant operand");

const auto dimOffsets = computeSuffixProduct(operandType.getShape());
auto startIndices = op.getStartIndices();
auto limitIndices = op.getLimitIndices();
auto strides = op.getStrides();

int64_t start = 0;
for (size_t i = 0; i < startIndices.size(); ++i)
start += startIndices[i] * dimOffsets[i];
op, "expected constant integer or float operand");

auto slicedDim = operandType.getRank() - 1;
int64_t limit = start + limitIndices[slicedDim] - startIndices[slicedDim];
int64_t stride = strides[slicedDim];
SmallVector<APSInt> result;
for (auto i = start; i < limit; i += stride)
result.push_back(operandData[i]);
DenseElementsAttr resAttr;
if (auto data = els.tryGetValues<APInt>())
resAttr = sliceType(op, *data);
else if (auto data = els.tryGetValues<APFloat>())
resAttr = sliceType(op, *data);
else
return rewriter.notifyMatchFailure(op.getLoc(),
"unsupported element type");

rewriter.replaceOpWithNewOp<ConstantOp>(op,
getTensorAttr(resultType, result));
rewriter.replaceOpWithNewOp<ConstantOp>(op, resAttr);
return success();
}
};
Expand Down

0 comments on commit 87b6fee

Please sign in to comment.