Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interpreter for ConvolutionOp #1964

Merged
merged 3 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -2317,16 +2317,18 @@ For quantized types, performs `dequantize_op_quantize(
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
```

&nbsp;[More Examples](../stablehlo/tests/interpret/convolution.mlir)

### cosine

#### Semantics
Expand Down
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ one of the following tracking labels.
| concatenate | yes | yes | yes | yes | yes |
| constant | yes | yes | yes | yes | yes |
| convert | yes | yes | infeasible | yes | yes |
| convolution | yes | yes | infeasible | revisit | no |
| convolution | yes | yes | infeasible | revisit | yes |
| cosine | yes | yes | yes | yes | yes |
| count_leading_zeros | yes | yes | yes | yes | yes |
| create_token | no | yes\* | yes\* | yes | revisit |
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/reference/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ bool Sizes::inBounds(const Sizes &bounds) const {

IndexSpaceIterator Sizes::index_begin() const {
if (any_of(*this, [](int64_t dimSize) { return dimSize == 0; }))
return IndexSpaceIterator(*this, std::nullopt);
return IndexSpaceIterator(*this);

Index initialIndex(size());
return IndexSpaceIterator(*this, initialIndex);
}

IndexSpaceIterator Sizes::index_end() const {
return IndexSpaceIterator(*this, std::nullopt);
return IndexSpaceIterator(*this);
}

Sizes operator+(const Sizes &x, const Sizes &y) {
Expand Down
6 changes: 4 additions & 2 deletions stablehlo/reference/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ using Index = Sizes;
class IndexSpaceIterator {
public:
/// \name Constructor
IndexSpaceIterator(Sizes shape) : shape_(shape) { index_ = std::nullopt; }
ghpvnist marked this conversation as resolved.
Show resolved Hide resolved

IndexSpaceIterator(Sizes shape, std::optional<Index> index)
: shape_(shape), index_(index) {
if (index && !index->inBounds(shape)) index_ = std::nullopt;
: shape_(shape), index_(std::nullopt) {
if (index && index->inBounds(shape)) index_ = index;
GleasonK marked this conversation as resolved.
Show resolved Hide resolved
}

/// Get the current index.
Expand Down
279 changes: 279 additions & 0 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ Index evalIndex(Tensor tensor) {
return result;
}

Tensor evalDotGeneralOp(const Tensor &lhs, const Tensor &rhs,
const Axes &lhsContractingDimensions,
const Axes &rhsContractingDimensions) {
SmallVector<ShapedTypeComponents> inferredDotGeneralType;
if (failed(hlo::inferDotGeneralOp(
/*location=*/{}, lhs.getType(), rhs.getType(),
/*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions*/ {},
GleasonK marked this conversation as resolved.
Show resolved Hide resolved
lhsContractingDimensions, rhsContractingDimensions,
/*precisionConfig=*/{}, inferredDotGeneralType)))
report_fatal_error(
invalidArgument("Could not infer DotGeneralOp's return type"));

return evalDotGeneralOp(
lhs, rhs, /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions*/ {},
lhsContractingDimensions, rhsContractingDimensions,
RankedTensorType::get(inferredDotGeneralType[0].getDims(),
lhs.getElementType()));
}

Tensor evalPadOp(const Tensor &operand, const Tensor &paddingValue,
const Sizes &edgePaddingLow, const Sizes &edgePaddingHigh,
const Sizes &interiorPadding) {
Expand Down Expand Up @@ -143,6 +162,12 @@ Tensor evalSliceOp(const Tensor &operand, const Index &index) {
return evalSliceOp(operand, start, limit, strides);
}

Sizes extractElements(ArrayRef<int64_t> arr, ArrayRef<int64_t> indices) {
Sizes elements;
for (int64_t index : indices) elements.push_back(arr[index]);
return elements;
}

void failOnDecomposableOp(Operation &op) {
report_fatal_error(invalidArgument(
"Operation %s is unsupported at the moment. "
Expand All @@ -153,6 +178,13 @@ void failOnDecomposableOp(Operation &op) {
op.getName().getStringRef().str().c_str()));
}

template <typename T>
DenseIntElementsAttr getDenseIntElementsAttr(Type elementType, T values,
SmallVector<int64_t> valuesShape) {
return DenseIntElementsAttr::get(
RankedTensorType::get(valuesShape, elementType), values);
}

SmallVector<SmallVector<uint32_t>> getReplicaGroups(
DenseIntElementsAttr replicaGroupsAttr) {
auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape();
Expand All @@ -168,6 +200,65 @@ SmallVector<SmallVector<uint32_t>> getReplicaGroups(
return replicaGroups;
}

Tensor evalConvolutionOp(
const Tensor &lhs, const Tensor &rhs, ArrayRef<int64_t> windowStrides,
ArrayRef<std::pair<int64_t, int64_t>> padding,
ArrayRef<int64_t> lhsDilation, ArrayRef<int64_t> rhsDilation,
ArrayRef<bool> windowReversal, Axis inputBatchDimension,
Axis inputFeatureDimension, const Axes &inputSpatialDimensions,
Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension,
const Axes &kernelSpatialDimensions, Axis outputBatchDimension,
Axis outputFeatureDimension, const Axes &outputSpatialDimensions,
int64_t featureGroupCount, int64_t batchGroupCount,
std::optional<ArrayAttr> precisionConfig, ShapedType resultType) {
SmallVector<int64_t> paddingVector;
for (auto pair : padding) {
paddingVector.push_back(pair.first);
paddingVector.push_back(pair.second);
}

SmallVector<ShapedTypeComponents> inferredConvolutionType;
if (failed(hlo::inferConvolutionOp(
/*location=*/{}, lhs.getType(), rhs.getType(), windowStrides,
/*padding=*/
getDenseIntElementsAttr(
IntegerType::get(lhs.getType().getContext(), 64), paddingVector,
SmallVector<int64_t>(padding.size(), 2)),
lhsDilation, rhsDilation, windowReversal, inputBatchDimension,
inputFeatureDimension, ArrayRef<int64_t>(inputSpatialDimensions),
kernelInputFeatureDimension, kernelOutputFeatureDimension,
ArrayRef<int64_t>(kernelSpatialDimensions), outputBatchDimension,
outputFeatureDimension, ArrayRef<int64_t>(outputSpatialDimensions),
featureGroupCount, batchGroupCount,
/*precisionConfig=*/{}, inferredConvolutionType)))
ghpvnist marked this conversation as resolved.
Show resolved Hide resolved
report_fatal_error(
invalidArgument("Could not infer ConvolutionOp's return type"));

return evalConvolutionOp(
lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation,
windowReversal, inputBatchDimension, inputFeatureDimension,
inputSpatialDimensions, kernelInputFeatureDimension,
kernelOutputFeatureDimension, kernelSpatialDimensions,
outputBatchDimension, outputFeatureDimension, outputSpatialDimensions,
featureGroupCount, batchGroupCount,
RankedTensorType::get(inferredConvolutionType[0].getDims(),
resultType.getElementType()));
}

// Returns `result` with the effect of applying `permutation`
// (= [dimA] + dimsB + [dimC]) to `input` (= [n] + hw + [c]) such that
// result[permutation[i]] = input[i].
template <typename T>
SmallVector<T> concatAndPermute(T n, SmallVector<T> hw, T c,
const Axes &permutation) {
SmallVector<T> result(permutation.size());
result[permutation[0]] = n;
result[permutation[permutation.size() - 1]] = c;
for (uint64_t i = 1; i < permutation.size() - 1; ++i)
result[permutation[i]] = hw[i - 1];
return result;
}

Tensor constant(Element initValue) {
Tensor result(RankedTensorType::get({}, initValue.getType()));
result.set({}, initValue);
Expand Down Expand Up @@ -420,6 +511,50 @@ SmallVector<InterpreterValue> eval(Region &region,
auto operand = scope.findTensor(convertOp.getOperand());
auto result = evalConvertOp(operand, convertOp.getType());
scope.add(convertOp.getResult(), result);
} else if (auto convolutionOp = dyn_cast<ConvolutionOp>(op)) {
auto lhs = scope.findTensor(convolutionOp.getLhs());
auto rhs = scope.findTensor(convolutionOp.getRhs());
auto rank = lhs.getRank();

SmallVector<int64_t> windowStrides(rank - 2, 1);
if (auto windowStridesAttr = convolutionOp.getWindowStridesAttr())
windowStrides = SmallVector<int64_t>(windowStridesAttr.asArrayRef());
mlevesquedion marked this conversation as resolved.
Show resolved Hide resolved

SmallVector<std::pair<int64_t, int64_t>> padding(rank - 2, {0, 0});
if (auto paddingAttr = convolutionOp.getPaddingAttr()) {
auto paddingOrErr = hlo::convertPaddingAttribute(paddingAttr, {});
if (failed(paddingOrErr))
report_fatal_error(invalidArgument("Invalid padding format found."));
padding = *paddingOrErr;
}

SmallVector<int64_t> lhsDilation(rank - 2, 1);
if (auto lhsDilationAttr = convolutionOp.getLhsDilationAttr())
lhsDilation = SmallVector<int64_t>(lhsDilationAttr.asArrayRef());

SmallVector<int64_t> rhsDilation(rank - 2, 1);
if (auto rhsDilationAttr = convolutionOp.getRhsDilationAttr())
rhsDilation = SmallVector<int64_t>(rhsDilationAttr.asArrayRef());

SmallVector<bool> windowReversal(rank - 2, false);
if (auto windowReversalAttr = convolutionOp.getWindowReversalAttr())
windowReversal = SmallVector<bool>(windowReversalAttr.asArrayRef());

auto dimensionNumbers = convolutionOp.getDimensionNumbers();
auto result = evalConvolutionOp(
lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation,
windowReversal, dimensionNumbers.getInputBatchDimension(),
dimensionNumbers.getInputFeatureDimension(),
Axes(dimensionNumbers.getInputSpatialDimensions()),
dimensionNumbers.getKernelInputFeatureDimension(),
dimensionNumbers.getKernelOutputFeatureDimension(),
Axes(dimensionNumbers.getKernelSpatialDimensions()),
dimensionNumbers.getOutputBatchDimension(),
dimensionNumbers.getOutputFeatureDimension(),
Axes(dimensionNumbers.getOutputSpatialDimensions()),
convolutionOp.getFeatureGroupCount(),
convolutionOp.getBatchGroupCount(), convolutionOp.getType());
scope.add(convolutionOp.getResult(), result);
} else if (auto cosineOp = dyn_cast<CosineOp>(op)) {
auto operand = scope.findTensor(cosineOp.getOperand());
auto result = evalCosineOp(operand, cosineOp.getType());
Expand Down Expand Up @@ -1237,6 +1372,150 @@ Tensor evalConvertOp(const Tensor &operand, ShapedType resultType) {
return result;
}

Tensor evalConvolutionOp(
const Tensor &lhs, const Tensor &rhs, ArrayRef<int64_t> windowStrides,
ArrayRef<std::pair<int64_t, int64_t>> padding,
ArrayRef<int64_t> lhsDilation, ArrayRef<int64_t> rhsDilation,
ArrayRef<bool> windowReversal, Axis inputBatchDimension,
Axis inputFeatureDimension, const Axes &inputSpatialDimensions,
Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension,
const Axes &kernelSpatialDimensions, Axis outputBatchDimension,
Axis outputFeatureDimension, const Axes &outputSpatialDimensions,
int64_t featureGroupCount, int64_t batchGroupCount, ShapedType resultType) {
Tensor result(resultType);

if (featureGroupCount > 1) {
auto lhses = split(lhs, featureGroupCount, inputFeatureDimension,
resultType.getContext());
auto rhses = split(rhs, featureGroupCount, kernelOutputFeatureDimension,
resultType.getContext());
SmallVector<Tensor> results;
for (auto [left, right] : llvm::zip(lhses, rhses))
results.push_back(evalConvolutionOp(
left, right, windowStrides, padding, lhsDilation, rhsDilation,
windowReversal, inputBatchDimension, inputFeatureDimension,
inputSpatialDimensions, kernelInputFeatureDimension,
kernelOutputFeatureDimension, kernelSpatialDimensions,
outputBatchDimension, outputFeatureDimension, outputSpatialDimensions,
/*featureGroupCount=*/1, batchGroupCount, /*precisionConfig=*/{},
resultType));

return evalConcatenateOp(results, outputFeatureDimension, result.getType());
}

if (batchGroupCount > 1) {
auto lhses = split(lhs, batchGroupCount, inputBatchDimension,
resultType.getContext());
auto rhses = split(rhs, batchGroupCount, kernelOutputFeatureDimension,
resultType.getContext());
SmallVector<Tensor> results;
for (auto [left, right] : llvm::zip(lhses, rhses))
results.push_back(evalConvolutionOp(
left, right, windowStrides, padding, lhsDilation, rhsDilation,
windowReversal, inputBatchDimension, inputFeatureDimension,
inputSpatialDimensions, kernelInputFeatureDimension,
kernelOutputFeatureDimension, kernelSpatialDimensions,
outputBatchDimension, outputFeatureDimension, outputSpatialDimensions,
featureGroupCount, /*batchGroupCount=*/1, /*precisionConfig=*/{},
resultType));

return evalConcatenateOp(results, outputFeatureDimension, result.getType());
}

Axes lhsPermutation;
lhsPermutation.push_back(inputBatchDimension);
lhsPermutation.append(inputSpatialDimensions.begin(),
inputSpatialDimensions.end());
lhsPermutation.push_back(inputFeatureDimension);

auto lhsWindowDimensions =
concatAndPermute(lhs.getShape()[inputBatchDimension],
extractElements(rhs.getShape(), kernelSpatialDimensions),
lhs.getShape()[inputFeatureDimension], lhsPermutation);

auto lhsWindowStrides =
concatAndPermute(1L, llvm::to_vector(windowStrides), 1L, lhsPermutation);

auto lhsBaseDilations =
concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, lhsPermutation);

auto lhsWindowDilations =
concatAndPermute(1L, llvm::to_vector(rhsDilation), 1L, lhsPermutation);

Sizes lhsPaddingLow, lhsPaddingHigh;
for (auto paddingPair : concatAndPermute({0, 0}, llvm::to_vector(padding),
{0, 0}, lhsPermutation)) {
lhsPaddingLow.push_back(paddingPair.first);
lhsPaddingHigh.push_back(paddingPair.second);
}

auto paddingValue = constant(0.0, result.getElementType());
auto paddedLhs = evalPadOp(lhs, paddingValue, lhsPaddingLow, lhsPaddingHigh,
Sizes(lhsBaseDilations));

IndexSpaceIterator outputSpatialIndexIt(
extractElements(result.getShape(), outputSpatialDimensions),
Index(outputSpatialDimensions.size()));
IndexSpaceIterator outputSpatialIndexItEnd(
extractElements(result.getShape(), outputSpatialDimensions));
for (; outputSpatialIndexIt != outputSpatialIndexItEnd;
++outputSpatialIndexIt) {
Sizes lhsWindowStart;
for (auto [i, offset] : llvm::enumerate(
concatAndPermute(0L, *outputSpatialIndexIt, 0L, lhsPermutation)))
lhsWindowStart.push_back(lhsWindowStrides[i] * offset);

Sizes limitIndices;
for (size_t i = 0; i < lhsWindowStart.size(); ++i)
limitIndices.push_back(std::min(
lhsWindowStart[i] + lhsWindowDimensions[i] * lhsWindowDilations[i],
paddedLhs.getShape()[i]));

auto lhsWindow = evalSliceOp(paddedLhs, lhsWindowStart, limitIndices,
Sizes(lhsWindowDilations));

Axes reverseDims;
for (auto [i, isReverse] : llvm::enumerate(windowReversal))
if (isReverse) reverseDims.push_back(inputSpatialDimensions[i]);
auto reversedLhsWindow =
evalReverseOp(lhsWindow, reverseDims, lhsWindow.getType());

Axes lhsContractingDimensions(inputSpatialDimensions);
lhsContractingDimensions.push_back(inputFeatureDimension);

Axes rhsContractingDimensions(kernelSpatialDimensions);
rhsContractingDimensions.push_back(kernelInputFeatureDimension);

auto dotProduct =
evalDotGeneralOp(reversedLhsWindow, rhs, lhsContractingDimensions,
rhsContractingDimensions);

Sizes resultNonSpatialDims;
for (auto i = 0; i < result.getRank(); ++i)
if (llvm::find(outputSpatialDimensions, i) ==
outputSpatialDimensions.end())
resultNonSpatialDims.push_back(result.getShape()[i]);

Axes resultPermutation;
resultPermutation.push_back(outputBatchDimension);
resultPermutation.append(outputSpatialDimensions.begin(),
outputSpatialDimensions.end());
resultPermutation.push_back(outputFeatureDimension);

IndexSpaceIterator resultNonSpatialIt(resultNonSpatialDims,
Index(resultNonSpatialDims.size()));
for (auto dotProductIt = dotProduct.index_begin();
dotProductIt != dotProduct.index_end();
++dotProductIt, ++resultNonSpatialIt) {
Index resultIndex(
concatAndPermute((*resultNonSpatialIt)[0], *outputSpatialIndexIt,
(*resultNonSpatialIt)[1], resultPermutation));
result.set(resultIndex, dotProduct.get(*dotProductIt));
}
}
return result;
}

Tensor evalCosineOp(const Tensor &operand, ShapedType resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it)
Expand Down
10 changes: 10 additions & 0 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ Tensor evalConcatenateOp(ArrayRef<Tensor> inputs, Axis dimension,
ShapedType resultType);
Tensor evalConstantOp(ElementsAttr value);
Tensor evalConvertOp(const Tensor &operand, ShapedType resultType);
Tensor evalConvolutionOp(
const Tensor &lhs, const Tensor &rhs, ArrayRef<int64_t> windowStrides,
ArrayRef<std::pair<int64_t, int64_t>> padding,
ArrayRef<int64_t> lhsDilation, ArrayRef<int64_t> rhsDilation,
ArrayRef<bool> windowReversal, Axis inputBatchDimension,
Axis inputFeatureDimension, const Axes &inputSpatialDimensions,
Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension,
const Axes &kernelSpatialDimensions, Axis outputBatchDimension,
Axis outputFeatureDimension, const Axes &outputSpatialDimensions,
int64_t featureGroupCount, int64_t batchGroupCount, ShapedType resultType);
Tensor evalCosineOp(const Tensor &operand, ShapedType resultType);
Tensor evalDivideOp(const Tensor &lhs, const Tensor &rhs,
ShapedType resultType);
Expand Down
Loading
Loading