diff --git a/docs/spec.md b/docs/spec.md index ce9bc9cd051..5ec33c7e297 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -2317,16 +2317,18 @@ For quantized types, performs `dequantize_op_quantize( // "i" is input feature dimension, "o" is output feature dimension, // "0/1/etc" are spatial dimensions. dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, - feature_group_count = 1 : i64, batch_group_count = 1 : i64, + feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] -} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32> +} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> // %result: [[ // [[10], [26]], // [[46], [62]] // ]] ``` + [More Examples](../stablehlo/tests/interpret/convolution.mlir) + ### cosine #### Semantics diff --git a/docs/status.md b/docs/status.md index 4b46d8430f1..c758844795a 100644 --- a/docs/status.md +++ b/docs/status.md @@ -70,7 +70,7 @@ one of the following tracking labels. | concatenate | yes | yes | yes | yes | yes | | constant | yes | yes | yes | yes | yes | | convert | yes | yes | infeasible | yes | yes | -| convolution | yes | yes | infeasible | revisit | no | +| convolution | yes | yes | infeasible | revisit | yes | | cosine | yes | yes | yes | yes | yes | | count_leading_zeros | yes | yes | yes | yes | yes | | create_token | no | yes\* | yes\* | yes | revisit | diff --git a/stablehlo/reference/Index.cpp b/stablehlo/reference/Index.cpp index 9f8ff31f51d..3ef82ecf58f 100644 --- a/stablehlo/reference/Index.cpp +++ b/stablehlo/reference/Index.cpp @@ -44,14 +44,14 @@ bool Sizes::inBounds(const Sizes &bounds) const { IndexSpaceIterator Sizes::index_begin() const { if (any_of(*this, [](int64_t dimSize) { return dimSize == 0; })) - return IndexSpaceIterator(*this, std::nullopt); + return IndexSpaceIterator(*this); Index initialIndex(size()); return IndexSpaceIterator(*this, initialIndex); } IndexSpaceIterator Sizes::index_end() const { - return IndexSpaceIterator(*this, std::nullopt); + return IndexSpaceIterator(*this); } Sizes operator+(const Sizes &x, const Sizes &y) { diff --git a/stablehlo/reference/Index.h b/stablehlo/reference/Index.h index 5e98a7f3dac..173765955ba 100644 --- a/stablehlo/reference/Index.h +++ b/stablehlo/reference/Index.h @@ -121,9 +121,11 @@ using Index = Sizes; class IndexSpaceIterator { public: /// \name Constructor + IndexSpaceIterator(Sizes shape) : shape_(shape) { index_ = std::nullopt; } + IndexSpaceIterator(Sizes shape, std::optional index) - : shape_(shape), index_(index) { - if (index && !index->inBounds(shape)) index_ = std::nullopt; + : shape_(shape), index_(std::nullopt) { + if (index && index->inBounds(shape)) index_ = index; } /// Get the current index. diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index ac860a66fb4..e25acb8c7ee 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -54,6 +54,25 @@ Index evalIndex(Tensor tensor) { return result; } +Tensor evalDotGeneralOp(const Tensor &lhs, const Tensor &rhs, + const Axes &lhsContractingDimensions, + const Axes &rhsContractingDimensions) { + SmallVector inferredDotGeneralType; + if (failed(hlo::inferDotGeneralOp( + /*location=*/{}, lhs.getType(), rhs.getType(), + /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions*/ {}, + lhsContractingDimensions, rhsContractingDimensions, + /*precisionConfig=*/{}, inferredDotGeneralType))) + report_fatal_error( + invalidArgument("Could not infer DotGeneralOp's return type")); + + return evalDotGeneralOp( + lhs, rhs, /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions*/ {}, + lhsContractingDimensions, rhsContractingDimensions, + RankedTensorType::get(inferredDotGeneralType[0].getDims(), + lhs.getElementType())); +} + Tensor evalPadOp(const Tensor &operand, const Tensor &paddingValue, const Sizes &edgePaddingLow, const Sizes &edgePaddingHigh, const Sizes &interiorPadding) { @@ -143,6 +162,12 @@ Tensor evalSliceOp(const Tensor &operand, const Index &index) { return evalSliceOp(operand, start, limit, strides); } +Sizes extractElements(ArrayRef arr, ArrayRef indices) { + Sizes elements; + for (int64_t index : indices) elements.push_back(arr[index]); + return elements; +} + void failOnDecomposableOp(Operation &op) { report_fatal_error(invalidArgument( "Operation %s is unsupported at the moment. " @@ -153,6 +178,13 @@ void failOnDecomposableOp(Operation &op) { op.getName().getStringRef().str().c_str())); } +template +DenseIntElementsAttr getDenseIntElementsAttr(Type elementType, T values, + SmallVector valuesShape) { + return DenseIntElementsAttr::get( + RankedTensorType::get(valuesShape, elementType), values); +} + SmallVector> getReplicaGroups( DenseIntElementsAttr replicaGroupsAttr) { auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape(); @@ -168,6 +200,65 @@ SmallVector> getReplicaGroups( return replicaGroups; } +Tensor evalConvolutionOp( + const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, + ArrayRef> padding, + ArrayRef lhsDilation, ArrayRef rhsDilation, + ArrayRef windowReversal, Axis inputBatchDimension, + Axis inputFeatureDimension, const Axes &inputSpatialDimensions, + Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, + const Axes &kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, const Axes &outputSpatialDimensions, + int64_t featureGroupCount, int64_t batchGroupCount, + std::optional precisionConfig, ShapedType resultType) { + SmallVector paddingVector; + for (auto pair : padding) { + paddingVector.push_back(pair.first); + paddingVector.push_back(pair.second); + } + + SmallVector inferredConvolutionType; + if (failed(hlo::inferConvolutionOp( + /*location=*/{}, lhs.getType(), rhs.getType(), windowStrides, + /*padding=*/ + getDenseIntElementsAttr( + IntegerType::get(lhs.getType().getContext(), 64), paddingVector, + SmallVector(padding.size(), 2)), + lhsDilation, rhsDilation, windowReversal, inputBatchDimension, + inputFeatureDimension, ArrayRef(inputSpatialDimensions), + kernelInputFeatureDimension, kernelOutputFeatureDimension, + ArrayRef(kernelSpatialDimensions), outputBatchDimension, + outputFeatureDimension, ArrayRef(outputSpatialDimensions), + featureGroupCount, batchGroupCount, + /*precisionConfig=*/{}, inferredConvolutionType))) + report_fatal_error( + invalidArgument("Could not infer ConvolutionOp's return type")); + + return evalConvolutionOp( + lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, + windowReversal, inputBatchDimension, inputFeatureDimension, + inputSpatialDimensions, kernelInputFeatureDimension, + kernelOutputFeatureDimension, kernelSpatialDimensions, + outputBatchDimension, outputFeatureDimension, outputSpatialDimensions, + featureGroupCount, batchGroupCount, + RankedTensorType::get(inferredConvolutionType[0].getDims(), + resultType.getElementType())); +} + +// Returns `result` with the effect of applying `permutation` +// (= [dimA] + dimsB + [dimC]) to `input` (= [n] + hw + [c]) such that +// result[permutation[i]] = input[i]. +template +SmallVector concatAndPermute(T n, SmallVector hw, T c, + const Axes &permutation) { + SmallVector result(permutation.size()); + result[permutation[0]] = n; + result[permutation[permutation.size() - 1]] = c; + for (uint64_t i = 1; i < permutation.size() - 1; ++i) + result[permutation[i]] = hw[i - 1]; + return result; +} + Tensor constant(Element initValue) { Tensor result(RankedTensorType::get({}, initValue.getType())); result.set({}, initValue); @@ -420,6 +511,50 @@ SmallVector eval(Region ®ion, auto operand = scope.findTensor(convertOp.getOperand()); auto result = evalConvertOp(operand, convertOp.getType()); scope.add(convertOp.getResult(), result); + } else if (auto convolutionOp = dyn_cast(op)) { + auto lhs = scope.findTensor(convolutionOp.getLhs()); + auto rhs = scope.findTensor(convolutionOp.getRhs()); + auto rank = lhs.getRank(); + + SmallVector windowStrides(rank - 2, 1); + if (auto windowStridesAttr = convolutionOp.getWindowStridesAttr()) + windowStrides = SmallVector(windowStridesAttr.asArrayRef()); + + SmallVector> padding(rank - 2, {0, 0}); + if (auto paddingAttr = convolutionOp.getPaddingAttr()) { + auto paddingOrErr = hlo::convertPaddingAttribute(paddingAttr, {}); + if (failed(paddingOrErr)) + report_fatal_error(invalidArgument("Invalid padding format found.")); + padding = *paddingOrErr; + } + + SmallVector lhsDilation(rank - 2, 1); + if (auto lhsDilationAttr = convolutionOp.getLhsDilationAttr()) + lhsDilation = SmallVector(lhsDilationAttr.asArrayRef()); + + SmallVector rhsDilation(rank - 2, 1); + if (auto rhsDilationAttr = convolutionOp.getRhsDilationAttr()) + rhsDilation = SmallVector(rhsDilationAttr.asArrayRef()); + + SmallVector windowReversal(rank - 2, false); + if (auto windowReversalAttr = convolutionOp.getWindowReversalAttr()) + windowReversal = SmallVector(windowReversalAttr.asArrayRef()); + + auto dimensionNumbers = convolutionOp.getDimensionNumbers(); + auto result = evalConvolutionOp( + lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, + windowReversal, dimensionNumbers.getInputBatchDimension(), + dimensionNumbers.getInputFeatureDimension(), + Axes(dimensionNumbers.getInputSpatialDimensions()), + dimensionNumbers.getKernelInputFeatureDimension(), + dimensionNumbers.getKernelOutputFeatureDimension(), + Axes(dimensionNumbers.getKernelSpatialDimensions()), + dimensionNumbers.getOutputBatchDimension(), + dimensionNumbers.getOutputFeatureDimension(), + Axes(dimensionNumbers.getOutputSpatialDimensions()), + convolutionOp.getFeatureGroupCount(), + convolutionOp.getBatchGroupCount(), convolutionOp.getType()); + scope.add(convolutionOp.getResult(), result); } else if (auto cosineOp = dyn_cast(op)) { auto operand = scope.findTensor(cosineOp.getOperand()); auto result = evalCosineOp(operand, cosineOp.getType()); @@ -1237,6 +1372,150 @@ Tensor evalConvertOp(const Tensor &operand, ShapedType resultType) { return result; } +Tensor evalConvolutionOp( + const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, + ArrayRef> padding, + ArrayRef lhsDilation, ArrayRef rhsDilation, + ArrayRef windowReversal, Axis inputBatchDimension, + Axis inputFeatureDimension, const Axes &inputSpatialDimensions, + Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, + const Axes &kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, const Axes &outputSpatialDimensions, + int64_t featureGroupCount, int64_t batchGroupCount, ShapedType resultType) { + Tensor result(resultType); + + if (featureGroupCount > 1) { + auto lhses = split(lhs, featureGroupCount, inputFeatureDimension, + resultType.getContext()); + auto rhses = split(rhs, featureGroupCount, kernelOutputFeatureDimension, + resultType.getContext()); + SmallVector results; + for (auto [left, right] : llvm::zip(lhses, rhses)) + results.push_back(evalConvolutionOp( + left, right, windowStrides, padding, lhsDilation, rhsDilation, + windowReversal, inputBatchDimension, inputFeatureDimension, + inputSpatialDimensions, kernelInputFeatureDimension, + kernelOutputFeatureDimension, kernelSpatialDimensions, + outputBatchDimension, outputFeatureDimension, outputSpatialDimensions, + /*featureGroupCount=*/1, batchGroupCount, /*precisionConfig=*/{}, + resultType)); + + return evalConcatenateOp(results, outputFeatureDimension, result.getType()); + } + + if (batchGroupCount > 1) { + auto lhses = split(lhs, batchGroupCount, inputBatchDimension, + resultType.getContext()); + auto rhses = split(rhs, batchGroupCount, kernelOutputFeatureDimension, + resultType.getContext()); + SmallVector results; + for (auto [left, right] : llvm::zip(lhses, rhses)) + results.push_back(evalConvolutionOp( + left, right, windowStrides, padding, lhsDilation, rhsDilation, + windowReversal, inputBatchDimension, inputFeatureDimension, + inputSpatialDimensions, kernelInputFeatureDimension, + kernelOutputFeatureDimension, kernelSpatialDimensions, + outputBatchDimension, outputFeatureDimension, outputSpatialDimensions, + featureGroupCount, /*batchGroupCount=*/1, /*precisionConfig=*/{}, + resultType)); + + return evalConcatenateOp(results, outputFeatureDimension, result.getType()); + } + + Axes lhsPermutation; + lhsPermutation.push_back(inputBatchDimension); + lhsPermutation.append(inputSpatialDimensions.begin(), + inputSpatialDimensions.end()); + lhsPermutation.push_back(inputFeatureDimension); + + auto lhsWindowDimensions = + concatAndPermute(lhs.getShape()[inputBatchDimension], + extractElements(rhs.getShape(), kernelSpatialDimensions), + lhs.getShape()[inputFeatureDimension], lhsPermutation); + + auto lhsWindowStrides = + concatAndPermute(1L, llvm::to_vector(windowStrides), 1L, lhsPermutation); + + auto lhsBaseDilations = + concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, lhsPermutation); + + auto lhsWindowDilations = + concatAndPermute(1L, llvm::to_vector(rhsDilation), 1L, lhsPermutation); + + Sizes lhsPaddingLow, lhsPaddingHigh; + for (auto paddingPair : concatAndPermute({0, 0}, llvm::to_vector(padding), + {0, 0}, lhsPermutation)) { + lhsPaddingLow.push_back(paddingPair.first); + lhsPaddingHigh.push_back(paddingPair.second); + } + + auto paddingValue = constant(0.0, result.getElementType()); + auto paddedLhs = evalPadOp(lhs, paddingValue, lhsPaddingLow, lhsPaddingHigh, + Sizes(lhsBaseDilations)); + + IndexSpaceIterator outputSpatialIndexIt( + extractElements(result.getShape(), outputSpatialDimensions), + Index(outputSpatialDimensions.size())); + IndexSpaceIterator outputSpatialIndexItEnd( + extractElements(result.getShape(), outputSpatialDimensions)); + for (; outputSpatialIndexIt != outputSpatialIndexItEnd; + ++outputSpatialIndexIt) { + Sizes lhsWindowStart; + for (auto [i, offset] : llvm::enumerate( + concatAndPermute(0L, *outputSpatialIndexIt, 0L, lhsPermutation))) + lhsWindowStart.push_back(lhsWindowStrides[i] * offset); + + Sizes limitIndices; + for (size_t i = 0; i < lhsWindowStart.size(); ++i) + limitIndices.push_back(std::min( + lhsWindowStart[i] + lhsWindowDimensions[i] * lhsWindowDilations[i], + paddedLhs.getShape()[i])); + + auto lhsWindow = evalSliceOp(paddedLhs, lhsWindowStart, limitIndices, + Sizes(lhsWindowDilations)); + + Axes reverseDims; + for (auto [i, isReverse] : llvm::enumerate(windowReversal)) + if (isReverse) reverseDims.push_back(inputSpatialDimensions[i]); + auto reversedLhsWindow = + evalReverseOp(lhsWindow, reverseDims, lhsWindow.getType()); + + Axes lhsContractingDimensions(inputSpatialDimensions); + lhsContractingDimensions.push_back(inputFeatureDimension); + + Axes rhsContractingDimensions(kernelSpatialDimensions); + rhsContractingDimensions.push_back(kernelInputFeatureDimension); + + auto dotProduct = + evalDotGeneralOp(reversedLhsWindow, rhs, lhsContractingDimensions, + rhsContractingDimensions); + + Sizes resultNonSpatialDims; + for (auto i = 0; i < result.getRank(); ++i) + if (llvm::find(outputSpatialDimensions, i) == + outputSpatialDimensions.end()) + resultNonSpatialDims.push_back(result.getShape()[i]); + + Axes resultPermutation; + resultPermutation.push_back(outputBatchDimension); + resultPermutation.append(outputSpatialDimensions.begin(), + outputSpatialDimensions.end()); + resultPermutation.push_back(outputFeatureDimension); + + IndexSpaceIterator resultNonSpatialIt(resultNonSpatialDims, + Index(resultNonSpatialDims.size())); + for (auto dotProductIt = dotProduct.index_begin(); + dotProductIt != dotProduct.index_end(); + ++dotProductIt, ++resultNonSpatialIt) { + Index resultIndex( + concatAndPermute((*resultNonSpatialIt)[0], *outputSpatialIndexIt, + (*resultNonSpatialIt)[1], resultPermutation)); + result.set(resultIndex, dotProduct.get(*dotProductIt)); + } + } + return result; +} + Tensor evalCosineOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) diff --git a/stablehlo/reference/Ops.h b/stablehlo/reference/Ops.h index 6b58066caae..8db66de3af5 100644 --- a/stablehlo/reference/Ops.h +++ b/stablehlo/reference/Ops.h @@ -77,6 +77,16 @@ Tensor evalConcatenateOp(ArrayRef inputs, Axis dimension, ShapedType resultType); Tensor evalConstantOp(ElementsAttr value); Tensor evalConvertOp(const Tensor &operand, ShapedType resultType); +Tensor evalConvolutionOp( + const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, + ArrayRef> padding, + ArrayRef lhsDilation, ArrayRef rhsDilation, + ArrayRef windowReversal, Axis inputBatchDimension, + Axis inputFeatureDimension, const Axes &inputSpatialDimensions, + Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, + const Axes &kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, const Axes &outputSpatialDimensions, + int64_t featureGroupCount, int64_t batchGroupCount, ShapedType resultType); Tensor evalCosineOp(const Tensor &operand, ShapedType resultType); Tensor evalDivideOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); diff --git a/stablehlo/tests/interpret/convolution.mlir b/stablehlo/tests/interpret/convolution.mlir new file mode 100644 index 00000000000..4b7fdc3ad07 --- /dev/null +++ b/stablehlo/tests/interpret/convolution.mlir @@ -0,0 +1,79 @@ +// RUN: stablehlo-translate --interpret -split-input-file %s + +func.func @convolution_op_test_si64() { + %lhs = stablehlo.constant dense<[[ + [[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[10], [11], [14], [15]], + [[12], [13], [16], [17]] + ]]> : tensor<1x4x4x1xi64> + %rhs = stablehlo.constant dense<1> : tensor<3x3x1x1xi64> + %result = stablehlo.convolution(%lhs, %rhs) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [4, 4], + lhs_dilate = [2, 2] + } { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } + : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> + check.expect_eq_const %result, dense<[[ + [[10], [26]], + [[46], [62]] + ]]> : tensor<1x2x2x1xi64> + func.return +} + +// ----- + +func.func @convolution_batch_group_count_4() { + %lhs = stablehlo.constant dense<[[ + [[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[10], [11], [14], [15]], + [[12], [13], [16], [17]] + ]]> : tensor<1x4x4x1xi64> + %rhs = stablehlo.constant dense<1> : tensor<1x2x1x4xi64> + %result = stablehlo.convolution(%lhs, %rhs) + dim_numbers = [0, b, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [4, 4], + lhs_dilate = [2, 2] + } { + batch_group_count = 4 : i64, + feature_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } + : (tensor<1x4x4x1xi64>, tensor<1x2x1x4xi64>) -> tensor<1x1x2x4xi64> + check.expect_eq_const %result, dense<[[[[1, 3, 10, 12], + [5, 7, 14, 16]]]]> : tensor<1x1x2x4xi64> + func.return +} + +// ----- + +func.func @convolution_feature_group_count_2() { + %lhs = stablehlo.constant dense<[[ + [[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[10], [11], [14], [15]], + [[12], [13], [16], [17]] + ]]> : tensor<1x4x4x1xi64> + %rhs = stablehlo.constant dense<1> : tensor<1x2x1x4xi64> + %result = stablehlo.convolution(%lhs, %rhs) + dim_numbers = [b, 0, f, 1]x[0, i, 1, o]->[b, 0, 1, f], + window = { + stride = [4, 4], + lhs_dilate = [2, 2] + } { + batch_group_count = 1 : i64, + feature_group_count = 2 : i64, + precision_config = [#stablehlo, #stablehlo] + } + : (tensor<1x4x4x1xi64>, tensor<1x2x1x4xi64>) -> tensor<1x2x1x4xi64> + check.expect_eq_const %result, dense<[[[[3, 3, 11, 11]], + [[21, 21, 29, 29]]]]> : tensor<1x2x1x4xi64> + func.return +}