diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 5384650e30c..39d3fa6c27c 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -1687,7 +1687,7 @@ struct GatherConversion final : OpConversionPattern { int64_t resultRank = resultType.getRank(); // slice_sizes has to have the same size as operand.rank, and doing it this // way permits an unranked operand. - int64_t operandRank = gatherOp.getSliceSizes().getNumElements(); + int64_t operandRank = gatherOp.getSliceSizes().size(); int64_t indexVectorDim = gatherOp.getDimensionNumbers().getIndexVectorDim(); diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp index 173e73325bd..cc922eeab5d 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.h" #include "stablehlo/conversions/linalg/transforms/Rewriters.h" +#include "stablehlo/dialect/Base.h" #include "stablehlo/dialect/StablehloOps.h" namespace mlir::stablehlo { @@ -29,13 +30,14 @@ namespace { /// Apply dilation and padding to the input of a convolution. Value applyConvolutionPadding(Location loc, Value input, DenseIntElementsAttr padding, - DenseIntElementsAttr lhsDilation, + Attribute lhsDilation, llvm::ArrayRef dimMappings, OpBuilder &rewriter) { - if ((!padding || isSplatValue(padding, 0)) && - (!lhsDilation || isSplatValue(lhsDilation, 1))) { - return input; - } + SmallVector lhsDilationValues; + if (lhsDilation) lhsDilationValues = hlo::getI64Array(lhsDilation); + bool noPadding = !padding || isSplatValue(padding, 0); + bool noDilation = !lhsDilation || hlo::isSplatArray(lhsDilationValues, 1); + if (noPadding && noDilation) return input; auto inputType = cast(input.getType()); int64_t rank = inputType.getRank(); @@ -58,10 +60,10 @@ Value applyConvolutionPadding(Location loc, Value input, // Translate input dilation into interior padding. SmallVector padInterior(rank, 0); if (lhsDilation) { - assert(rank == lhsDilation.size() + 2); - for (int64_t i : llvm::seq(0, lhsDilation.size())) { + assert(rank == static_cast(lhsDilationValues.size()) + 2); + for (int64_t i : llvm::seq(0, lhsDilationValues.size())) { int64_t dim = dimMappings[i]; - padInterior[dim] = lhsDilation.getValues()[i] - 1; + padInterior[dim] = lhsDilationValues[i] - 1; } } @@ -91,8 +93,7 @@ Value applyConvolutionReversal(Location loc, OpBuilder &b, return filter; } llvm::SmallVector reversedDims; - for (auto [idx, reversed] : - llvm::enumerate(reversals.value().getValues())) { + for (auto [idx, reversed] : llvm::enumerate(reversals.value())) { if (reversed) { reversedDims.push_back( op.getDimensionNumbers().getKernelSpatialDimensions()[idx]); @@ -219,8 +220,10 @@ struct NormalConvolutionOpConversion final loc, resultType.getShape(), resultType.getElementType(), dynSizes); Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); linalg::LinalgOp res; - Attribute strides = op.getWindowStridesAttr(); - Attribute dilations = op.getRhsDilationAttr(); + Attribute strides; + if (auto s = op.getWindowStrides()) strides = rewriter.getI64TensorAttr(*s); + Attribute dilations; + if (auto d = op.getRhsDilation()) dilations = rewriter.getI64TensorAttr(*d); // Apply padding and input dilation. llvm::SmallVector spatialDimMapping(rank - 2); @@ -512,7 +515,7 @@ struct ConvolutionOpGeneralConversion final AffineExpr stride = dim0; if (op.getWindowStrides().has_value()) - stride = stride * op.getWindowStrides().value().getValues()[i]; + stride = stride * op.getWindowStrides().value()[i]; AffineExpr srcExpr = stride + dim1; srcExprs[lhsIndexMapping[inputSpatialDimensions[i]]] = srcExpr; @@ -599,7 +602,7 @@ struct DepthwiseConvolutionOpConversion final Attribute windowStrides; if (op.getWindowStrides()) { - windowStrides = op.getWindowStrides().value(); + windowStrides = rewriter.getI64TensorAttr(op.getWindowStrides().value()); } else { windowStrides = SplatElementsAttr::get( VectorType::get({spatialRank}, rewriter.getI64Type()), @@ -608,7 +611,7 @@ struct DepthwiseConvolutionOpConversion final Attribute rhsDilation; if (op.getRhsDilation()) { - rhsDilation = op.getRhsDilation().value(); + rhsDilation = rewriter.getI64TensorAttr(op.getRhsDilation().value()); } else { rhsDilation = SplatElementsAttr::get( VectorType::get({spatialRank}, rewriter.getI64Type()), diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp index 20bc88dbea1..300492e2602 100644 --- a/stablehlo/dialect/Base.cpp +++ b/stablehlo/dialect/Base.cpp @@ -612,10 +612,23 @@ SmallVector getI64Array(Attribute attr) { if (auto array = attr.dyn_cast()) return llvm::to_vector(array.asArrayRef()); llvm::report_fatal_error( - "called i64ArrayOrElementsValues on Attribute that was neither a " + "called getI64Array on Attribute that was neither a " "DenseIntElementsAttr or a DenseI64ArrayAttr", false); } +SmallVector getBoolArray(Attribute attr) { + if (!attr) return {}; + if (auto elements = attr.dyn_cast()) + return llvm::to_vector(elements.getValues()); + if (auto array = attr.dyn_cast()) { + return SmallVector(array.asArrayRef()); + } + llvm::report_fatal_error( + "called getBoolArray on Attribute that was neither a " + "DenseIntOrFPElementsAttr or a DenseBoolArrayAttr", + false); +} + } // namespace hlo } // namespace mlir diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h index f36b9f7e207..5045146e936 100644 --- a/stablehlo/dialect/Base.h +++ b/stablehlo/dialect/Base.h @@ -71,6 +71,13 @@ bool isSplatArray(ArrayRef arr, int64_t val); // have been removed. SmallVector getI64Array(Attribute); +// Returns a vector of the bool values in a BoolDenseArrayOrElementsAttr. +// Such an Attr can be backed by either a DenseIntOrFPElementsAttr or +// a DenseBoolArrayAttr. +// TODO(#1578): Remove this code once all uses of BoolDenseArrayOrElementsAttr +// have been removed. +SmallVector getBoolArray(Attribute); + // Verifies that the two types have compatible shape with bounds but allows // different element types. LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2); diff --git a/stablehlo/dialect/StablehloAttrs.td b/stablehlo/dialect/StablehloAttrs.td index 418cab10806..6371425327c 100644 --- a/stablehlo/dialect/StablehloAttrs.td +++ b/stablehlo/dialect/StablehloAttrs.td @@ -160,17 +160,6 @@ def StableHLO_FlatSymbolRefArrayAttr : let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)"; } -def StableHLO_BoolElementsAttr : - ElementsAttrBase< - And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">, - CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>, - "constant boolean vector/tensor attribute"> { - let storageType = [{ ::mlir::DenseElementsAttr }]; - let returnType = [{ ::mlir::DenseElementsAttr }]; - - let convertFromStorage = "$_self"; -} - def StableHLO_ConvDimensionNumbers : AttrDef { let mnemonic = "conv"; let summary = "Structure of dimension information for conv op"; @@ -190,18 +179,35 @@ def StableHLO_ConvDimensionNumbers : AttrDef()">, + CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>, + "constant boolean vector/tensor attribute"> { + let storageType = [{ ::mlir::DenseElementsAttr }]; + let returnType = [{ ::mlir::DenseElementsAttr }]; + + let convertFromStorage = "$_self"; +} + +def BoolDenseArrayOrElementsAttr : Attr, "either a DenseBoolArrayAttr or a StableHLO_BoolElementsAttr"> { + let storageType = "Attribute"; + let returnType = "SmallVector"; + let convertFromStorage = "hlo::getBoolArray($_self)"; +} + def StableHLO_ConvolutionAttributes { dag attributes = (ins // Default value: one for each of the spatial dimension. - OptionalAttr:$window_strides, + OptionalAttr:$window_strides, // Default value: two zeros for each of the spatial dimension. OptionalAttr:$padding, // Default value: one for each of the spatial dimension. - OptionalAttr:$lhs_dilation, + OptionalAttr:$lhs_dilation, // Default value: one for each of the spatial dimension. - OptionalAttr:$rhs_dilation, + OptionalAttr:$rhs_dilation, // Default value: false for each of the spatial dimension. - OptionalAttr:$window_reversal, + OptionalAttr:$window_reversal, StableHLO_ConvDimensionNumbers:$dimension_numbers, I64Attr:$feature_group_count, I64Attr:$batch_group_count, diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 672fc349027..3568e912e8e 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -72,6 +72,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/InliningUtils.h" #include "stablehlo/dialect/AssemblyFormat.h" +#include "stablehlo/dialect/Base.h" #include "stablehlo/dialect/StablehloBytecode.h" #include "stablehlo/dialect/StablehloOps.h.inc" #include "stablehlo/dialect/TypeInference.h" @@ -613,7 +614,7 @@ namespace { void getSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc, ValueRange operands, SmallVectorImpl& sliceSizes) { - for (int64_t val : gather->getSliceSizes().getValues()) + for (int64_t val : gather->getSliceSizes()) sliceSizes.push_back(builder.create(loc, val)); } @@ -3090,42 +3091,30 @@ Attribute ConvDimensionNumbersAttr::parse(AsmParser& parser, Type type) { namespace { // Custom formatting for convolution window attributes. -void printWindowAttribute(OpAsmPrinter& p, DenseElementsAttr attribute) { - if (attribute.getElementType().isInteger(/*width=*/1)) { - // boolean attribute. - llvm::interleaveComma(attribute.getValues(), p, - [&](bool b) { p << (b ? 1 : 0); }); - return; - } - if (attribute.getType().getRank() == 2) { - // Padding is Nx2 attribute. - auto it = attribute.value_begin(); - std::vector> values(attribute.getNumElements() / - 2); - for (auto& item : values) { - int64_t first = *it; - ++it; - int64_t second = *it; - ++it; - item = {first, second}; - } - llvm::interleaveComma( - values, p, [&](const std::pair pair) { - p << '[' << pair.first << ", " << pair.second << ']'; - }); - } else { - llvm::interleaveComma(attribute.getValues(), p); +void printWindowPadding(OpAsmPrinter& p, DenseElementsAttr padding) { + // Padding is Nx2 attribute. + auto it = padding.value_begin(); + std::vector> values(padding.getNumElements() / 2); + for (auto& item : values) { + int64_t first = *it; + ++it; + int64_t second = *it; + ++it; + item = {first, second}; } + llvm::interleaveComma(values, p, [&](const std::pair pair) { + p << '[' << pair.first << ", " << pair.second << ']'; + }); } } // namespace void printWindowAttributes(OpAsmPrinter& p, Operation* /*op*/, - std::optional windowStrides, + std::optional windowStrides, std::optional padding, - std::optional lhsDilation, - std::optional rhsDilation, - std::optional windowReversal) { - using pair_t = std::pair; + std::optional lhsDilation, + std::optional rhsDilation, + std::optional windowReversal) { + using pair_t = std::pair; std::array printedAttributes = {{ {windowStrides ? *windowStrides : nullptr, "stride"}, {padding ? *padding : nullptr, "pad"}, @@ -3139,19 +3128,26 @@ void printWindowAttributes(OpAsmPrinter& p, Operation* /*op*/, printedAttributes, [](const pair_t& a) { return static_cast(a.first); }); - llvm::interleaveComma(nonNullAttributes, p, [&](const pair_t& a) { - p << a.second << " = ["; - printWindowAttribute(p, a.first); - p << "]"; + llvm::interleaveComma(nonNullAttributes, p, [&](const pair_t& attr) { + p << attr.second << " = ["; + + if (attr.second == "pad") { + printWindowPadding(p, attr.first.dyn_cast()); + } else if (attr.second == "reverse") { + llvm::interleaveComma(hlo::getBoolArray(attr.first), p); + } else { + llvm::interleaveComma(hlo::getI64Array(attr.first), p); + } + + p << ']'; }); } -ParseResult parseWindowAttributes(OpAsmParser& parser, - DenseIntElementsAttr& windowStrides, +ParseResult parseWindowAttributes(OpAsmParser& parser, Attribute& windowStrides, DenseIntElementsAttr& padding, - DenseIntElementsAttr& lhsDilation, - DenseIntElementsAttr& rhsDilation, - DenseElementsAttr& windowReversal) { + Attribute& lhsDilation, + Attribute& rhsDilation, + Attribute& windowReversal) { StringRef attributeName; llvm::StringSet<> allowedAttributeNames{ @@ -3205,9 +3201,8 @@ ParseResult parseWindowAttributes(OpAsmParser& parser, if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, int64Parser)) return failure(); - const int64_t size = static_cast(values.size()); if (attributeName == "reverse") { - auto ty = RankedTensorType::get({size}, + auto ty = RankedTensorType::get({static_cast(values.size())}, parser.getBuilder().getIntegerType(1)); auto boolVector = llvm::to_vector<4>( llvm::map_range(values, [](int64_t v) { return v != 0; })); diff --git a/stablehlo/dialect/StablehloOps.h b/stablehlo/dialect/StablehloOps.h index 9ee96d58e99..e85bdba0fb7 100644 --- a/stablehlo/dialect/StablehloOps.h +++ b/stablehlo/dialect/StablehloOps.h @@ -88,18 +88,17 @@ ParseResult parseConvolutionDimensions(AsmParser &parser, // Custom formatting for convolution window attributes. void printWindowAttributes(OpAsmPrinter &p, Operation *op, - std::optional windowStrides, + std::optional windowStrides, std::optional padding, - std::optional lhsDilation, - std::optional rhsDilation, - std::optional windowReversal); + std::optional lhsDilation, + std::optional rhsDilation, + std::optional windowReversal); -ParseResult parseWindowAttributes(OpAsmParser &parser, - DenseIntElementsAttr &windowStrides, +ParseResult parseWindowAttributes(OpAsmParser &parser, Attribute &windowStrides, DenseIntElementsAttr &padding, - DenseIntElementsAttr &lhsDilation, - DenseIntElementsAttr &rhsDilation, - DenseElementsAttr &windowReversal); + Attribute &lhsDilation, + Attribute &rhsDilation, + Attribute &windowReversal); } // end namespace stablehlo } // end namespace mlir diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 6d70ba80161..b987561aa51 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2102,11 +2102,11 @@ def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", [Pure]> { Example: ```mlir %result = "stablehlo.convolution"(%lhs, %rhs) { - window_strides = dense<4> : tensor<2xi64>, + window_strides = array, padding = dense<0> : tensor<2x2xi64>, - lhs_dilation = dense<2> : tensor<2xi64>, - rhs_dilation = dense<1> : tensor<2xi64>, - window_reversal = dense : tensor<2xi1>, + lhs_dilation = array, + rhs_dilation = array, + window_reversal = array, dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 1 : i64, batch_group_count = 1 : i64, @@ -2126,8 +2126,7 @@ def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", [Pure]> { let extraClassDeclaration = [{ bool hasWindowReversal() { auto reversal = getWindowReversalAttr(); - return reversal && llvm::any_of(reversal.getValues(), - [](bool v) { return v; }); + return reversal && llvm::any_of(hlo::getBoolArray(reversal), [](bool v) { return v; }); } }]; @@ -2387,7 +2386,7 @@ def StableHLO_GatherOp: StableHLO_Op<"gather", [InferTensorTypeWithReify /*gathe collapsed_slice_dims = [0], start_index_map = [1, 0], index_vector_dim = 2>, - slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>, + slice_sizes = array, indices_are_sorted = false } : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32> ``` @@ -2397,7 +2396,7 @@ def StableHLO_GatherOp: StableHLO_Op<"gather", [InferTensorTypeWithReify /*gathe HLO_Tensor:$operand /*gather_i1*/, HLO_IntTensor:$start_indices /*gather_i2*/, StableHLO_GatherDimensionNumbers:$dimension_numbers /*gather_i3, gather_i4, gather_i5, gather_i6*/, - I64ElementsAttr:$slice_sizes /*gather_i7*/, + I64DenseArrayOrElements1DAttr:$slice_sizes /*gather_i7*/, DefaultValuedOptionalAttr:$indices_are_sorted /*gather_i8*/ ); diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 33ae7ffa0a6..a0ccd5e4d3d 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -1176,7 +1176,8 @@ static LogicalResult verifyGather( // gather_i7 if (sliceSizesShape.hasRank() && sliceSizesShape.getRank() != 1) - return emitOptionalError(location, "slice_sizes.rank != 1"); + return emitOptionalError(location, "slice_sizes.rank != 1 (got ", + sliceSizesShape.getRank(), ')'); if (sliceSizesShape.hasStaticShape()) { int64_t sliceSize = sliceSizesShape.getNumElements(); @@ -1778,13 +1779,12 @@ LogicalResult inferConvertOp( */ LogicalResult inferConvolutionOp( std::optional location, Type lhsType, Type rhsType, - std::optional windowStrides, + std::optional> windowStrides, std::optional padding, - std::optional lhsDilation, - std::optional rhsDilation, - std::optional windowReversal, - int64_t inputBatchDimension, int64_t inputFeatureDimension, - ArrayRef inputSpatialDimensions, + std::optional> lhsDilation, + std::optional> rhsDilation, + std::optional> windowReversal, int64_t inputBatchDimension, + int64_t inputFeatureDimension, ArrayRef inputSpatialDimensions, int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension, ArrayRef kernelSpatialDimensions, int64_t outputBatchDimension, int64_t outputFeatureDimension, ArrayRef outputSpatialDimensions, @@ -1834,22 +1834,12 @@ LogicalResult inferConvolutionOp( if (failed(paddingOrErr)) return failure(); // TODO: add missing tests for ConvolutionOp. - auto windowStridesOrErr = - convert1DAttribute(windowStrides, location, "window_strides"); - if (failed(windowStridesOrErr)) return failure(); - auto lhsDilationOrErr = - convert1DAttribute(lhsDilation, location, "lhs_dilation"); - if (failed(lhsDilationOrErr)) return failure(); - auto rhsDilationOrErr = - convert1DAttribute(rhsDilation, location, "rhs_dilation"); - if (failed(rhsDilationOrErr)) return failure(); - auto windowReversalOrErr = convertWindowReversalAttribute( - windowReversal, location, "window_reversal"); - if (failed(windowReversalOrErr)) return failure(); auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions( - windowDimensions, *windowStridesOrErr, *paddingOrErr, *lhsDilationOrErr, - *rhsDilationOrErr, *windowReversalOrErr, location); + windowDimensions, windowStrides.value_or(ArrayRef{}), + *paddingOrErr, lhsDilation.value_or(ArrayRef{}), + rhsDilation.value_or(ArrayRef{}), + windowReversal.value_or(ArrayRef{}), location); if (failed(windowOrErr)) return failure(); // P3. @@ -2300,22 +2290,25 @@ LogicalResult inferGatherOp( std::optional location, Value operand, Value startIndices, ArrayRef offsetDims, ArrayRef collapsedSliceDims, ArrayRef startIndexMap, int64_t indexVectorDim, - DenseIntElementsAttr sliceSizes, + ArrayRef sliceSizes, SmallVectorImpl& inferredReturnShapes) { ShapeAdaptor operandShape(operand.getType()); ShapeAdaptor startIndicesShape(startIndices.getType()); + SmallVector ssShape{static_cast(sliceSizes.size())}; + ShapedTypeComponents ssSTC{ssShape}; + ShapeAdaptor sliceSizesShape(ssSTC); // For some reason the getType call is necessary here if (failed(verifyGather(location, /*operandShape=*/operandShape, /*startIndicesShape=*/startIndicesShape, - /*sliceSizesShape=*/sliceSizes.getType(), offsetDims, + /*sliceSizesShape=*/sliceSizesShape, offsetDims, collapsedSliceDims, startIndexMap, indexVectorDim))) return failure(); // gather_c8 for (auto dim : collapsedSliceDims) { - int64_t sliceDimSize = sliceSizes.getValues()[dim]; + int64_t sliceDimSize = sliceSizes[dim]; if (sliceDimSize > 1) return emitOptionalError(location, "slice_sizes collapsed dimension ", dim, " should <= 1 but got ", sliceDimSize); @@ -2323,7 +2316,7 @@ LogicalResult inferGatherOp( // gather_c12 if (operandShape.hasRank()) { - for (const auto& it : llvm::enumerate(sliceSizes.getValues())) { + for (const auto& it : llvm::enumerate(sliceSizes)) { if (operandShape.isDynamicDim(it.index())) continue; auto operandDimSize = operandShape.getDimSize(it.index()); auto sliceDimSize = it.value(); @@ -2335,7 +2328,7 @@ LogicalResult inferGatherOp( } auto getSliceDim = [&sliceSizes](int64_t index) -> int64_t { - return sliceSizes.getValues()[index]; + return sliceSizes[index]; }; return inferGatherReturnTypeComponents( @@ -3380,13 +3373,12 @@ LogicalResult verifyCollectivePermuteOp( LogicalResult verifyConvolutionOp( std::optional location, Type lhsType, Type rhsType, - std::optional windowStrides, + std::optional> windowStrides, std::optional padding, - std::optional lhsDilation, - std::optional rhsDilation, - std::optional windowReversal, - int64_t inputBatchDimension, int64_t inputFeatureDimension, - ArrayRef inputSpatialDimensions, + std::optional> lhsDilation, + std::optional> rhsDilation, + std::optional> windowReversal, int64_t inputBatchDimension, + int64_t inputFeatureDimension, ArrayRef inputSpatialDimensions, int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension, ArrayRef kernelSpatialDimensions, int64_t outputBatchDimension, int64_t outputFeatureDimension, ArrayRef outputSpatialDimensions, diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index 771c39ab8cd..8a6b2cc3b11 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -176,13 +176,12 @@ LogicalResult inferConvertOp( LogicalResult inferConvolutionOp( std::optional location, Type lhsType, Type rhsType, - std::optional windowStrides, + std::optional> windowStrides, std::optional padding, - std::optional lhsDilation, - std::optional rhsDilation, - std::optional windowReversal, - int64_t inputBatchDimension, int64_t inputFeatureDimension, - ArrayRef inputSpatialDimensions, + std::optional> lhsDilation, + std::optional> rhsDilation, + std::optional> windowReversal, int64_t inputBatchDimension, + int64_t inputFeatureDimension, ArrayRef inputSpatialDimensions, int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension, ArrayRef kernelSpatialDimensions, int64_t outputBatchDimension, int64_t outputFeatureDimension, ArrayRef outputSpatialDimensions, @@ -234,7 +233,7 @@ LogicalResult inferGatherOp( std::optional location, Value operand, Value startIndices, ArrayRef offsetDims, ArrayRef collapsedSliceDims, ArrayRef startIndexMap, int64_t indexVectorDim, - DenseIntElementsAttr sliceSizes, + ArrayRef sliceSizes, SmallVectorImpl& inferredReturnShapes); LogicalResult inferGetDimensionSizeOp( @@ -400,13 +399,12 @@ LogicalResult verifyCollectivePermuteOp(std::optional location, LogicalResult verifyConvolutionOp( std::optional location, Type lhsType, Type rhsType, - std::optional windowStrides, + std::optional> windowStrides, std::optional padding, - std::optional lhsDilation, - std::optional rhsDilation, - std::optional windowReversal, - int64_t inputBatchDimension, int64_t inputFeatureDimension, - ArrayRef inputSpatialDimensions, + std::optional> lhsDilation, + std::optional> rhsDilation, + std::optional> windowReversal, int64_t inputBatchDimension, + int64_t inputFeatureDimension, ArrayRef inputSpatialDimensions, int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension, ArrayRef kernelSpatialDimensions, int64_t outputBatchDimension, int64_t outputFeatureDimension, ArrayRef outputSpatialDimensions, diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index f7cb3da317b..d8db0afa975 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -4035,8 +4035,7 @@ func.func @gather_c14(%operand : tensor<*xi32>, %start_indices : tensor, %start_indices : tensor<1x5x2xi32>) -> tensor<1x5x8xi32> { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{slice_sizes.rank != 1}} + // expected-error@+1 {{attribute 'slice_sizes' failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr.}} %res = "stablehlo.gather"(%operand, %start_indices) { dimension_numbers = #stablehlo.gather< offset_dims = [2], @@ -5787,7 +5786,7 @@ func.func @dynamic_iota_output_shape_mismatching_size() -> tensor<4xf32> { func.return %1 : tensor<4xf32> } -// Tests for I64DenseArrayOrElementsAttr. +// Tests for I64DenseArrayOrElements1DAttr. // ----- @@ -5805,3 +5804,32 @@ func.func @broadcast_in_dim_dense_array(%arg0: tensor<1x2xi32>) -> tensor<1x2x2x func.return %0 : tensor<1x2x2xi32> } +// Tests for BoolDenseArrayOrElementsAttr. + +// ----- + +// CHECK-LABEL: func @convolution_elements +// CHECK: reverse = [true, false] +func.func @convolution_elements(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> { + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_reversal = dense<[true, false]> : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> + func.return %0 : tensor<1x6x6x16xf32> +} + +// ----- + +// CHECK-LABEL: func @convolution_array +// CHECK: reverse = [true, false] +func.func @convolution_array(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> { + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_reversal = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> + func.return %0 : tensor<1x6x6x16xf32> +} diff --git a/stablehlo/tests/ops_stablehlo_roundtrip.mlir b/stablehlo/tests/ops_stablehlo_roundtrip.mlir index 2e6176f47ff..d32c295ebb3 100644 --- a/stablehlo/tests/ops_stablehlo_roundtrip.mlir +++ b/stablehlo/tests/ops_stablehlo_roundtrip.mlir @@ -262,7 +262,7 @@ func.func @test_convolution3(%arg0 : tensor<100x26x26x32xi8>, %arg1 : tensor<3x3 padding = dense<2> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>, - window_reversal = dense<1> : tensor<2xi1> + window_reversal = dense : tensor<2xi1> } : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> func.return %result : tensor<100x28x28x1xi32> } diff --git a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp index ec8e35cd8bb..c61ba027a74 100644 --- a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +++ b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp @@ -151,6 +151,19 @@ Attribute convertGeneric(Attribute stablehloAttr, return vhlo::TensorV1Attr::get(attr.getContext(), vhloType, attr.getRawData()); } + if (auto attr = stablehloAttr.dyn_cast()) { + // Dense arrays of bool need special handling as their raw data is + // encoded differently from a DenseElementsAttr of bool. Using a similar + // conversion as for DenseI64ArrayAttr causes issues when attempting to + // construct a DenseIntOrFPElementsAttr from the underlying raw data. + // Converting to a DenseElementsAttr up front avoids issues. + auto data = attr.asArrayRef(); + auto ty = + RankedTensorType::get({static_cast(attr.asArrayRef().size())}, + IntegerType::get(attr.getContext(), 1)); + auto dense = DenseElementsAttr::get(ty, data); + return convertGeneric(dense, typeConverter); + } if (auto attr = stablehloAttr.dyn_cast()) { SmallVector> vhloAttrs; for (auto namedAttr : attr.getValue()) { diff --git a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp index ce1c2c3000e..89e32fcb0e4 100644 --- a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp @@ -426,15 +426,15 @@ LogicalResult implodeSpecial(const OpConversionPattern& pattern, return success(); } +template SpecialResult convertDenseArray(StringAttr vhloName, Attribute vhloAttr, SmallVector& stablehloAttrs) { auto tensorAttr = dyn_cast(vhloAttr); if (!tensorAttr) return specialFailure(); - auto data = ArrayRef( - reinterpret_cast(tensorAttr.getData().data()), - tensorAttr.getData().size() / sizeof(int64_t)) - .vec(); + auto data = SmallVector( + ArrayRef(reinterpret_cast(tensorAttr.getData().data()), + tensorAttr.getData().size() / sizeof(T))); // Handle splats if (data.size() == 1) { @@ -445,11 +445,17 @@ SpecialResult convertDenseArray(StringAttr vhloName, Attribute vhloAttr, data.resize(size, data[0]); } - stablehloAttrs.emplace_back( - vhloName, DenseI64ArrayAttr::get(vhloAttr.getContext(), data)); + stablehloAttrs.emplace_back(vhloName, Attr::get(vhloAttr.getContext(), data)); return specialSuccess(); } +SpecialResult convertDenseI64Array( + StringAttr vhloName, Attribute vhloAttr, + SmallVector& stablehloAttrs) { + return convertDenseArray(vhloName, vhloAttr, + stablehloAttrs); +} + template SpecialResult convertSpecial(const OpConversionPattern& pattern, StringAttr vhloName, Attribute vhloAttr, @@ -493,44 +499,44 @@ SpecialResult convertSpecial(const OpConversionPattern& pattern, } if constexpr (std::is_same::value) { if (vhloName == "fft_length") - return convertDenseArray(vhloName, vhloAttr, stablehloAttrs); + return convertDenseI64Array(vhloName, vhloAttr, stablehloAttrs); } if constexpr (std::is_same::value) { if (vhloName == "broadcast_sizes") - return convertDenseArray(vhloName, vhloAttr, stablehloAttrs); + return convertDenseI64Array(vhloName, vhloAttr, stablehloAttrs); } if constexpr (std::is_same::value) { if (vhloName == "slice_sizes") - return convertDenseArray(vhloName, vhloAttr, stablehloAttrs); + return convertDenseI64Array(vhloName, vhloAttr, stablehloAttrs); } if constexpr (std::is_same::value) { if (vhloName == "dimensions") - return convertDenseArray(vhloName, vhloAttr, stablehloAttrs); + return convertDenseI64Array(vhloName, vhloAttr, stablehloAttrs); } if constexpr (std::is_same::value) { if (vhloName == "permutation") - return convertDenseArray(vhloName, vhloAttr, stablehloAttrs); + return convertDenseI64Array(vhloName, vhloAttr, stablehloAttrs); } if constexpr (std::is_same::value) { if (vhloName == "edge_padding_low" || vhloName == "edge_padding_high" || vhloName == "interior_padding") - return convertDenseArray(vhloName, vhloAttr, stablehloAttrs); + return convertDenseI64Array(vhloName, vhloAttr, stablehloAttrs); } if constexpr (std::is_same::value) { if (vhloName == "start_indices" || vhloName == "limit_indices" || vhloName == "strides") - return convertDenseArray(vhloName, vhloAttr, stablehloAttrs); + return convertDenseI64Array(vhloName, vhloAttr, stablehloAttrs); } if constexpr (std::is_same::value) { if (vhloName == "broadcast_dimensions") - return convertDenseArray(vhloName, vhloAttr, stablehloAttrs); + return convertDenseI64Array(vhloName, vhloAttr, stablehloAttrs); } if constexpr (std::is_same::value) { if (vhloName == "broadcast_dimensions" || vhloName == "known_expanding_dimensions" || vhloName == "known_nonexpanding_dimensions") - return convertDenseArray(vhloName, vhloAttr, stablehloAttrs); + return convertDenseI64Array(vhloName, vhloAttr, stablehloAttrs); } return notSpecial(); }