Skip to content

Commit

Permalink
Wrap indices around max value for tosa.gather creation.
Browse files Browse the repository at this point in the history
  • Loading branch information
sahas3 committed Jan 5, 2025
1 parent 714b7fc commit 8fcdfbe
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 100 deletions.
39 changes: 0 additions & 39 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4006,41 +4006,6 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success();
}

Value wrapIndicesAroundMax(Value index, int maxIndex, Operation *op,
ConversionPatternRewriter &rewriter) {
// performs the operation : index = index % maxIndex to wrap index around
// maxIndex

auto maxIndexValue =
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
auto maxIndexValueMinusOne =
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex - 1, {}).value();

auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto boolType = indexType.clone(rewriter.getIntegerType(1));

auto isBeyondMaxIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
rewriter, op->getLoc(), boolType, index, maxIndexValueMinusOne);
auto wrappedBeyondMaxIndicesQuotient =
tosa::CreateOpAndInfer<tosa::IntDivOp>(rewriter, op->getLoc(), indexType,
index, maxIndexValue)
.getResult();
auto wrappedBeyondMaxIndicesQuotientTimesIndices =
tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), indexType,
wrappedBeyondMaxIndicesQuotient,
maxIndexValue, /*shift=*/0)
.getResult();
auto wrappedBeyondMaxIndices =
tosa::CreateOpAndInfer<tosa::SubOp>(
rewriter, op->getLoc(), indexType, index,
wrappedBeyondMaxIndicesQuotientTimesIndices)
.getResult();

return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
indexType, isBeyondMaxIndices,
wrappedBeyondMaxIndices, index);
}

template <>
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
AtenIndexSelectOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -4084,10 +4049,6 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}

int64_t selfNumElems = std::accumulate(inputShape.begin(), inputShape.end(),
1, std::multiplies<int64_t>());
index = wrapIndicesAroundMax(index, selfNumElems, op, rewriter);

// Get positive dim
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
Expand Down
40 changes: 39 additions & 1 deletion lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,41 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
return indicesTf.getResult();
}

Value wrapIndicesAroundMax(Value index, int maxIndex, Operation *op,
PatternRewriter &rewriter) {
// performs the operation : index = index % maxIndex to wrap index around
// maxIndex

auto maxIndexValue =
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
auto maxIndexValueMinusOne =
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex - 1, {}).value();

auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto boolType = indexType.clone(rewriter.getIntegerType(1));

auto isBeyondMaxIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
rewriter, op->getLoc(), boolType, index, maxIndexValueMinusOne);
auto wrappedBeyondMaxIndicesQuotient =
tosa::CreateOpAndInfer<tosa::IntDivOp>(rewriter, op->getLoc(), indexType,
index, maxIndexValue)
.getResult();
auto wrappedBeyondMaxIndicesQuotientTimesIndices =
tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), indexType,
wrappedBeyondMaxIndicesQuotient,
maxIndexValue, /*shift=*/0)
.getResult();
auto wrappedBeyondMaxIndices =
tosa::CreateOpAndInfer<tosa::SubOp>(
rewriter, op->getLoc(), indexType, index,
wrappedBeyondMaxIndicesQuotientTimesIndices)
.getResult();

return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
indexType, isBeyondMaxIndices,
wrappedBeyondMaxIndices, index);
}

// Lowers Gather operators to a sequence of TOSA ops.
// taken from
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
Expand Down Expand Up @@ -403,14 +438,17 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
flattenedIndicesReduceOp.getResult(),
rewriter.getDenseI64ArrayAttr(tosaIndicesShape));

auto wrappedIndices = wrapIndicesAroundMax(tosaIndicesReshapeOp.getResult(),
K + 1, op, rewriter);

// Now the gather op itself
// %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) ->
// tensor<1x8x1xf32>
auto tosaGatherOp = tosa::CreateOpAndInfer<tosa::GatherOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaGatherResultShape,
resultType.getElementType()),
tosaValuesReshapeOp.getResult(), tosaIndicesReshapeOp.getResult());
tosaValuesReshapeOp.getResult(), wrappedIndices);

// Finally, reshape back to the original output shape of [Indices,
// ParamChannels]. %10 = "tosa.reshape"(%9) {new_shape = [1, 4, 2]} :
Expand Down
Loading

0 comments on commit 8fcdfbe

Please sign in to comment.