Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an expander pattern for GatherOp/ScatterOp with batching dims #2555

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,170 @@ func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor
%3 = stablehlo.imag %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
func.return %2, %3 : tensor<4xf64>, tensor<4xf64>
}

// -----

// CHECK-LABEL: @gather_with_batching_dims
// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32>
// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{
// CHECK-SAME: dimension_numbers = #stablehlo.gather<
// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3],
// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>,
// CHECK-SAME: indices_are_sorted = false,
// CHECK-SAME: slice_sizes = array<i64: 1, 1, 1, 1, 8>
// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>) -> tensor<4x3x5x8xi32>
// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32>
func.func @gather_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> {
// CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2]
// CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0]
%0 = "stablehlo.gather"(%arg0, %arg1) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3],
collapsed_slice_dims = [1, 3],
operand_batching_dims = [0, 2],
start_indices_batching_dims = [1, 0],
start_index_map = [1, 3],
index_vector_dim = 3
>,
slice_sizes = array<i64: 1, 1, 1, 1, 8>,
indices_are_sorted = true
} : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32>
func.return %0 : tensor<4x3x5x8xi32>
}

// -----

// CHECK-LABEL: @gather_with_batching_no_index_vector_dim
// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32>
// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{
// CHECK-SAME: dimension_numbers = #stablehlo.gather<
// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2],
// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>,
// CHECK-SAME: indices_are_sorted = false,
// CHECK-SAME: slice_sizes = array<i64: 1, 1, 1, 8>
// CHECK-SAME: }> : (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>) -> tensor<4x3x5x8xi32>
// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32>
func.func @gather_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> {
// CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2]
// CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0]
%0 = "stablehlo.gather"(%arg0, %arg1) <{
dimension_numbers = #stablehlo.gather<
offset_dims = [3],
collapsed_slice_dims = [1],
operand_batching_dims = [0, 2],
start_indices_batching_dims = [1, 0],
start_index_map = [1],
index_vector_dim = 3
>,
slice_sizes = array<i64: 1, 1, 1, 8>,
indices_are_sorted = true
}> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32>
func.return %0 : tensor<4x3x5x8xi32>
}

// -----

// CHECK-LABEL: @gather_with_batching_dim_size_zero
// CHECK-NEXT: %[[iota:.*]] = stablehlo.iota dim = 0 : tensor<0x3x5x1xi32>
// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota]], %arg1, dim = 3 : (tensor<0x3x5x1xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x2xi32>
// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{
// CHECK-SAME: dimension_numbers = #stablehlo.gather<
// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1],
// CHECK-SAME: start_index_map = [0, 1], index_vector_dim = 3>,
// CHECK-SAME: indices_are_sorted = false,
// CHECK-SAME: slice_sizes = array<i64: 0, 1, 8>
// CHECK-SAME: }> : (tensor<0x2x9xi32>, tensor<0x3x5x2xi32>) -> tensor<0x3x5x8xi32>
// CHECK-NEXT: return %[[gather]] : tensor<0x3x5x8xi32>
func.func @gather_with_batching_dim_size_zero(%arg0: tensor<0x2x9xi32>, %arg1: tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> {
// CHECK-NO-DOWNGRADE: operand_batching_dims = [0]
// CHECK-NO-DOWNGRADE: start_indices_batching_dims = [0]
%0 = "stablehlo.gather"(%arg0, %arg1) <{
dimension_numbers = #stablehlo.gather<
offset_dims = [3],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [0],
start_index_map = [1],
index_vector_dim = 3
>,
slice_sizes = array<i64: 0, 1, 8>,
indices_are_sorted = true
}> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32>
func.return %0 : tensor<0x3x5x8xi32>
}

// -----

// CHECK-LABEL: @scatter_with_batching_dims
// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32>
// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{
// CHECK-SAME: indices_are_sorted = false,
// CHECK-SAME: dimension_numbers = #stablehlo.scatter<
// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3],
// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>,
// CHECK-SAME: unique_indices = false}>
// CHECK: (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32>
// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x7x9xi32>
func.func @scatter_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> {
// CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2]
// CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0]
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{
indices_are_sorted = true,
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3],
inserted_window_dims = [1, 3],
input_batching_dims = [0, 2],
scatter_indices_batching_dims = [1, 0],
scatter_dims_to_operand_dims = [1, 3],
index_vector_dim = 3
>,
unique_indices = false
}> ({
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
stablehlo.return %arg4 : tensor<i32>
}) : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32>
func.return %0 : tensor<3x2x4x7x9xi32>
}

// -----

// CHECK-LABEL: @scatter_with_batching_no_index_vector_dim
// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32>
// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32>
// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{
// CHECK-SAME: indices_are_sorted = false,
// CHECK-SAME: dimension_numbers = #stablehlo.scatter<
// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2],
// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1], index_vector_dim = 3>,
// CHECK-SAME: unique_indices = true}>
// CHECK: (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32>
// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x9xi32>
func.func @scatter_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> {
// CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2]
// CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0]
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{
indices_are_sorted = true,
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3],
inserted_window_dims = [1],
input_batching_dims = [0, 2],
scatter_indices_batching_dims = [1, 0],
scatter_dims_to_operand_dims = [1],
index_vector_dim = 3
>,
unique_indices = true
}> ({
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
stablehlo.return %arg4 : tensor<i32>
}) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32>
func.return %0 : tensor<3x2x4x9xi32>
}
144 changes: 142 additions & 2 deletions stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@ limitations under the License.

#include <fcntl.h>

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <iterator>
#include <utility>

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -58,6 +66,132 @@ vhlo::Version validateTargetVersion(llvm::StringRef versionRef) {
return targetVersion;
}

SmallVector<int64_t> mergeSortedDims(ArrayRef<int64_t> dims1,
ArrayRef<int64_t> dims2) {
SmallVector<int64_t> result;
result.reserve(dims1.size() + dims2.size());
std::merge(dims1.begin(), dims1.end(), dims2.begin(), dims2.end(),
std::back_inserter(result));
return result;
}

// Returns an updated indices tensor such that an `IotaOp` is prepended for each
// dim in `indicesBatchingDims` with a `ConcatenateOp`.
//
// If `indexVectorDim` is equal to the rank of `indices`, it is reshaped to have
// a trailing dimension of size 1 so it can be concatenated with the `IotaOp`s.
Value createConcatIndices(Value indices, int64_t indexVectorDim,
ArrayRef<int64_t> indicesBatchingDims,
PatternRewriter &rewriter) {
Location loc = indices.getLoc();
auto indicesType = cast<RankedTensorType>(indices.getType());
bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank();

SmallVector<int64_t> iotaShape(indicesType.getShape());
if (indexVectorDimOnLastDim) {
iotaShape.push_back(1);
} else {
iotaShape[indexVectorDim] = 1;
}
auto iotaType =
RankedTensorType::get(iotaShape, indicesType.getElementType());

SmallVector<Value> indicesToConcat;
indicesToConcat.reserve(indicesBatchingDims.size() + 1);
for (int64_t batchingDim : indicesBatchingDims) {
indicesToConcat.push_back(
rewriter.create<IotaOp>(loc, iotaType, batchingDim));
}
if (indexVectorDimOnLastDim) {
indicesToConcat.push_back(
rewriter.create<ReshapeOp>(loc, iotaType, indices));
} else {
indicesToConcat.push_back(indices);
}
return rewriter.create<ConcatenateOp>(loc, indicesToConcat, indexVectorDim);
}

//===----------------------------------------------------------------------===//
// Patterns (non DRR)
//===----------------------------------------------------------------------===//

// Converts a `GatherOp` with batching dims to a `GatherOp` without batching
// dims, such that each batching dim becomes a collapsed slice dim with a
// corresponding `IotaOp` concatenated to the start indices.
class GatherWithBatchingDimsExpander : public OpRewritePattern<GatherOp> {
using OpRewritePattern<GatherOp>::OpRewritePattern;

LogicalResult matchAndRewrite(GatherOp op,
PatternRewriter &rewriter) const override {
GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers();
ArrayRef<int64_t> operandBatchingDims = dimNumbers.getOperandBatchingDims();
if (operandBatchingDims.empty()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "gather op has no batching dims";
});
}

SmallVector<int64_t> newCollapsedSliceDims = mergeSortedDims(
operandBatchingDims, dimNumbers.getCollapsedSliceDims());
SmallVector<int64_t> newStartIndexMap =
llvm::to_vector(llvm::concat<const int64_t>(
operandBatchingDims, dimNumbers.getStartIndexMap()));
Value newIndices = createConcatIndices(
op.getStartIndices(), dimNumbers.getIndexVectorDim(),
dimNumbers.getStartIndicesBatchingDims(), rewriter);
rewriter.replaceOpWithNewOp<GatherOp>(
op, op.getOperand(), newIndices,
GatherDimensionNumbersAttr::get(
op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims,
/*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{},
newStartIndexMap, dimNumbers.getIndexVectorDim()),
op.getSliceSizes(), /*indicesAreSorted=*/false);

return success();
}
};

// Converts a `ScatterOp` with batching dims to a `ScatterOp` without batching
// dims, such that each batching dim becomes an inserted window dim with a
// corresponding `IotaOp` concatenated to the scatter indices.
class ScatterWithBatchingDimsExpander : public OpRewritePattern<ScatterOp> {
using OpRewritePattern<ScatterOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers();
ArrayRef<int64_t> inputBatchingDims = dimNumbers.getInputBatchingDims();
if (inputBatchingDims.empty()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "scatter op has no batching dims";
});
}

SmallVector<int64_t> newInsertedWindowDims =
mergeSortedDims(inputBatchingDims, dimNumbers.getInsertedWindowDims());
SmallVector<int64_t> newScatterDimsToOperandDims =
llvm::to_vector(llvm::concat<const int64_t>(
inputBatchingDims, dimNumbers.getScatterDimsToOperandDims()));
Value newIndices = createConcatIndices(
op.getScatterIndices(), dimNumbers.getIndexVectorDim(),
dimNumbers.getScatterIndicesBatchingDims(), rewriter);
auto newScatterOp = rewriter.create<ScatterOp>(
op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices,
op.getUpdates(),
ScatterDimensionNumbersAttr::get(
op.getContext(), dimNumbers.getUpdateWindowDims(),
newInsertedWindowDims,
/*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{},
newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()),
/*indicesAreSorted=*/false, op.getUniqueIndices());

newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation());
rewriter.replaceOp(op, newScatterOp.getResults());

return success();
}
};

//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -107,10 +241,16 @@ struct StablehloCreateCompatibilityExpanderPass
void populateStablehloCreateCompatibilityExpanderPatterns(
RewritePatternSet *patterns, MLIRContext *context,
vhlo::Version targetVersion) {
// StableHLO GatherOp/ScatterOp with batching dims is introduced in v1.1.0.
if (targetVersion < vhlo::Version(1, 1, 0)) {
patterns
->add<GatherWithBatchingDimsExpander, ScatterWithBatchingDimsExpander>(
context);
}
// StableHLO TanOp is introduced in v1.4.0.
if (targetVersion < vhlo::Version(1, 4, 0)) {
patterns->add<TanOp_ComplexElementType_CompatiblityExpander>(context);
patterns->add<TanOp_CompatiblityExpander>(context);
patterns->add<TanOp_ComplexElementType_CompatiblityExpander,
TanOp_CompatiblityExpander>(context);
}
}

Expand Down
Loading