diff --git a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir index c24b7d4699..5a414ec307 100644 --- a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir +++ b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir @@ -433,17 +433,25 @@ func.func @eval_sign() -> tensor<3xi64> { // ----- // CHECK-LABEL: func @eval_slice -func.func @eval_slice() -> tensor<2xi64> { +func.func @eval_slice() -> (tensor<2xi64>, tensor<1x2x1xi64>) { // CHECK-NOT: stablehlo.slice - // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[1, 2]> : tensor<2xi64> - // CHECK: return [[RESULT]] + // CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<[1, 2]> : tensor<2xi64> + // CHECK: [[RESULT2:%.*]] = stablehlo.constant dense<{{\[\[}}[15], [19]]]> : tensor<1x2x1xi64> + // CHECK: return [[RESULT1]], [[RESULT2]] %0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> %1 = "stablehlo.slice"(%0) { start_indices = array, limit_indices = array, strides = array } : (tensor<4xi64>) -> tensor<2xi64> - func.return %1 : tensor<2xi64> + %2 = stablehlo.constant dense<[[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], + [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]> : tensor<2x3x4xi64> + %3 = "stablehlo.slice"(%2) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<2x3x4xi64>) -> tensor<1x2x1xi64> + func.return %1, %3 : tensor<2xi64>, tensor<1x2x1xi64> } // ----- @@ -496,18 +504,32 @@ func.func @eval_slice_unit_prefix() -> (tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64> // ----- -// CHECK-LABEL: func @eval_slice_non_unit_prefix -func.func @eval_slice_non_unit_prefix() -> tensor<1x2x1xi64> { - // CHECK: stablehlo.constant {{.*}} : tensor<1x2x2xi64> - // CHECK: [[RESULT:%.*]] = stablehlo.slice{{.*}} +// CHECK-LABEL: func @eval_slice_zerodim +func.func @eval_slice_zerodim() -> tensor<0x2x1xi64> { + // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<> : tensor<0x2x1xi64> // CHECK: return [[RESULT]] %0 = stablehlo.constant dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi64> %1 = "stablehlo.slice"(%0) { - start_indices = array, + start_indices = array, limit_indices = array, strides = array - } : (tensor<1x2x2xi64>) -> tensor<1x2x1xi64> - func.return %1 : tensor<1x2x1xi64> + } : (tensor<1x2x2xi64>) -> tensor<0x2x1xi64> + func.return %1 : tensor<0x2x1xi64> +} + +// ----- + +// CHECK-LABEL: func @eval_slice_zerorank +func.func @eval_slice_zerorank() -> tensor { + // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<3.300000e+01> : tensor + // CHECK: return [[RESULT]] + %0 = stablehlo.constant dense<33.0> : tensor + %1 = "stablehlo.slice"(%0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor) -> tensor + func.return %1 : tensor } // ----- diff --git a/stablehlo/transforms/StablehloAggressiveFolder.cpp b/stablehlo/transforms/StablehloAggressiveFolder.cpp index 4d62d12f68..a5768f4c9f 100644 --- a/stablehlo/transforms/StablehloAggressiveFolder.cpp +++ b/stablehlo/transforms/StablehloAggressiveFolder.cpp @@ -521,6 +521,41 @@ struct EvalSignOpPattern : public OpRewritePattern { } }; +template +DenseElementsAttr sliceType(SliceOp& op, const RangeType& data) { + using ElementType = std::decay_t; + + RankedTensorType operandType = op.getOperand().getType(); + RankedTensorType resultType = op.getResult().getType(); + + const auto dimOffsets = computeStrides(operandType.getShape()); + auto startIndices = op.getStartIndices(); + auto limitIndices = op.getLimitIndices(); + auto strides = op.getStrides(); + + const SmallVector startIndex(startIndices); + const SmallVector endIndex(limitIndices); + + SmallVector result; + result.reserve(resultType.getNumElements()); + + SmallVector srcIndex(startIndex); + for (int64_t i = 0; i < resultType.getNumElements(); ++i) { + auto srcLinearIndex = linearize(srcIndex, dimOffsets); + result.push_back(data[srcLinearIndex]); + for (int64_t dim = srcIndex.size() - 1; dim >= 0; --dim) { + srcIndex[dim] += strides[dim]; + if (srcIndex[dim] >= endIndex[dim]) + srcIndex[dim] = startIndex[dim]; + else + break; + } + } + + return DenseElementsAttr::get(op.getResult().getType(), + ArrayRef(result)); +} + struct EvalSliceOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SliceOp op, @@ -529,45 +564,27 @@ struct EvalSliceOpPattern : public OpRewritePattern { if (failed(validateResultTypeForEval(rewriter, op, resultType))) return failure(); - if (resultType.getRank() < 1) - return rewriter.notifyMatchFailure( - op, "expected non-0 ranked tensor result type"); - - auto operand = cast>(op.getOperand()); + auto operand = op.getOperand(); RankedTensorType operandType = operand.getType(); if (!operandType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "expected operand with static ranked tensor type"); - // A ranked tensor type with unit dimension prefix of R-1 size is physically - // compatible with 1-dimensional type. - if (!llvm::all_of(resultType.getShape().drop_back(), - [](int64_t s) { return s == 1; })) + ElementsAttr els; + if (!matchPattern(operand, m_Constant(&els))) return rewriter.notifyMatchFailure( - op, "expected 1-dimensional compatible result type"); - - SmallVector operandData; - if (failed(hlo::matchInts(operand, operandData))) - return rewriter.notifyMatchFailure(op, "expected constant operand"); - - const auto dimOffsets = computeSuffixProduct(operandType.getShape()); - auto startIndices = op.getStartIndices(); - auto limitIndices = op.getLimitIndices(); - auto strides = op.getStrides(); - - int64_t start = 0; - for (size_t i = 0; i < startIndices.size(); ++i) - start += startIndices[i] * dimOffsets[i]; + op, "expected constant integer or float operand"); - auto slicedDim = operandType.getRank() - 1; - int64_t limit = start + limitIndices[slicedDim] - startIndices[slicedDim]; - int64_t stride = strides[slicedDim]; - SmallVector result; - for (auto i = start; i < limit; i += stride) - result.push_back(operandData[i]); + DenseElementsAttr resAttr; + if (auto data = els.tryGetValues()) + resAttr = sliceType(op, *data); + else if (auto data = els.tryGetValues()) + resAttr = sliceType(op, *data); + else + return rewriter.notifyMatchFailure(op.getLoc(), + "unsupported element type"); - rewriter.replaceOpWithNewOp(op, - getTensorAttr(resultType, result)); + rewriter.replaceOpWithNewOp(op, resAttr); return success(); } };