Skip to content

Commit

Permalink
Port {Convolution,DynamicConv,Gather}Op to dense elements or array at…
Browse files Browse the repository at this point in the history
…trs (#1905)

In the same vein as #1893.

Also introduce an attr that can be either dense elements of bool or dense array of bool.

Not porting the MLIR source files yet.

#1578
  • Loading branch information
mlevesquedion authored Jan 8, 2024
1 parent a2d2624 commit c472822
Show file tree
Hide file tree
Showing 14 changed files with 214 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1687,7 +1687,7 @@ struct GatherConversion final : OpConversionPattern<mlir::stablehlo::GatherOp> {
int64_t resultRank = resultType.getRank();
// slice_sizes has to have the same size as operand.rank, and doing it this
// way permits an unranked operand.
int64_t operandRank = gatherOp.getSliceSizes().getNumElements();
int64_t operandRank = gatherOp.getSliceSizes().size();

int64_t indexVectorDim = gatherOp.getDimensionNumbers().getIndexVectorDim();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,22 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.h"
#include "stablehlo/conversions/linalg/transforms/Rewriters.h"
#include "stablehlo/dialect/Base.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::stablehlo {
namespace {
/// Apply dilation and padding to the input of a convolution.
Value applyConvolutionPadding(Location loc, Value input,
DenseIntElementsAttr padding,
DenseIntElementsAttr lhsDilation,
Attribute lhsDilation,
llvm::ArrayRef<int64_t> dimMappings,
OpBuilder &rewriter) {
if ((!padding || isSplatValue(padding, 0)) &&
(!lhsDilation || isSplatValue(lhsDilation, 1))) {
return input;
}
SmallVector<int64_t> lhsDilationValues;
if (lhsDilation) lhsDilationValues = hlo::getI64Array(lhsDilation);
bool noPadding = !padding || isSplatValue(padding, 0);
bool noDilation = !lhsDilation || hlo::isSplatArray(lhsDilationValues, 1);
if (noPadding && noDilation) return input;

auto inputType = cast<ShapedType>(input.getType());
int64_t rank = inputType.getRank();
Expand All @@ -58,10 +60,10 @@ Value applyConvolutionPadding(Location loc, Value input,
// Translate input dilation into interior padding.
SmallVector<int64_t, 8> padInterior(rank, 0);
if (lhsDilation) {
assert(rank == lhsDilation.size() + 2);
for (int64_t i : llvm::seq<int64_t>(0, lhsDilation.size())) {
assert(rank == static_cast<int64_t>(lhsDilationValues.size()) + 2);
for (int64_t i : llvm::seq<int64_t>(0, lhsDilationValues.size())) {
int64_t dim = dimMappings[i];
padInterior[dim] = lhsDilation.getValues<int64_t>()[i] - 1;
padInterior[dim] = lhsDilationValues[i] - 1;
}
}

Expand Down Expand Up @@ -91,8 +93,7 @@ Value applyConvolutionReversal(Location loc, OpBuilder &b,
return filter;
}
llvm::SmallVector<int64_t> reversedDims;
for (auto [idx, reversed] :
llvm::enumerate(reversals.value().getValues<bool>())) {
for (auto [idx, reversed] : llvm::enumerate(reversals.value())) {
if (reversed) {
reversedDims.push_back(
op.getDimensionNumbers().getKernelSpatialDimensions()[idx]);
Expand Down Expand Up @@ -219,8 +220,10 @@ struct NormalConvolutionOpConversion final
loc, resultType.getShape(), resultType.getElementType(), dynSizes);
Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
linalg::LinalgOp res;
Attribute strides = op.getWindowStridesAttr();
Attribute dilations = op.getRhsDilationAttr();
Attribute strides;
if (auto s = op.getWindowStrides()) strides = rewriter.getI64TensorAttr(*s);
Attribute dilations;
if (auto d = op.getRhsDilation()) dilations = rewriter.getI64TensorAttr(*d);

// Apply padding and input dilation.
llvm::SmallVector<int64_t> spatialDimMapping(rank - 2);
Expand Down Expand Up @@ -512,7 +515,7 @@ struct ConvolutionOpGeneralConversion final

AffineExpr stride = dim0;
if (op.getWindowStrides().has_value())
stride = stride * op.getWindowStrides().value().getValues<int64_t>()[i];
stride = stride * op.getWindowStrides().value()[i];
AffineExpr srcExpr = stride + dim1;

srcExprs[lhsIndexMapping[inputSpatialDimensions[i]]] = srcExpr;
Expand Down Expand Up @@ -599,7 +602,7 @@ struct DepthwiseConvolutionOpConversion final

Attribute windowStrides;
if (op.getWindowStrides()) {
windowStrides = op.getWindowStrides().value();
windowStrides = rewriter.getI64TensorAttr(op.getWindowStrides().value());
} else {
windowStrides = SplatElementsAttr::get(
VectorType::get({spatialRank}, rewriter.getI64Type()),
Expand All @@ -608,7 +611,7 @@ struct DepthwiseConvolutionOpConversion final

Attribute rhsDilation;
if (op.getRhsDilation()) {
rhsDilation = op.getRhsDilation().value();
rhsDilation = rewriter.getI64TensorAttr(op.getRhsDilation().value());
} else {
rhsDilation = SplatElementsAttr::get(
VectorType::get({spatialRank}, rewriter.getI64Type()),
Expand Down
15 changes: 14 additions & 1 deletion stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,10 +612,23 @@ SmallVector<int64_t> getI64Array(Attribute attr) {
if (auto array = attr.dyn_cast<DenseI64ArrayAttr>())
return llvm::to_vector(array.asArrayRef());
llvm::report_fatal_error(
"called i64ArrayOrElementsValues on Attribute that was neither a "
"called getI64Array on Attribute that was neither a "
"DenseIntElementsAttr or a DenseI64ArrayAttr",
false);
}

SmallVector<bool> getBoolArray(Attribute attr) {
if (!attr) return {};
if (auto elements = attr.dyn_cast<DenseIntOrFPElementsAttr>())
return llvm::to_vector(elements.getValues<bool>());
if (auto array = attr.dyn_cast<DenseBoolArrayAttr>()) {
return SmallVector<bool>(array.asArrayRef());
}
llvm::report_fatal_error(
"called getBoolArray on Attribute that was neither a "
"DenseIntOrFPElementsAttr or a DenseBoolArrayAttr",
false);
}

} // namespace hlo
} // namespace mlir
7 changes: 7 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ bool isSplatArray(ArrayRef<int64_t> arr, int64_t val);
// have been removed.
SmallVector<int64_t> getI64Array(Attribute);

// Returns a vector of the bool values in a BoolDenseArrayOrElementsAttr.
// Such an Attr can be backed by either a DenseIntOrFPElementsAttr or
// a DenseBoolArrayAttr.
// TODO(#1578): Remove this code once all uses of BoolDenseArrayOrElementsAttr
// have been removed.
SmallVector<bool> getBoolArray(Attribute);

// Verifies that the two types have compatible shape with bounds but allows
// different element types.
LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2);
Expand Down
36 changes: 21 additions & 15 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,6 @@ def StableHLO_FlatSymbolRefArrayAttr :
let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)";
}

def StableHLO_BoolElementsAttr :
ElementsAttrBase<
And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
"constant boolean vector/tensor attribute"> {
let storageType = [{ ::mlir::DenseElementsAttr }];
let returnType = [{ ::mlir::DenseElementsAttr }];

let convertFromStorage = "$_self";
}

def StableHLO_ConvDimensionNumbers : AttrDef<StableHLO_Dialect, "ConvDimensionNumbers"> {
let mnemonic = "conv";
let summary = "Structure of dimension information for conv op";
Expand All @@ -190,18 +179,35 @@ def StableHLO_ConvDimensionNumbers : AttrDef<StableHLO_Dialect, "ConvDimensionNu
let hasCustomAssemblyFormat = 1;
}

def StableHLO_BoolElementsAttr :
ElementsAttrBase<
And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
"constant boolean vector/tensor attribute"> {
let storageType = [{ ::mlir::DenseElementsAttr }];
let returnType = [{ ::mlir::DenseElementsAttr }];

let convertFromStorage = "$_self";
}

def BoolDenseArrayOrElementsAttr : Attr<Or<[DenseBoolArrayAttr.predicate, StableHLO_BoolElementsAttr.predicate]>, "either a DenseBoolArrayAttr or a StableHLO_BoolElementsAttr"> {
let storageType = "Attribute";
let returnType = "SmallVector<bool>";
let convertFromStorage = "hlo::getBoolArray($_self)";
}

def StableHLO_ConvolutionAttributes {
dag attributes = (ins
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
OptionalAttr<I64DenseArrayOrElements1DAttr>:$window_strides,
// Default value: two zeros for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
OptionalAttr<I64DenseArrayOrElements1DAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
OptionalAttr<I64DenseArrayOrElements1DAttr>:$rhs_dilation,
// Default value: false for each of the spatial dimension.
OptionalAttr<StableHLO_BoolElementsAttr>:$window_reversal,
OptionalAttr<BoolDenseArrayOrElementsAttr>:$window_reversal,
StableHLO_ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
Expand Down
79 changes: 37 additions & 42 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/InliningUtils.h"
#include "stablehlo/dialect/AssemblyFormat.h"
#include "stablehlo/dialect/Base.h"
#include "stablehlo/dialect/StablehloBytecode.h"
#include "stablehlo/dialect/StablehloOps.h.inc"
#include "stablehlo/dialect/TypeInference.h"
Expand Down Expand Up @@ -613,7 +614,7 @@ namespace {
void getSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc,
ValueRange operands,
SmallVectorImpl<Value>& sliceSizes) {
for (int64_t val : gather->getSliceSizes().getValues<int64_t>())
for (int64_t val : gather->getSliceSizes())
sliceSizes.push_back(builder.create<arith::ConstantIndexOp>(loc, val));
}

Expand Down Expand Up @@ -3090,42 +3091,30 @@ Attribute ConvDimensionNumbersAttr::parse(AsmParser& parser, Type type) {

namespace {
// Custom formatting for convolution window attributes.
void printWindowAttribute(OpAsmPrinter& p, DenseElementsAttr attribute) {
if (attribute.getElementType().isInteger(/*width=*/1)) {
// boolean attribute.
llvm::interleaveComma(attribute.getValues<bool>(), p,
[&](bool b) { p << (b ? 1 : 0); });
return;
}
if (attribute.getType().getRank() == 2) {
// Padding is Nx2 attribute.
auto it = attribute.value_begin<int64_t>();
std::vector<std::pair<int64_t, int64_t>> values(attribute.getNumElements() /
2);
for (auto& item : values) {
int64_t first = *it;
++it;
int64_t second = *it;
++it;
item = {first, second};
}
llvm::interleaveComma(
values, p, [&](const std::pair<int64_t, int64_t> pair) {
p << '[' << pair.first << ", " << pair.second << ']';
});
} else {
llvm::interleaveComma(attribute.getValues<int64_t>(), p);
void printWindowPadding(OpAsmPrinter& p, DenseElementsAttr padding) {
// Padding is Nx2 attribute.
auto it = padding.value_begin<int64_t>();
std::vector<std::pair<int64_t, int64_t>> values(padding.getNumElements() / 2);
for (auto& item : values) {
int64_t first = *it;
++it;
int64_t second = *it;
++it;
item = {first, second};
}
llvm::interleaveComma(values, p, [&](const std::pair<int64_t, int64_t> pair) {
p << '[' << pair.first << ", " << pair.second << ']';
});
}
} // namespace

void printWindowAttributes(OpAsmPrinter& p, Operation* /*op*/,
std::optional<DenseIntElementsAttr> windowStrides,
std::optional<Attribute> windowStrides,
std::optional<DenseIntElementsAttr> padding,
std::optional<DenseIntElementsAttr> lhsDilation,
std::optional<DenseIntElementsAttr> rhsDilation,
std::optional<DenseElementsAttr> windowReversal) {
using pair_t = std::pair<DenseElementsAttr, StringRef>;
std::optional<Attribute> lhsDilation,
std::optional<Attribute> rhsDilation,
std::optional<Attribute> windowReversal) {
using pair_t = std::pair<Attribute, StringRef>;
std::array<pair_t, 5> printedAttributes = {{
{windowStrides ? *windowStrides : nullptr, "stride"},
{padding ? *padding : nullptr, "pad"},
Expand All @@ -3139,19 +3128,26 @@ void printWindowAttributes(OpAsmPrinter& p, Operation* /*op*/,
printedAttributes,
[](const pair_t& a) { return static_cast<bool>(a.first); });

llvm::interleaveComma(nonNullAttributes, p, [&](const pair_t& a) {
p << a.second << " = [";
printWindowAttribute(p, a.first);
p << "]";
llvm::interleaveComma(nonNullAttributes, p, [&](const pair_t& attr) {
p << attr.second << " = [";

if (attr.second == "pad") {
printWindowPadding(p, attr.first.dyn_cast<DenseIntElementsAttr>());
} else if (attr.second == "reverse") {
llvm::interleaveComma(hlo::getBoolArray(attr.first), p);
} else {
llvm::interleaveComma(hlo::getI64Array(attr.first), p);
}

p << ']';
});
}

ParseResult parseWindowAttributes(OpAsmParser& parser,
DenseIntElementsAttr& windowStrides,
ParseResult parseWindowAttributes(OpAsmParser& parser, Attribute& windowStrides,
DenseIntElementsAttr& padding,
DenseIntElementsAttr& lhsDilation,
DenseIntElementsAttr& rhsDilation,
DenseElementsAttr& windowReversal) {
Attribute& lhsDilation,
Attribute& rhsDilation,
Attribute& windowReversal) {
StringRef attributeName;

llvm::StringSet<> allowedAttributeNames{
Expand Down Expand Up @@ -3205,9 +3201,8 @@ ParseResult parseWindowAttributes(OpAsmParser& parser,
if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
int64Parser))
return failure();
const int64_t size = static_cast<int64_t>(values.size());
if (attributeName == "reverse") {
auto ty = RankedTensorType::get({size},
auto ty = RankedTensorType::get({static_cast<int64_t>(values.size())},
parser.getBuilder().getIntegerType(1));
auto boolVector = llvm::to_vector<4>(
llvm::map_range(values, [](int64_t v) { return v != 0; }));
Expand Down
17 changes: 8 additions & 9 deletions stablehlo/dialect/StablehloOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,17 @@ ParseResult parseConvolutionDimensions(AsmParser &parser,

// Custom formatting for convolution window attributes.
void printWindowAttributes(OpAsmPrinter &p, Operation *op,
std::optional<DenseIntElementsAttr> windowStrides,
std::optional<Attribute> windowStrides,
std::optional<DenseIntElementsAttr> padding,
std::optional<DenseIntElementsAttr> lhsDilation,
std::optional<DenseIntElementsAttr> rhsDilation,
std::optional<DenseElementsAttr> windowReversal);
std::optional<Attribute> lhsDilation,
std::optional<Attribute> rhsDilation,
std::optional<Attribute> windowReversal);

ParseResult parseWindowAttributes(OpAsmParser &parser,
DenseIntElementsAttr &windowStrides,
ParseResult parseWindowAttributes(OpAsmParser &parser, Attribute &windowStrides,
DenseIntElementsAttr &padding,
DenseIntElementsAttr &lhsDilation,
DenseIntElementsAttr &rhsDilation,
DenseElementsAttr &windowReversal);
Attribute &lhsDilation,
Attribute &rhsDilation,
Attribute &windowReversal);

} // end namespace stablehlo
} // end namespace mlir
Expand Down
15 changes: 7 additions & 8 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2102,11 +2102,11 @@ def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", [Pure]> {
Example:
```mlir
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = dense<4> : tensor<2xi64>,
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = dense<2> : tensor<2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_reversal = dense<false> : tensor<2xi1>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
Expand All @@ -2126,8 +2126,7 @@ def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", [Pure]> {
let extraClassDeclaration = [{
bool hasWindowReversal() {
auto reversal = getWindowReversalAttr();
return reversal && llvm::any_of(reversal.getValues<bool>(),
[](bool v) { return v; });
return reversal && llvm::any_of(hlo::getBoolArray(reversal), [](bool v) { return v; });
}
}];

Expand Down Expand Up @@ -2387,7 +2386,7 @@ def StableHLO_GatherOp: StableHLO_Op<"gather", [InferTensorTypeWithReify /*gathe
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
slice_sizes = array<i64: 1, 2, 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
```
Expand All @@ -2397,7 +2396,7 @@ def StableHLO_GatherOp: StableHLO_Op<"gather", [InferTensorTypeWithReify /*gathe
HLO_Tensor:$operand /*gather_i1*/,
HLO_IntTensor:$start_indices /*gather_i2*/,
StableHLO_GatherDimensionNumbers:$dimension_numbers /*gather_i3, gather_i4, gather_i5, gather_i6*/,
I64ElementsAttr:$slice_sizes /*gather_i7*/,
I64DenseArrayOrElements1DAttr:$slice_sizes /*gather_i7*/,
DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted /*gather_i8*/
);

Expand Down
Loading

0 comments on commit c472822

Please sign in to comment.