From 651ee35fabe6e2cfaf5b7397a21f5ae1b479b555 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 11 May 2023 20:56:40 +0000 Subject: [PATCH] Organize code based on recent PRs --- stablehlo/dialect/TypeInference.cpp | 80 +++++--- stablehlo/reference/Ops.cpp | 283 +++++++++++----------------- stablehlo/reference/Ops.h | 6 +- 3 files changed, 165 insertions(+), 204 deletions(-) diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 0384c3ed13a..d6af994b15b 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -318,26 +318,30 @@ verifyWindowAttributesAndInferWindowDimensions( " to have same dimension-size as size of window dimensions (", windowDimensions.size(), "), but got: ", attrSize, "."); }; - // convolution_c3, reduce_window_c6, select_and_scatter_c6 + // convolution_c2, reduce_window_c6, select_and_scatter_c6 if (failed(verifySize(windowStrides.size(), "window-strides"))) return failure(); - // convolution_c6, reduce_window_c8 + + // convolution_c5, reduce_window_c8 if (failed(verifySize(lhsDilation.size(), "base-dilation factors"))) return failure(); - // convolution_c8, reduce_window_c10 + + // convolution_c7, reduce_window_c10 if (failed(verifySize(rhsDilation.size(), "window-dilation factors"))) return failure(); - // convolution_c5, reduce_window_c12 + + // convolution_c4, reduce_window_c12 if (failed(verifySize(padding.size(), "padding-entries"))) return failure(); - // convolution_c10 + + // convolution_c9 if (failed(verifySize(windowReversal.size(), "window-reversal"))) return failure(); SmallVector window(windowDimensions.size()); for (size_t i = 0; i < windowDimensions.size(); i++) { WindowDimension& dim = window[i]; - dim.size = windowDimensions[i]; + // reduce_window_c5, select_and_scatter_c5 if (!isDynamicDimSize(dim.size) && dim.size <= 0) return emitOptionalError(loc, @@ -345,21 +349,24 @@ verifyWindowAttributesAndInferWindowDimensions( "-th window dimension, but got ", dim.size, "."); if (!windowStrides.empty()) dim.stride = windowStrides[i]; - // convolution_c4, reduce_window_c7, select_and_scatter_c7 + + // convolution_c3, reduce_window_c7, select_and_scatter_c7 if (dim.stride <= 0) return emitOptionalError( loc, "expects window to have positive stride for ", i, "-th window dimension, but got ", dim.stride, "."); if (!lhsDilation.empty()) dim.baseDilation = lhsDilation[i]; - // convolution_c7, reduce_window_c9 + + // convolution_c6, reduce_window_c9 if (dim.baseDilation <= 0) return emitOptionalError( loc, "expects window to have positive base dilation factor for ", i, "-th window dimension, but got ", dim.baseDilation, "."); if (!rhsDilation.empty()) dim.windowDilation = rhsDilation[i]; - // convolution_c9, reduce_window_c11 + + // convolution_c8, reduce_window_c11 if (dim.windowDilation <= 0) return emitOptionalError( loc, "expects window to have positive window dilation factor for ", i, @@ -755,7 +762,7 @@ LogicalResult isSpatialDimensionsValid( int64_t outputFeatureDimension, ArrayRef outputSpatialDimensions, std::optional location) { uint64_t spatialDimNum = inputSpatialDimensions.size(); - // convolution_c18, convolution_c20 + // convolution_c17, convolution_c19 if ((spatialDimNum != kernelSpatialDimensions.size()) || (spatialDimNum != outputSpatialDimensions.size())) return emitOptionalError(location, @@ -785,7 +792,7 @@ LogicalResult isSpatialDimensionsValid( auto numDims = lhsType.cast().getRank(); const auto inRange = [numDims](int64_t i) { return 0 <= i && i < numDims; }; - // convolution_c14, convolution_c19, convolution_c21 + // convolution_c13, convolution_c18, convolution_c20 if (!llvm::all_of(inputDimNums, inRange) || !llvm::all_of(windowDimNums, inRange) || !llvm::all_of(outputDimNums, inRange)) @@ -793,17 +800,20 @@ LogicalResult isSpatialDimensionsValid( "expects input, kernel, and output " "dimension-numbers to be in-range [0, ", numDims, ")."); - // convolution_c14 + + // convolution_c13 if (hasDuplicates(inputDimNums)) return emitOptionalError( location, "expects input dimension-numbers to be unique, got {", inputDimNums, "}."); - // convolution_c19 + + // convolution_c18 if (hasDuplicates(windowDimNums)) return emitOptionalError( location, "expects kernel dimension-numbers to be unique, got {", windowDimNums, "}."); - // convolution_c21 + + // convolution_c20 if (hasDuplicates(outputDimNums)) return emitOptionalError( location, "expects output dimension-numbers to be unique, got {", @@ -842,17 +852,19 @@ LogicalResult verifyConvolutionAttributes( location))) return failure(); - // convolution_c22 + // convolution_c21 if (featureGroupCount <= 0) return emitOptionalError( location, "expects feature_group_count to be a positive number, got ", featureGroupCount, "."); - // convolution_c23 + + // convolution_c22 if (batchGroupCount <= 0) return emitOptionalError( location, "expects batch_group_count to be a positive number, got ", batchGroupCount, "."); - // convolution_c24 + + // convolution_c23 if (batchGroupCount > 1 && featureGroupCount > 1) return emitOptionalError( location, @@ -870,22 +882,24 @@ LogicalResult verifyConvolutionAttributes( const int64_t kernelOutputFeatures = rankedRhsType.getShape()[kernelOutputFeatureDimension]; - // convolution_c11 + // convolution_c10 if (!isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0) return emitOptionalError(location, "expects input batch dimension (", inputBatch, ") to be divisible by " "batch_group_count. Got batch_group_count = ", batchGroupCount, "."); + if (!isDynamicDimSize(inputFeatures)) { - // convolution_c12 + // convolution_c11 if (inputFeatures % featureGroupCount != 0) return emitOptionalError(location, "expects input feature dimension (", inputFeatures, ") to be a multiple of feature_group_count. Got " "feature_group_count = ", featureGroupCount, "."); - // convolution_c15 + + // convolution_c14 if (!isDynamicDimSize(kernelInputFeatures) && inputFeatures / featureGroupCount != kernelInputFeatures) return emitOptionalError( @@ -895,15 +909,17 @@ LogicalResult verifyConvolutionAttributes( kernelInputFeatures, "). Got feature_group_count = ", featureGroupCount, "."); } + if (!isDynamicDimSize(kernelOutputFeatures)) { - // convolution_c16 + // convolution_c15 if (kernelOutputFeatures % batchGroupCount != 0) return emitOptionalError( location, "expects output feature dimension size (", kernelOutputFeatures, ") to be a multiple of batch_group_count. Got batch_group_count = ", batchGroupCount, "."); - // convolution_c17 + + // convolution_c16 if (kernelOutputFeatures % featureGroupCount != 0) return emitOptionalError(location, "expects kernel output feature dimension (", @@ -913,7 +929,7 @@ LogicalResult verifyConvolutionAttributes( featureGroupCount, "."); } - // convolution_c25 + // convolution_c24 if (failed(verifyPrecisionConfig(location, precisionConfig))) return failure(); @@ -1707,20 +1723,22 @@ LogicalResult inferConvolutionOp( return success(); } - // convolution_c14 + // convolution_c13 int numDims = rankedLhsType.getRank(); if (numDims < 2) return emitOptionalError( location, "expects convolution arguments to have >= 2 dimensions. Got: ", rankedLhsType, " and ", rankedRhsType, "."); + // convolution_c1 if (numDims != rankedRhsType.getRank()) return emitOptionalError(location, "expects convolution arguments to have same " "number of dimensions. Got: ", rankedLhsType, " and ", rankedRhsType, "."); - // convolution_c2 + + // convolution_c27 if (!isCompatibleForHloTypeInference(rankedLhsType.getElementType(), rankedRhsType.getElementType())) return emitOptionalError( @@ -1736,7 +1754,8 @@ LogicalResult inferConvolutionOp( outputSpatialDimensions, featureGroupCount, batchGroupCount, precisionConfig))) return failure(); - // convolution_c13 + + // convolution_c12 if ((size_t)numDims != inputSpatialDimensions.size() + 2) return emitOptionalError(location, "expects convolution arguments to have ", inputSpatialDimensions.size() + 2, @@ -1746,7 +1765,7 @@ LogicalResult inferConvolutionOp( for (size_t i = 0; i < windowDimensions.size(); i++) windowDimensions[i] = rankedRhsType.getShape()[kernelSpatialDimensions[i]]; - // convolution_c5, convolution_i4 + // convolution_c4, convolution_i4 auto paddingOrErr = convertPaddingAttribute(padding, location); if (failed(paddingOrErr)) return failure(); @@ -1754,14 +1773,17 @@ LogicalResult inferConvolutionOp( auto windowStridesOrErr = convert1DAttribute(windowStrides, location, "window_strides"); if (failed(windowStridesOrErr)) return failure(); + // convolution_i5 auto lhsDilationOrErr = convert1DAttribute(lhsDilation, location, "lhs_dilation"); if (failed(lhsDilationOrErr)) return failure(); + // convolution_i6 auto rhsDilationOrErr = convert1DAttribute(rhsDilation, location, "rhs_dilation"); if (failed(rhsDilationOrErr)) return failure(); + // convolution_i7 auto windowReversalOrErr = convertWindowReversalAttribute( windowReversal, location, "window_reversal"); @@ -1772,7 +1794,7 @@ LogicalResult inferConvolutionOp( *rhsDilationOrErr, *windowReversalOrErr, location); if (failed(windowOrErr)) return failure(); - // convolution_c26 + // convolution_c25, convolution_c26 SmallVector outputDimensions(rankedLhsType.getShape().size(), ShapedType::kDynamic); auto numSpatialDims = inputSpatialDimensions.size(); @@ -3284,7 +3306,7 @@ LogicalResult verifyConvolutionOp( auto inferredShape = inferredReturnShapes[0]; auto shapedResultType = resultType.cast(); - // convolution_c26 + // convolution_c25 if (inferredShape.hasRank() && shapedResultType.hasRank() && failed(verifyCompatibleShape(inferredShape.getDims(), shapedResultType.getShape()))) diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 050f51f6180..9ff9d98322b 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -51,6 +51,28 @@ Index evalIndex(Tensor tensor) { return result; } +Tensor evalDotGeneralOp(const Tensor &lhs, const Tensor &rhs, + const Axes &lhsBatchingDimensions, + const Axes &rhsBatchingDimensions, + const Axes &lhsContractingDimensions, + const Axes &rhsContractingDimensions) { + SmallVector inferredDotGeneralType; + auto dotGeneralStatus = hlo::inferDotGeneralOp( + /*location=*/{}, lhs.getType(), rhs.getType(), + /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions*/ {}, + lhsContractingDimensions, rhsContractingDimensions, + /*precisionConfig=*/{}, inferredDotGeneralType); + if (failed(dotGeneralStatus)) + report_fatal_error( + invalidArgument("Could not infer DotGeneralOp's return type")); + + return evalDotGeneralOp( + lhs, rhs, lhsBatchingDimensions, rhsBatchingDimensions, + lhsContractingDimensions, rhsContractingDimensions, + RankedTensorType::get(inferredDotGeneralType[0].getDims(), + lhs.getElementType())); +} + Tensor evalPadOp(const Tensor &operand, const Tensor &paddingValue, const Sizes &edgePaddingLow, const Sizes &edgePaddingHigh, const Sizes &interiorPadding) { @@ -130,9 +152,8 @@ Tensor evalSliceOp(const Tensor &operand, const Index &index) { return evalSliceOp(operand, start, limit, strides); } -SmallVector extractElements(ArrayRef arr, - ArrayRef indices) { - SmallVector elements; +Sizes extractElements(ArrayRef arr, ArrayRef indices) { + Sizes elements; for (auto index : indices) elements.push_back(arr[index]); return elements; } @@ -159,19 +180,19 @@ DenseIntElementsAttr getDenseIntElementsAttr( values); } -ShapedType inferConvolutionOpType( - ShapedType lhsType, ShapedType rhsType, ArrayRef windowStrides, +Tensor evalConvolutionOp( + const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, ArrayRef> padding, ArrayRef lhsDilation, ArrayRef rhsDilation, ArrayRef windowReversal, Axis inputBatchDimension, - Axis inputFeatureDimension, Axes inputSpatialDimensions, + Axis inputFeatureDimension, const Axes &inputSpatialDimensions, Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, - Axes kernelSpatialDimensions, Axis outputBatchDimension, - Axis outputFeatureDimension, Axes outputSpatialDimensions, + const Axes &kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, const Axes &outputSpatialDimensions, int64_t featureGroupCount, int64_t batchGroupCount, std::optional precisionConfig, ShapedType resultType) { - auto i64Type = IntegerType::get(lhsType.getContext(), 64); - auto i1Type = IntegerType::get(lhsType.getContext(), 1); + auto i64Type = IntegerType::get(lhs.getType().getContext(), 64); + auto i1Type = IntegerType::get(lhs.getType().getContext(), 1); SmallVector paddingVector; for (auto pair : padding) { @@ -182,7 +203,7 @@ ShapedType inferConvolutionOpType( SmallVector paddingShape{static_cast(padding.size()), 2}; SmallVector inferredConvolutionType; auto convolutionStatus = hlo::inferConvolutionOp( - /*location=*/{}, lhsType, rhsType, + /*location=*/{}, lhs.getType(), rhs.getType(), getDenseIntElementsAttr(i64Type, windowStrides, {}), getDenseIntElementsAttr(i64Type, paddingVector, paddingShape), getDenseIntElementsAttr(i64Type, lhsDilation, {}), @@ -202,65 +223,15 @@ ShapedType inferConvolutionOpType( if (failed(convolutionStatus)) report_fatal_error( invalidArgument("Could not infer ConvolutionOp's return type")); - return RankedTensorType::get(inferredConvolutionType[0].getDims(), - resultType.getElementType()); -} - -ShapedType inferDotGeneralOpType(ShapedType lhsType, ShapedType rhsType, - ArrayRef lhsContractingDimensions, - ArrayRef rhsContractingDimensions) { - SmallVector inferredDotGeneralType; - auto dotGeneralStatus = hlo::inferDotGeneralOp( - /*location=*/{}, lhsType, rhsType, - /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions*/ {}, - lhsContractingDimensions, rhsContractingDimensions, - /*precisionConfig=*/{}, inferredDotGeneralType); - if (failed(dotGeneralStatus)) - report_fatal_error( - invalidArgument("Could not infer DotGeneralOp's return type")); - return RankedTensorType::get(inferredDotGeneralType[0].getDims(), - lhsType.getElementType()); -} - -ShapedType inferPadOpType(ArrayRef> padding, - Type operandType, Type paddingValueElementType, - ArrayRef interiorPadding) { - SmallVector lhsPaddingLow; - SmallVector lhsPaddingHigh; - for (auto paddingPair : padding) { - lhsPaddingLow.push_back(paddingPair.first); - lhsPaddingHigh.push_back(paddingPair.second); - } - - SmallVector inferredPadType; - auto i64Type = IntegerType::get(operandType.getContext(), 64); - auto padStatus = hlo::inferPadOp( - {}, operandType, RankedTensorType::get({}, paddingValueElementType), - getDenseIntElementsAttr(i64Type, lhsPaddingLow, {}), - getDenseIntElementsAttr(i64Type, lhsPaddingHigh, {}), - getDenseIntElementsAttr(i64Type, interiorPadding, {}), inferredPadType); - - if (failed(padStatus)) - report_fatal_error(invalidArgument("Could not infer PadOp's return type")); - return inferredPadType[0].cast(); -} - -ShapedType inferSliceOpType(Type operandType, - SmallVector lhsWindowStart, - SmallVector limitIndices, - SmallVector lhsWindowDilations) { - SmallVector inferredSliceType; - auto i64Type = IntegerType::get(operandType.getContext(), 64); - auto sliceStatus = hlo::inferSliceOp( - {}, operandType, getDenseIntElementsAttr(i64Type, lhsWindowStart, {}), - getDenseIntElementsAttr(i64Type, limitIndices, {}), - getDenseIntElementsAttr(i64Type, lhsWindowDilations, {}), - inferredSliceType); - - if (failed(sliceStatus)) - report_fatal_error( - invalidArgument("Could not infer SliceOp's return type")); - return inferredSliceType[0].cast(); + return evalConvolutionOp( + lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, + windowReversal, inputBatchDimension, inputFeatureDimension, + inputSpatialDimensions, kernelInputFeatureDimension, + kernelOutputFeatureDimension, kernelSpatialDimensions, + outputBatchDimension, outputFeatureDimension, outputSpatialDimensions, + featureGroupCount, batchGroupCount, + RankedTensorType::get(inferredConvolutionType[0].getDims(), + resultType.getElementType())); } // Returns `result` with the effect of applying `permutation` @@ -279,22 +250,23 @@ SmallVector concatAndPermute(T n, SmallVector hw, T c, SmallVector split(const Tensor &input, int64_t groupSize, Axis splitDimension, MLIRContext *context) { - auto i64Type = IntegerType::get(context, 64); - auto getScalarTensor = [&](int64_t value) { + auto i64Type = IntegerType::get(context, 64); return Tensor(RankedTensorType::get({}, i64Type), convert(i64Type, value)); }; + Sizes splitInputShape(input.getShape()); splitInputShape[splitDimension] /= groupSize; - auto splitInputType = - RankedTensorType::get(splitInputShape, input.getElementType()); + SmallVector splitResults; for (auto idx = 0; idx < groupSize; ++idx) { SmallVector inputStartIndices(input.getRank(), getScalarTensor(0L)); inputStartIndices[splitDimension] = getScalarTensor(idx * splitInputShape[splitDimension]); - auto resultTensor = evalDynamicSliceOp(input, inputStartIndices, - splitInputShape, splitInputType); + + auto resultTensor = evalDynamicSliceOp( + input, inputStartIndices, splitInputShape, + RankedTensorType::get(splitInputShape, input.getElementType())); splitResults.push_back(resultTensor); } return splitResults; @@ -406,30 +378,35 @@ SmallVector eval( auto lhs = scope.findTensor(convolutionOp.getLhs()); auto rhs = scope.findTensor(convolutionOp.getRhs()); auto rank = lhs.getRank(); + SmallVector windowStrides(rank - 2, 1); - if (convolutionOp.getWindowStrides().has_value()) - windowStrides = llvm::to_vector( - convolutionOp.getWindowStridesAttr().getValues()); + if (auto windowStridesAttr = convolutionOp.getWindowStridesAttr()) + windowStrides.assign(windowStridesAttr.value_begin(), + windowStridesAttr.value_end()); + SmallVector> padding(rank - 2, {0, 0}); - if (convolutionOp.getPadding().has_value()) { - auto paddingOrErr = - hlo::convertPaddingAttribute(convolutionOp.getPadding(), {}); + if (auto paddingAttr = convolutionOp.getPaddingAttr()) { + auto paddingOrErr = hlo::convertPaddingAttribute(paddingAttr, {}); if (failed(paddingOrErr)) report_fatal_error(invalidArgument("Invalid padding format found.")); padding = *paddingOrErr; } + SmallVector lhsDilation(rank - 2, 1); - if (convolutionOp.getLhsDilation().has_value()) - lhsDilation = llvm::to_vector( - convolutionOp.getLhsDilationAttr().getValues()); + if (auto lhsDilationAttr = convolutionOp.getLhsDilationAttr()) + lhsDilation.assign(lhsDilationAttr.value_begin(), + lhsDilationAttr.value_end()); + SmallVector rhsDilation(rank - 2, 1); - if (convolutionOp.getRhsDilation().has_value()) - rhsDilation = llvm::to_vector( - convolutionOp.getRhsDilationAttr().getValues()); + if (auto rhsDilationAttr = convolutionOp.getRhsDilationAttr()) + rhsDilation.assign(rhsDilationAttr.value_begin(), + rhsDilationAttr.value_end()); + SmallVector windowReversal(rank - 2, false); - if (convolutionOp.getWindowReversal().has_value()) - windowReversal = llvm::to_vector( - convolutionOp.getWindowReversalAttr().getValues()); + if (auto windowReversalAttr = convolutionOp.getWindowReversalAttr()) + windowReversal.assign(windowReversalAttr.value_begin(), + windowReversalAttr.value_end()); + auto result = evalConvolutionOp( lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, windowReversal, @@ -1009,10 +986,10 @@ Tensor evalConvolutionOp( ArrayRef> padding, ArrayRef lhsDilation, ArrayRef rhsDilation, ArrayRef windowReversal, Axis inputBatchDimension, - Axis inputFeatureDimension, Axes inputSpatialDimensions, + Axis inputFeatureDimension, const Axes &inputSpatialDimensions, Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, - Axes kernelSpatialDimensions, Axis outputBatchDimension, - Axis outputFeatureDimension, Axes outputSpatialDimensions, + const Axes &kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, const Axes &outputSpatialDimensions, int64_t featureGroupCount, int64_t batchGroupCount, ShapedType resultType) { Tensor result(resultType); @@ -1029,16 +1006,8 @@ Tensor evalConvolutionOp( inputSpatialDimensions, kernelInputFeatureDimension, kernelOutputFeatureDimension, kernelSpatialDimensions, outputBatchDimension, outputFeatureDimension, outputSpatialDimensions, - /*featureGroupCount=*/1, batchGroupCount, - inferConvolutionOpType( - left.getType(), right.getType(), windowStrides, padding, - lhsDilation, rhsDilation, windowReversal, inputBatchDimension, - inputFeatureDimension, inputSpatialDimensions, - kernelInputFeatureDimension, kernelOutputFeatureDimension, - kernelSpatialDimensions, outputBatchDimension, - outputFeatureDimension, outputSpatialDimensions, - /*featureGroupCount=*/1, batchGroupCount, - /*precisionConfig=*/{}, resultType)); + /*featureGroupCount=*/1, batchGroupCount, /*precisionConfig=*/{}, + resultType); results.push_back(convolutionResult); } return evalConcatenateOp(results, outputFeatureDimension, result.getType()); @@ -1057,16 +1026,8 @@ Tensor evalConvolutionOp( inputSpatialDimensions, kernelInputFeatureDimension, kernelOutputFeatureDimension, kernelSpatialDimensions, outputBatchDimension, outputFeatureDimension, outputSpatialDimensions, - featureGroupCount, /*batchGroupCount=*/1, - inferConvolutionOpType( - left.getType(), right.getType(), windowStrides, padding, - lhsDilation, rhsDilation, windowReversal, inputBatchDimension, - inputFeatureDimension, inputSpatialDimensions, - kernelInputFeatureDimension, kernelOutputFeatureDimension, - kernelSpatialDimensions, outputBatchDimension, - outputFeatureDimension, outputSpatialDimensions, - featureGroupCount, /*batchGroupCount=*/1, - /*precisionConfig=*/{}, resultType)); + featureGroupCount, /*batchGroupCount=*/1, /*precisionConfig=*/{}, + resultType); results.push_back(convolutionResult); } return evalConcatenateOp(results, outputFeatureDimension, result.getType()); @@ -1086,86 +1047,63 @@ Tensor evalConvolutionOp( auto lhsWindowStrides = concatAndPermute(1L, llvm::to_vector(windowStrides), 1L, lhsPermutation); - auto lhsPadding = concatAndPermute({0, 0}, llvm::to_vector(padding), {0, 0}, - lhsPermutation); - auto lhsBaseDilations = concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, lhsPermutation); auto lhsWindowDilations = concatAndPermute(1L, llvm::to_vector(rhsDilation), 1L, lhsPermutation); - auto outputSpatialIndexIt = IndexSpaceIterator( - Sizes(extractElements(result.getShape(), outputSpatialDimensions)), - Index(outputSpatialDimensions.size())); - auto outputSpatialIndexItEnd = IndexSpaceIterator( - Sizes(extractElements(result.getShape(), outputSpatialDimensions)), - std::nullopt); - - SmallVector lhsPaddingLow; - for (auto paddingPair : lhsPadding) + Sizes lhsPaddingLow; + Sizes lhsPaddingHigh; + for (auto paddingPair : concatAndPermute({0, 0}, llvm::to_vector(padding), + {0, 0}, lhsPermutation)) { lhsPaddingLow.push_back(paddingPair.first); + lhsPaddingHigh.push_back(paddingPair.second); + } - auto inferredPadOpType = inferPadOpType( - lhsPadding, lhs.getType(), result.getElementType(), lhsBaseDilations); - - auto paddedLhs = evalPadOp( - lhs, - Tensor(RankedTensorType::get({}, result.getElementType()), - convert(result.getElementType(), 0.0)), - Sizes(lhsPaddingLow), Sizes(lhsBaseDilations), inferredPadOpType); + Tensor paddingValue(RankedTensorType::get({}, result.getElementType()), + convert(result.getElementType(), 0L)); + auto paddedLhs = evalPadOp(lhs, paddingValue, lhsPaddingLow, lhsPaddingHigh, + Sizes(lhsBaseDilations)); + IndexSpaceIterator outputSpatialIndexIt( + extractElements(result.getShape(), outputSpatialDimensions), + Index(outputSpatialDimensions.size())); + IndexSpaceIterator outputSpatialIndexItEnd( + extractElements(result.getShape(), outputSpatialDimensions), + std::nullopt); for (; outputSpatialIndexIt != outputSpatialIndexItEnd; ++outputSpatialIndexIt) { - SmallVector lhsPaddingLow; - for (auto paddingPair : lhsPadding) - lhsPaddingLow.push_back(paddingPair.first); - - auto paddedLhs = - evalPadOp(lhs, - Tensor(RankedTensorType::get({}, result.getElementType()), - convert(result.getElementType(), 0.0)), - Sizes(lhsPaddingLow), Sizes(lhsBaseDilations), - inferPadOpType(lhsPadding, lhs.getType(), - result.getElementType(), lhsBaseDilations)); - - SmallVector lhsWindowStart; + Sizes lhsWindowStart; for (auto [i, offset] : llvm::enumerate(concatAndPermute( 0L, llvm::to_vector(*outputSpatialIndexIt), 0L, lhsPermutation))) lhsWindowStart.push_back(lhsWindowStrides[i] * offset); - SmallVector limitIndices; + Sizes limitIndices; for (size_t i = 0; i < lhsWindowStart.size(); ++i) limitIndices.push_back(std::min( lhsWindowStart[i] + lhsWindowDimensions[i] * lhsWindowDilations[i], paddedLhs.getShape()[i])); - auto lhsWindow = - evalSliceOp(paddedLhs, Sizes(lhsWindowStart), Sizes(lhsWindowDilations), - inferSliceOpType(paddedLhs.getType(), lhsWindowStart, - limitIndices, lhsWindowDilations)); + auto lhsWindow = evalSliceOp(paddedLhs, lhsWindowStart, limitIndices, + Sizes(lhsWindowDilations)); - SmallVector reverseDims; + Axes reverseDims; for (auto [i, isReverse] : llvm::enumerate(windowReversal)) if (isReverse) reverseDims.push_back(inputSpatialDimensions[i]); auto reversedLhsWindow = - evalReverseOp(lhsWindow, Axes(reverseDims), lhsWindow.getType()); + evalReverseOp(lhsWindow, reverseDims, lhsWindow.getType()); - auto lhsContractingDimensions = llvm::to_vector(inputSpatialDimensions); - lhsContractingDimensions.push_back( - static_cast(inputFeatureDimension)); + Axes lhsContractingDimensions(inputSpatialDimensions); + lhsContractingDimensions.push_back(inputFeatureDimension); - auto rhsContractingDimensions = llvm::to_vector(kernelSpatialDimensions); - rhsContractingDimensions.push_back( - static_cast(kernelInputFeatureDimension)); + Axes rhsContractingDimensions(kernelSpatialDimensions); + rhsContractingDimensions.push_back(kernelInputFeatureDimension); - auto dotProduct = evalDotGeneralOp( - reversedLhsWindow, rhs, /*lhsBatchingDimensions=*/{}, - /*rhsBatchingDimensions=*/{}, Axes(lhsContractingDimensions), - Axes(rhsContractingDimensions), - inferDotGeneralOpType(reversedLhsWindow.getType(), rhs.getType(), - lhsContractingDimensions, - rhsContractingDimensions)); + auto dotProduct = + evalDotGeneralOp(reversedLhsWindow, rhs, /*lhsBatchingDimensions=*/{}, + /*rhsBatchingDimensions=*/{}, lhsContractingDimensions, + rhsContractingDimensions); Sizes resultNonSpatialDims; for (auto i = 0; i < result.getRank(); ++i) @@ -1178,15 +1116,16 @@ Tensor evalConvolutionOp( resultPermutation.append(outputSpatialDimensions.begin(), outputSpatialDimensions.end()); resultPermutation.push_back(outputFeatureDimension); - auto resultNonSpatialIt = IndexSpaceIterator( - resultNonSpatialDims, Index(resultNonSpatialDims.size())); + + IndexSpaceIterator resultNonSpatialIt(resultNonSpatialDims, + Index(resultNonSpatialDims.size())); for (auto dotProductIt = dotProduct.index_begin(); dotProductIt != dotProduct.index_end(); ++dotProductIt, ++resultNonSpatialIt) { - auto resultIndex = + Index resultIndex( concatAndPermute((*resultNonSpatialIt)[0], *outputSpatialIndexIt, - (*resultNonSpatialIt)[1], resultPermutation); - result.set(Index(resultIndex), dotProduct.get(*dotProductIt)); + (*resultNonSpatialIt)[1], resultPermutation)); + result.set(resultIndex, dotProduct.get(*dotProductIt)); } } return result; diff --git a/stablehlo/reference/Ops.h b/stablehlo/reference/Ops.h index f2cda3bc032..a18a040a33f 100644 --- a/stablehlo/reference/Ops.h +++ b/stablehlo/reference/Ops.h @@ -57,10 +57,10 @@ Tensor evalConvolutionOp( ArrayRef> padding, ArrayRef lhsDilation, ArrayRef rhsDilation, ArrayRef windowReversal, Axis inputBatchDimension, - Axis inputFeatureDimension, Axes inputSpatialDimensions, + Axis inputFeatureDimension, const Axes &inputSpatialDimensions, Axis kernelInputFeatureDimension, Axis kernelOutputFeatureDimension, - Axes kernelSpatialDimensions, Axis outputBatchDimension, - Axis outputFeatureDimension, Axes outputSpatialDimensions, + const Axes &kernelSpatialDimensions, Axis outputBatchDimension, + Axis outputFeatureDimension, const Axes &outputSpatialDimensions, int64_t featureGroupCount, int64_t batchGroupCount, ShapedType resultType); Tensor evalCosineOp(const Tensor &operand, ShapedType resultType); Tensor evalDivideOp(const Tensor &lhs, const Tensor &rhs,