Skip to content

Commit

Permalink
Port ReduceWindowOp to I64DenseArrayOrElements1DAttr (#1903)
Browse files Browse the repository at this point in the history
In the same vein as #1893.

#1578
  • Loading branch information
mlevesquedion authored Jan 6, 2024
1 parent 98eeea5 commit bc57138
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 81 deletions.
31 changes: 15 additions & 16 deletions stablehlo/conversions/linalg/transforms/StablehloToLinalgReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,7 @@ struct ReduceWindowOpOnTensorsGenericConversion final
return failure();
auto numOperands = initValues.size();

llvm::SmallVector<int64_t> windowDimensions =
extract1DVector(op.getWindowDimensions());
llvm::SmallVector<int64_t> windowDimensions(op.getWindowDimensions());

llvm::SmallVector<int64_t> padding;
if (op.getPadding()) {
Expand All @@ -347,17 +346,17 @@ struct ReduceWindowOpOnTensorsGenericConversion final

llvm::SmallVector<int64_t> baseDilations;
if (op.getBaseDilations()) {
baseDilations = extract1DVector(*op.getBaseDilations());
baseDilations = *op.getBaseDilations();
}

llvm::SmallVector<int64_t> windowStrides(windowDimensions.size(), 1);
if (op.getWindowStrides()) {
windowStrides = extract1DVector(*op.getWindowStrides());
windowStrides = *op.getWindowStrides();
}

llvm::SmallVector<int64_t> windowDilations(windowDimensions.size(), 1);
if (op.getWindowDilations()) {
windowDilations = extract1DVector(*op.getWindowDilations());
windowDilations = *op.getWindowDilations();
}

auto rank = static_cast<int64_t>(windowDimensions.size());
Expand Down Expand Up @@ -556,26 +555,26 @@ struct ReduceWindowOpConversion final
return rewriter.notifyMatchFailure(op, "require paddings are all zero");
}

if (op.getBaseDilations() && !isSplatValue(*op.getBaseDilations(), 1)) {
return rewriter.notifyMatchFailure(op, "expected undilated base");
if (auto bd = op.getBaseDilations()) {
if (!hlo::isSplatArray(*bd, 1)) {
return rewriter.notifyMatchFailure(op, "expected undilated base");
}
}

int lastDim = rank - 1;
SmallVector<int64_t, 2> fakeWindowShapes;
for (int i = 1; i < lastDim; ++i) {
fakeWindowShapes.push_back(
op.getWindowDimensions().getValues<int64_t>()[i]);
fakeWindowShapes.push_back(op.getWindowDimensions()[i]);
}

if (op.getWindowStrides() &&
(op.getWindowStrides().value().getValues<int64_t>()[0] != 1 ||
op.getWindowStrides().value().getValues<int64_t>()[lastDim] != 1)) {
(op.getWindowStrides().value()[0] != 1 ||
op.getWindowStrides().value()[lastDim] != 1)) {
return rewriter.notifyMatchFailure(
op, "expected window_strides to be [1,x,y,(z),1]");
}
if (op.getWindowDimensions() &&
(op.getWindowDimensions().getValues<int64_t>()[0] != 1 ||
op.getWindowDimensions().getValues<int64_t>()[lastDim] != 1)) {
if (op.getWindowDimensions()[0] != 1 ||
op.getWindowDimensions()[lastDim] != 1) {
return rewriter.notifyMatchFailure(
op, "expected window_dimensions to be [1,x,y,(z),1]");
}
Expand All @@ -584,7 +583,7 @@ struct ReduceWindowOpConversion final
SmallVector<int64_t> vec;
if (op.getWindowStridesAttr()) {
for (int i = 1; i < lastDim; ++i) {
vec.push_back(op.getWindowStrides().value().getValues<int64_t>()[i]);
vec.push_back(op.getWindowStrides().value()[i]);
}
} else {
vec.assign(rank - 2, 1);
Expand All @@ -595,7 +594,7 @@ struct ReduceWindowOpConversion final
vec.clear();
if (op.getWindowDilations()) {
for (int i = 1; i < lastDim; ++i) {
vec.push_back(op.getWindowDilations().value().getValues<int64_t>()[i]);
vec.push_back(op.getWindowDilations().value()[i]);
}
} else {
vec.assign(rank - 2, 1);
Expand Down
7 changes: 5 additions & 2 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,11 @@ ShapedType createShapedType(ShapedTypeComponents components) {
return UnrankedTensorType::get(components.getElementType());
}

// TODO(#1578): Remove this code once all uses of I64DenseArrayOrElements1DAttr
// have been removed.
bool isSplatArray(ArrayRef<int64_t> arr, int64_t val) {
return std::all_of(arr.begin(), arr.end(),
[val](int64_t x) { return x == val; });
}

SmallVector<int64_t> getI64Array(Attribute attr) {
if (!attr) return {};
if (auto elements = attr.dyn_cast<DenseIntElementsAttr>())
Expand Down
13 changes: 13 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ inline static bool isStaticDimSize(int64_t val) {
return !isDynamicDimSize(val);
}

// Checks whether every position in the given array contains the given value.
// This is especially useful for dealing with instances of
// I64DenseArrayOrElements1DAttr, which returns a SmallVector<int64_t> as its
// value no matter what actual attribute is backing it.
// TODO(#1578): Remove this code once all uses of I64DenseArrayOrElements1DAttr
// have been removed.
bool isSplatArray(ArrayRef<int64_t> arr, int64_t val);

// Returns a vector of the int64 values in a I64DenseArrayOrElements1DAttr.
// Such an Attr can be backed by either a 1-dimensional DenseIntElementsAttr or
// a DenseI64ArrayAttr.
// TODO(#1578): Remove this code once all uses of I64DenseArrayOrElements1DAttr
// have been removed.
SmallVector<int64_t> getI64Array(Attribute);

// Verifies that the two types have compatible shape with bounds but allows
Expand Down
18 changes: 9 additions & 9 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2870,10 +2870,10 @@ def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
window_dimensions = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<[4, 1]> : tensor<2xi64>,
base_dilations = dense<[2, 1]> : tensor<2xi64>,
window_dilations = dense<[3, 1]> : tensor<2xi64>,
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
```
Expand All @@ -2882,13 +2882,13 @@ def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [
let arguments = (ins
Variadic<HLO_Tensor>:$inputs /*reduce_window_i1*/,
Variadic<HLO_Tensor>:$init_values /*reduce_window_i2*/,
I64ElementsAttr:$window_dimensions /*reduce_window_i3*/,
I64DenseArrayOrElements1DAttr:$window_dimensions /*reduce_window_i3*/,
// If strides or dilations attributes are missing then the default value is
// one for each of the operand dimensions. Similarly, padding values are zero
// for both low and high in each of the dimensions, if not specified.
OptionalAttr<I64ElementsAttr>:$window_strides /*reduce_window_i4*/,
OptionalAttr<I64ElementsAttr>:$base_dilations /*reduce_window_i5*/,
OptionalAttr<I64ElementsAttr>:$window_dilations /*reduce_window_i6*/,
OptionalAttr<I64DenseArrayOrElements1DAttr>:$window_strides /*reduce_window_i4*/,
OptionalAttr<I64DenseArrayOrElements1DAttr>:$base_dilations /*reduce_window_i5*/,
OptionalAttr<I64DenseArrayOrElements1DAttr>:$window_dilations /*reduce_window_i6*/,
OptionalAttr<I64ElementsAttr>:$padding /*reduce_window_i7*/
);

Expand All @@ -2898,6 +2898,7 @@ def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [

let hasVerifier = 1;


// Builder for non-variadic version of the operation.
let builders = [
OpBuilder<(ins "Type":$result_type, "Value":$operand,
Expand All @@ -2922,7 +2923,6 @@ def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [
"function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder
)>,
];

// TODO(hinsu): Implement custom printer and parser.
}

Expand Down
56 changes: 20 additions & 36 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,11 +715,10 @@ LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,

LogicalResult verifyReduceWindowOpInputsAndInferWindow(
std::optional<Location> location, SmallVector<ShapedType> inputTypes,
SmallVector<ShapedType> initValueTypes,
DenseIntElementsAttr windowDimensions,
std::optional<DenseIntElementsAttr> windowStrides,
std::optional<DenseIntElementsAttr> baseDilations,
std::optional<DenseIntElementsAttr> windowDilations,
SmallVector<ShapedType> initValueTypes, ArrayRef<int64_t> windowDimensions,
std::optional<ArrayRef<int64_t>> windowStrides,
std::optional<ArrayRef<int64_t>> baseDilations,
std::optional<ArrayRef<int64_t>> windowDilations,
std::optional<DenseIntElementsAttr> padding,
SmallVector<int64_t>& windowDims,
SmallVector<WindowDimension>& inferredWindow) {
Expand Down Expand Up @@ -749,45 +748,30 @@ LogicalResult verifyReduceWindowOpInputsAndInferWindow(
" is not compatible with shape at input-index ", rankedInputIdx);
}

// reduce_window_i3
auto windowDimsOrErr =
convert1DAttribute(windowDimensions, location, "window_dimensions");
if (failed(windowDimsOrErr)) return failure();
// reduce_window_i4
auto windowStridesOrErr =
convert1DAttribute(windowStrides, location, "window_strides");
if (failed(windowStridesOrErr)) return failure();
// reduce_window_i5
auto baseDilationsOrErr =
convert1DAttribute(baseDilations, location, "base_dilations");
if (failed(baseDilationsOrErr)) return failure();
// reduce_window_i6
auto windowDilationsOrErr =
convert1DAttribute(windowDilations, location, "window_dilations");
if (failed(windowDilationsOrErr)) return failure();
// reduce_window_c12, reduce_window_i7
auto paddingOrErr = convertPaddingAttribute(padding, location);
if (failed(paddingOrErr)) return failure();

// reduce_window_c4
for (const auto inputType : inputTypes) {
if (!inputType.hasRank()) continue;
if (inputType.getRank() != static_cast<int64_t>((*windowDimsOrErr).size()))
if (inputType.getRank() != static_cast<int64_t>(windowDimensions.size()))
return emitOptionalError(
location, "expects window-dimensions size == input rank, but got ",
"window-dimensions size: ", (*windowDimsOrErr).size(),
"window-dimensions size: ", windowDimensions.size(),
" and input: ", inputType, " with rank = ", inputType.getRank(), ".");
}

// reduce_window_c5...reduce_window_c12
auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
*windowDimsOrErr, *windowStridesOrErr, *paddingOrErr,
/*lhsDilation=*/*baseDilationsOrErr,
/*rhsDilation=*/*windowDilationsOrErr, /*windowReversal=*/std::nullopt,
location);
windowDimensions, windowStrides.value_or(SmallVector<int64_t, 0>{}),
*paddingOrErr,
/*lhsDilation=*/baseDilations.value_or(SmallVector<int64_t, 0>{}),
/*rhsDilation=*/windowDilations.value_or(SmallVector<int64_t, 0>{}),
/*windowReversal=*/std::nullopt, location);
if (failed(windowOrErr)) return failure();

windowDims.append(*windowDimsOrErr);
windowDims.append(windowDimensions.begin(), windowDimensions.end());
inferredWindow.append(*windowOrErr);
return success();
}
Expand Down Expand Up @@ -2624,10 +2608,10 @@ LogicalResult inferReduceOp(

LogicalResult inferReduceWindowOp(
std::optional<Location> location, ValueRange inputs, ValueRange initValues,
DenseIntElementsAttr windowDimensions,
std::optional<DenseIntElementsAttr> windowStrides,
std::optional<DenseIntElementsAttr> baseDilations,
std::optional<DenseIntElementsAttr> windowDilations,
ArrayRef<int64_t> windowDimensions,
std::optional<ArrayRef<int64_t>> windowStrides,
std::optional<ArrayRef<int64_t>> baseDilations,
std::optional<ArrayRef<int64_t>> windowDilations,
std::optional<DenseIntElementsAttr> padding, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputTypes{llvm::map_range(
Expand Down Expand Up @@ -3913,10 +3897,10 @@ LogicalResult verifyReduceScatterOp(std::optional<Location> location,

LogicalResult verifyReduceWindowOp(
std::optional<Location> location, ValueRange inputs, ValueRange initValues,
DenseIntElementsAttr windowDimensions,
std::optional<DenseIntElementsAttr> windowStrides,
std::optional<DenseIntElementsAttr> baseDilations,
std::optional<DenseIntElementsAttr> windowDilations,
ArrayRef<int64_t> windowDimensions,
std::optional<ArrayRef<int64_t>> windowStrides,
std::optional<ArrayRef<int64_t>> baseDilations,
std::optional<ArrayRef<int64_t>> windowDilations,
std::optional<DenseIntElementsAttr> padding, Region& body) {
SmallVector<ShapedType> inputTypes{llvm::map_range(
inputs.getTypes(), [](Type t) { return t.cast<ShapedType>(); })};
Expand Down
16 changes: 8 additions & 8 deletions stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,10 @@ LogicalResult inferReduceOp(

LogicalResult inferReduceWindowOp(
std::optional<Location> location, ValueRange inputs, ValueRange initValues,
DenseIntElementsAttr windowDimensions,
std::optional<DenseIntElementsAttr> windowStrides,
std::optional<DenseIntElementsAttr> baseDilations,
std::optional<DenseIntElementsAttr> windowDilations,
ArrayRef<int64_t> windowDimensions,
std::optional<ArrayRef<int64_t>> windowStrides,
std::optional<ArrayRef<int64_t>> baseDilations,
std::optional<ArrayRef<int64_t>> windowDilations,
std::optional<DenseIntElementsAttr> padding, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

Expand Down Expand Up @@ -477,10 +477,10 @@ LogicalResult verifyReduceScatterOp(std::optional<Location> location,

LogicalResult verifyReduceWindowOp(
std::optional<Location> location, ValueRange inputs, ValueRange initValues,
DenseIntElementsAttr windowDimensions,
std::optional<DenseIntElementsAttr> windowStrides,
std::optional<DenseIntElementsAttr> baseDilations,
std::optional<DenseIntElementsAttr> windowDilations,
ArrayRef<int64_t> windowDimensions,
std::optional<ArrayRef<int64_t>> windowStrides,
std::optional<ArrayRef<int64_t>> baseDilations,
std::optional<ArrayRef<int64_t>> windowDilations,
std::optional<DenseIntElementsAttr> padding, Region& body);

LogicalResult verifyReshapeOp(std::optional<Location> location, Value operand,
Expand Down
9 changes: 3 additions & 6 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,18 +603,15 @@ SmallVector<InterpreterValue> eval(Region &region,

Sizes windowStrides(rank, 1);
if (auto windowStridesAttr = reduceWindowOp.getWindowStrides())
windowStrides.assign(windowStridesAttr->value_begin<int64_t>(),
windowStridesAttr->value_end<int64_t>());
windowStrides = Sizes(*windowStridesAttr);

Sizes baseDilations(rank, 1);
if (auto baseDilationsAttr = reduceWindowOp.getBaseDilations())
baseDilations.assign(baseDilationsAttr->value_begin<int64_t>(),
baseDilationsAttr->value_end<int64_t>());
baseDilations = Sizes(*baseDilationsAttr);

Sizes windowDilations(rank, 1);
if (auto windowDilationsAttr = reduceWindowOp.getWindowDilations())
windowDilations.assign(windowDilationsAttr->value_begin<int64_t>(),
windowDilationsAttr->value_end<int64_t>());
windowDilations = Sizes(*windowDilationsAttr);

Sizes paddingLow(rank, 0), paddingHigh(rank, 0);
if (auto paddingAttr = reduceWindowOp.getPadding()) {
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/tests/verify_reduce_window.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ func.func @reduce_window_i2(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>,
func.func @reduce_window_i3(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>,
%init0: tensor<f32>, %init1: tensor<i32>) ->
(tensor<2x2xf32>, tensor<2x2xi32>) {
// expected-error@+1 {{expects the shape of window_dimensions attribute to be 1-D, but got {1, 2}}}
// expected-error@+1 {{attribute 'window_dimensions' failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr.}}
%0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({
^bb0(%a0: tensor<f32>, %a1: tensor<i32>,
%b0: tensor<f32>, %b1: tensor<i32>):
Expand All @@ -705,7 +705,7 @@ func.func @reduce_window_i3(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>,
func.func @reduce_window_i4(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>,
%init0: tensor<f32>, %init1: tensor<i32>) ->
(tensor<2x2xf32>, tensor<2x2xi32>) {
// expected-error@+1 {{expects the shape of window_strides attribute to be 1-D, but got {1, 2}}}
// expected-error@+1 {{attribute 'window_strides' failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr.}}
%0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({
^bb0(%a0: tensor<f32>, %a1: tensor<i32>,
%b0: tensor<f32>, %b1: tensor<i32>):
Expand All @@ -726,7 +726,7 @@ func.func @reduce_window_i4(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>,
func.func @reduce_window_i5(%arg0: tensor<*xf32>,
%arg1: tensor<4x?xi32>, %init0: tensor<f32>, %init1: tensor<i32>) ->
(tensor<?x?xf32>, tensor<*xi32>) {
// expected-error@+1 {{expects the shape of base_dilations attribute to be 1-D, but got {1, 2}}}
// expected-error@+1 {{attribute 'base_dilations' failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr.}}
%0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({
^bb0(%a0: tensor<f32>, %a1: tensor<i32>,
%b0: tensor<f32>, %b1: tensor<i32>):
Expand All @@ -749,7 +749,7 @@ func.func @reduce_window_i5(%arg0: tensor<*xf32>,
func.func @reduce_window_i6(%arg0: tensor<*xf32>,
%arg1: tensor<4x?xi32>, %init0: tensor<f32>, %init1: tensor<i32>) ->
(tensor<?x?xf32>, tensor<*xi32>) {
// expected-error@+1 {{expects the shape of window_dilations attribute to be 1-D, but got {1, 2}}}
// expected-error@+1 {{attribute 'window_dilations' failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr.}}
%0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({
^bb0(%a0: tensor<f32>, %a1: tensor<i32>,
%b0: tensor<f32>, %b1: tensor<i32>):
Expand Down

0 comments on commit bc57138

Please sign in to comment.