From bc5713846727d01eacae92b6b5e808418919bd9e Mon Sep 17 00:00:00 2001 From: mlevesquedion Date: Fri, 5 Jan 2024 17:48:11 -0800 Subject: [PATCH] Port ReduceWindowOp to I64DenseArrayOrElements1DAttr (#1903) In the same vein as https://github.com/openxla/stablehlo/pull/1893. https://github.com/openxla/stablehlo/issues/1578 --- .../transforms/StablehloToLinalgReduce.cpp | 31 +++++----- stablehlo/dialect/Base.cpp | 7 ++- stablehlo/dialect/Base.h | 13 +++++ stablehlo/dialect/StablehloOps.td | 18 +++--- stablehlo/dialect/TypeInference.cpp | 56 +++++++------------ stablehlo/dialect/TypeInference.h | 16 +++--- stablehlo/reference/Ops.cpp | 9 +-- stablehlo/tests/verify_reduce_window.mlir | 8 +-- 8 files changed, 77 insertions(+), 81 deletions(-) diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgReduce.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgReduce.cpp index a6094f20274..835aa8722d3 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloToLinalgReduce.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgReduce.cpp @@ -337,8 +337,7 @@ struct ReduceWindowOpOnTensorsGenericConversion final return failure(); auto numOperands = initValues.size(); - llvm::SmallVector windowDimensions = - extract1DVector(op.getWindowDimensions()); + llvm::SmallVector windowDimensions(op.getWindowDimensions()); llvm::SmallVector padding; if (op.getPadding()) { @@ -347,17 +346,17 @@ struct ReduceWindowOpOnTensorsGenericConversion final llvm::SmallVector baseDilations; if (op.getBaseDilations()) { - baseDilations = extract1DVector(*op.getBaseDilations()); + baseDilations = *op.getBaseDilations(); } llvm::SmallVector windowStrides(windowDimensions.size(), 1); if (op.getWindowStrides()) { - windowStrides = extract1DVector(*op.getWindowStrides()); + windowStrides = *op.getWindowStrides(); } llvm::SmallVector windowDilations(windowDimensions.size(), 1); if (op.getWindowDilations()) { - windowDilations = extract1DVector(*op.getWindowDilations()); + windowDilations = *op.getWindowDilations(); } auto rank = static_cast(windowDimensions.size()); @@ -556,26 +555,26 @@ struct ReduceWindowOpConversion final return rewriter.notifyMatchFailure(op, "require paddings are all zero"); } - if (op.getBaseDilations() && !isSplatValue(*op.getBaseDilations(), 1)) { - return rewriter.notifyMatchFailure(op, "expected undilated base"); + if (auto bd = op.getBaseDilations()) { + if (!hlo::isSplatArray(*bd, 1)) { + return rewriter.notifyMatchFailure(op, "expected undilated base"); + } } int lastDim = rank - 1; SmallVector fakeWindowShapes; for (int i = 1; i < lastDim; ++i) { - fakeWindowShapes.push_back( - op.getWindowDimensions().getValues()[i]); + fakeWindowShapes.push_back(op.getWindowDimensions()[i]); } if (op.getWindowStrides() && - (op.getWindowStrides().value().getValues()[0] != 1 || - op.getWindowStrides().value().getValues()[lastDim] != 1)) { + (op.getWindowStrides().value()[0] != 1 || + op.getWindowStrides().value()[lastDim] != 1)) { return rewriter.notifyMatchFailure( op, "expected window_strides to be [1,x,y,(z),1]"); } - if (op.getWindowDimensions() && - (op.getWindowDimensions().getValues()[0] != 1 || - op.getWindowDimensions().getValues()[lastDim] != 1)) { + if (op.getWindowDimensions()[0] != 1 || + op.getWindowDimensions()[lastDim] != 1) { return rewriter.notifyMatchFailure( op, "expected window_dimensions to be [1,x,y,(z),1]"); } @@ -584,7 +583,7 @@ struct ReduceWindowOpConversion final SmallVector vec; if (op.getWindowStridesAttr()) { for (int i = 1; i < lastDim; ++i) { - vec.push_back(op.getWindowStrides().value().getValues()[i]); + vec.push_back(op.getWindowStrides().value()[i]); } } else { vec.assign(rank - 2, 1); @@ -595,7 +594,7 @@ struct ReduceWindowOpConversion final vec.clear(); if (op.getWindowDilations()) { for (int i = 1; i < lastDim; ++i) { - vec.push_back(op.getWindowDilations().value().getValues()[i]); + vec.push_back(op.getWindowDilations().value()[i]); } } else { vec.assign(rank - 2, 1); diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp index c35bd8a0029..20bc88dbea1 100644 --- a/stablehlo/dialect/Base.cpp +++ b/stablehlo/dialect/Base.cpp @@ -600,8 +600,11 @@ ShapedType createShapedType(ShapedTypeComponents components) { return UnrankedTensorType::get(components.getElementType()); } -// TODO(#1578): Remove this code once all uses of I64DenseArrayOrElements1DAttr -// have been removed. +bool isSplatArray(ArrayRef arr, int64_t val) { + return std::all_of(arr.begin(), arr.end(), + [val](int64_t x) { return x == val; }); +} + SmallVector getI64Array(Attribute attr) { if (!attr) return {}; if (auto elements = attr.dyn_cast()) diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h index 6d641de08c9..f36b9f7e207 100644 --- a/stablehlo/dialect/Base.h +++ b/stablehlo/dialect/Base.h @@ -56,6 +56,19 @@ inline static bool isStaticDimSize(int64_t val) { return !isDynamicDimSize(val); } +// Checks whether every position in the given array contains the given value. +// This is especially useful for dealing with instances of +// I64DenseArrayOrElements1DAttr, which returns a SmallVector as its +// value no matter what actual attribute is backing it. +// TODO(#1578): Remove this code once all uses of I64DenseArrayOrElements1DAttr +// have been removed. +bool isSplatArray(ArrayRef arr, int64_t val); + +// Returns a vector of the int64 values in a I64DenseArrayOrElements1DAttr. +// Such an Attr can be backed by either a 1-dimensional DenseIntElementsAttr or +// a DenseI64ArrayAttr. +// TODO(#1578): Remove this code once all uses of I64DenseArrayOrElements1DAttr +// have been removed. SmallVector getI64Array(Attribute); // Verifies that the two types have compatible shape with bounds but allows diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index db80e6e40ac..6d70ba80161 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2870,10 +2870,10 @@ def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [ %0 = stablehlo.add %arg0, %arg1 : tensor stablehlo.return %0 : tensor }) { - window_dimensions = dense<[2, 1]> : tensor<2xi64>, - window_strides = dense<[4, 1]> : tensor<2xi64>, - base_dilations = dense<[2, 1]> : tensor<2xi64>, - window_dilations = dense<[3, 1]> : tensor<2xi64>, + window_dimensions = array, + window_strides = array, + base_dilations = array, + window_dilations = array, padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> } : (tensor<3x2xi64>, tensor) -> tensor<2x2xi64> ``` @@ -2882,13 +2882,13 @@ def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [ let arguments = (ins Variadic:$inputs /*reduce_window_i1*/, Variadic:$init_values /*reduce_window_i2*/, - I64ElementsAttr:$window_dimensions /*reduce_window_i3*/, + I64DenseArrayOrElements1DAttr:$window_dimensions /*reduce_window_i3*/, // If strides or dilations attributes are missing then the default value is // one for each of the operand dimensions. Similarly, padding values are zero // for both low and high in each of the dimensions, if not specified. - OptionalAttr:$window_strides /*reduce_window_i4*/, - OptionalAttr:$base_dilations /*reduce_window_i5*/, - OptionalAttr:$window_dilations /*reduce_window_i6*/, + OptionalAttr:$window_strides /*reduce_window_i4*/, + OptionalAttr:$base_dilations /*reduce_window_i5*/, + OptionalAttr:$window_dilations /*reduce_window_i6*/, OptionalAttr:$padding /*reduce_window_i7*/ ); @@ -2898,6 +2898,7 @@ def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [ let hasVerifier = 1; + // Builder for non-variadic version of the operation. let builders = [ OpBuilder<(ins "Type":$result_type, "Value":$operand, @@ -2922,7 +2923,6 @@ def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [ "function_ref":$bodyBuilder )>, ]; - // TODO(hinsu): Implement custom printer and parser. } diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 6b4cc24ec6e..33ae7ffa0a6 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -715,11 +715,10 @@ LogicalResult verifyReducerShape(std::optional loc, Block& block, LogicalResult verifyReduceWindowOpInputsAndInferWindow( std::optional location, SmallVector inputTypes, - SmallVector initValueTypes, - DenseIntElementsAttr windowDimensions, - std::optional windowStrides, - std::optional baseDilations, - std::optional windowDilations, + SmallVector initValueTypes, ArrayRef windowDimensions, + std::optional> windowStrides, + std::optional> baseDilations, + std::optional> windowDilations, std::optional padding, SmallVector& windowDims, SmallVector& inferredWindow) { @@ -749,22 +748,6 @@ LogicalResult verifyReduceWindowOpInputsAndInferWindow( " is not compatible with shape at input-index ", rankedInputIdx); } - // reduce_window_i3 - auto windowDimsOrErr = - convert1DAttribute(windowDimensions, location, "window_dimensions"); - if (failed(windowDimsOrErr)) return failure(); - // reduce_window_i4 - auto windowStridesOrErr = - convert1DAttribute(windowStrides, location, "window_strides"); - if (failed(windowStridesOrErr)) return failure(); - // reduce_window_i5 - auto baseDilationsOrErr = - convert1DAttribute(baseDilations, location, "base_dilations"); - if (failed(baseDilationsOrErr)) return failure(); - // reduce_window_i6 - auto windowDilationsOrErr = - convert1DAttribute(windowDilations, location, "window_dilations"); - if (failed(windowDilationsOrErr)) return failure(); // reduce_window_c12, reduce_window_i7 auto paddingOrErr = convertPaddingAttribute(padding, location); if (failed(paddingOrErr)) return failure(); @@ -772,22 +755,23 @@ LogicalResult verifyReduceWindowOpInputsAndInferWindow( // reduce_window_c4 for (const auto inputType : inputTypes) { if (!inputType.hasRank()) continue; - if (inputType.getRank() != static_cast((*windowDimsOrErr).size())) + if (inputType.getRank() != static_cast(windowDimensions.size())) return emitOptionalError( location, "expects window-dimensions size == input rank, but got ", - "window-dimensions size: ", (*windowDimsOrErr).size(), + "window-dimensions size: ", windowDimensions.size(), " and input: ", inputType, " with rank = ", inputType.getRank(), "."); } // reduce_window_c5...reduce_window_c12 auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions( - *windowDimsOrErr, *windowStridesOrErr, *paddingOrErr, - /*lhsDilation=*/*baseDilationsOrErr, - /*rhsDilation=*/*windowDilationsOrErr, /*windowReversal=*/std::nullopt, - location); + windowDimensions, windowStrides.value_or(SmallVector{}), + *paddingOrErr, + /*lhsDilation=*/baseDilations.value_or(SmallVector{}), + /*rhsDilation=*/windowDilations.value_or(SmallVector{}), + /*windowReversal=*/std::nullopt, location); if (failed(windowOrErr)) return failure(); - windowDims.append(*windowDimsOrErr); + windowDims.append(windowDimensions.begin(), windowDimensions.end()); inferredWindow.append(*windowOrErr); return success(); } @@ -2624,10 +2608,10 @@ LogicalResult inferReduceOp( LogicalResult inferReduceWindowOp( std::optional location, ValueRange inputs, ValueRange initValues, - DenseIntElementsAttr windowDimensions, - std::optional windowStrides, - std::optional baseDilations, - std::optional windowDilations, + ArrayRef windowDimensions, + std::optional> windowStrides, + std::optional> baseDilations, + std::optional> windowDilations, std::optional padding, Region& body, SmallVectorImpl& inferredReturnShapes) { SmallVector inputTypes{llvm::map_range( @@ -3913,10 +3897,10 @@ LogicalResult verifyReduceScatterOp(std::optional location, LogicalResult verifyReduceWindowOp( std::optional location, ValueRange inputs, ValueRange initValues, - DenseIntElementsAttr windowDimensions, - std::optional windowStrides, - std::optional baseDilations, - std::optional windowDilations, + ArrayRef windowDimensions, + std::optional> windowStrides, + std::optional> baseDilations, + std::optional> windowDilations, std::optional padding, Region& body) { SmallVector inputTypes{llvm::map_range( inputs.getTypes(), [](Type t) { return t.cast(); })}; diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index 12e19e319e6..771c39ab8cd 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -290,10 +290,10 @@ LogicalResult inferReduceOp( LogicalResult inferReduceWindowOp( std::optional location, ValueRange inputs, ValueRange initValues, - DenseIntElementsAttr windowDimensions, - std::optional windowStrides, - std::optional baseDilations, - std::optional windowDilations, + ArrayRef windowDimensions, + std::optional> windowStrides, + std::optional> baseDilations, + std::optional> windowDilations, std::optional padding, Region& body, SmallVectorImpl& inferredReturnShapes); @@ -477,10 +477,10 @@ LogicalResult verifyReduceScatterOp(std::optional location, LogicalResult verifyReduceWindowOp( std::optional location, ValueRange inputs, ValueRange initValues, - DenseIntElementsAttr windowDimensions, - std::optional windowStrides, - std::optional baseDilations, - std::optional windowDilations, + ArrayRef windowDimensions, + std::optional> windowStrides, + std::optional> baseDilations, + std::optional> windowDilations, std::optional padding, Region& body); LogicalResult verifyReshapeOp(std::optional location, Value operand, diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 64a04fcebdb..085b98068f5 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -603,18 +603,15 @@ SmallVector eval(Region ®ion, Sizes windowStrides(rank, 1); if (auto windowStridesAttr = reduceWindowOp.getWindowStrides()) - windowStrides.assign(windowStridesAttr->value_begin(), - windowStridesAttr->value_end()); + windowStrides = Sizes(*windowStridesAttr); Sizes baseDilations(rank, 1); if (auto baseDilationsAttr = reduceWindowOp.getBaseDilations()) - baseDilations.assign(baseDilationsAttr->value_begin(), - baseDilationsAttr->value_end()); + baseDilations = Sizes(*baseDilationsAttr); Sizes windowDilations(rank, 1); if (auto windowDilationsAttr = reduceWindowOp.getWindowDilations()) - windowDilations.assign(windowDilationsAttr->value_begin(), - windowDilationsAttr->value_end()); + windowDilations = Sizes(*windowDilationsAttr); Sizes paddingLow(rank, 0), paddingHigh(rank, 0); if (auto paddingAttr = reduceWindowOp.getPadding()) { diff --git a/stablehlo/tests/verify_reduce_window.mlir b/stablehlo/tests/verify_reduce_window.mlir index 62d9b4d80e0..80ab641e0a3 100644 --- a/stablehlo/tests/verify_reduce_window.mlir +++ b/stablehlo/tests/verify_reduce_window.mlir @@ -684,7 +684,7 @@ func.func @reduce_window_i2(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, func.func @reduce_window_i3(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { - // expected-error@+1 {{expects the shape of window_dimensions attribute to be 1-D, but got {1, 2}}} + // expected-error@+1 {{attribute 'window_dimensions' failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr.}} %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): @@ -705,7 +705,7 @@ func.func @reduce_window_i3(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, func.func @reduce_window_i4(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { - // expected-error@+1 {{expects the shape of window_strides attribute to be 1-D, but got {1, 2}}} + // expected-error@+1 {{attribute 'window_strides' failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr.}} %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): @@ -726,7 +726,7 @@ func.func @reduce_window_i4(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, func.func @reduce_window_i5(%arg0: tensor<*xf32>, %arg1: tensor<4x?xi32>, %init0: tensor, %init1: tensor) -> (tensor, tensor<*xi32>) { - // expected-error@+1 {{expects the shape of base_dilations attribute to be 1-D, but got {1, 2}}} + // expected-error@+1 {{attribute 'base_dilations' failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr.}} %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): @@ -749,7 +749,7 @@ func.func @reduce_window_i5(%arg0: tensor<*xf32>, func.func @reduce_window_i6(%arg0: tensor<*xf32>, %arg1: tensor<4x?xi32>, %init0: tensor, %init1: tensor) -> (tensor, tensor<*xi32>) { - // expected-error@+1 {{expects the shape of window_dilations attribute to be 1-D, but got {1, 2}}} + // expected-error@+1 {{attribute 'window_dilations' failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr.}} %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor):