diff --git a/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir index 024f1108c92..569db7cb1d1 100644 --- a/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +++ b/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir @@ -41,3 +41,170 @@ func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor %3 = stablehlo.imag %1 : (tensor<4xcomplex>) -> tensor<4xf64> func.return %2, %3 : tensor<4xf64>, tensor<4xf64> } + +// ----- + +// CHECK-LABEL: @gather_with_batching_dims +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], +// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>) -> tensor<4x3x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32> +func.func @gather_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> { + // CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2] + // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0] + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1, 3], + operand_batching_dims = [0, 2], + start_indices_batching_dims = [1, 0], + start_index_map = [1, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> + func.return %0 : tensor<4x3x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @gather_with_batching_no_index_vector_dim +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> +// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2], +// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>) -> tensor<4x3x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32> +func.func @gather_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> { + // CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2] + // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0] + %0 = "stablehlo.gather"(%arg0, %arg1) <{ + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1], + operand_batching_dims = [0, 2], + start_indices_batching_dims = [1, 0], + start_index_map = [1], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = true + }> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> + func.return %0 : tensor<4x3x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @gather_with_batching_dim_size_zero +// CHECK-NEXT: %[[iota:.*]] = stablehlo.iota dim = 0 : tensor<0x3x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota]], %arg1, dim = 3 : (tensor<0x3x5x1xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x2xi32> +// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ +// CHECK-SAME: dimension_numbers = #stablehlo.gather< +// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1], +// CHECK-SAME: start_index_map = [0, 1], index_vector_dim = 3>, +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: slice_sizes = array +// CHECK-SAME: }> : (tensor<0x2x9xi32>, tensor<0x3x5x2xi32>) -> tensor<0x3x5x8xi32> +// CHECK-NEXT: return %[[gather]] : tensor<0x3x5x8xi32> +func.func @gather_with_batching_dim_size_zero(%arg0: tensor<0x2x9xi32>, %arg1: tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> { + // CHECK-NO-DOWNGRADE: operand_batching_dims = [0] + // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [0] + %0 = "stablehlo.gather"(%arg0, %arg1) <{ + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [0], + start_index_map = [1], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = true + }> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> + func.return %0 : tensor<0x3x5x8xi32> +} + +// ----- + +// CHECK-LABEL: @scatter_with_batching_dims +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> +// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: dimension_numbers = #stablehlo.scatter< +// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], +// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>, +// CHECK-SAME: unique_indices = false}> +// CHECK: (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> +// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x7x9xi32> +func.func @scatter_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> { + // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] + // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ + indices_are_sorted = true, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [1, 3], + input_batching_dims = [0, 2], + scatter_indices_batching_dims = [1, 0], + scatter_dims_to_operand_dims = [1, 3], + index_vector_dim = 3 + >, + unique_indices = false + }> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> + func.return %0 : tensor<3x2x4x7x9xi32> +} + +// ----- + +// CHECK-LABEL: @scatter_with_batching_no_index_vector_dim +// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> +// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> +// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> +// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> +// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ +// CHECK-SAME: indices_are_sorted = false, +// CHECK-SAME: dimension_numbers = #stablehlo.scatter< +// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2], +// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1], index_vector_dim = 3>, +// CHECK-SAME: unique_indices = true}> +// CHECK: (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> +// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x9xi32> +func.func @scatter_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> { + // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] + // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ + indices_are_sorted = true, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [1], + input_batching_dims = [0, 2], + scatter_indices_batching_dims = [1, 0], + scatter_dims_to_operand_dims = [1], + index_vector_dim = 3 + >, + unique_indices = true + }> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> + func.return %0 : tensor<3x2x4x9xi32> +} diff --git a/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp index 28113bf1a4a..0761a8d80cb 100644 --- a/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +++ b/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp @@ -12,14 +12,22 @@ limitations under the License. #include +#include #include +#include +#include +#include #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -58,6 +66,132 @@ vhlo::Version validateTargetVersion(llvm::StringRef versionRef) { return targetVersion; } +SmallVector mergeSortedDims(ArrayRef dims1, + ArrayRef dims2) { + SmallVector result; + result.reserve(dims1.size() + dims2.size()); + std::merge(dims1.begin(), dims1.end(), dims2.begin(), dims2.end(), + std::back_inserter(result)); + return result; +} + +// Returns an updated indices tensor such that an `IotaOp` is prepended for each +// dim in `indicesBatchingDims` with a `ConcatenateOp`. +// +// If `indexVectorDim` is equal to the rank of `indices`, it is reshaped to have +// a trailing dimension of size 1 so it can be concatenated with the `IotaOp`s. +Value createConcatIndices(Value indices, int64_t indexVectorDim, + ArrayRef indicesBatchingDims, + PatternRewriter &rewriter) { + Location loc = indices.getLoc(); + auto indicesType = cast(indices.getType()); + bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); + + SmallVector iotaShape(indicesType.getShape()); + if (indexVectorDimOnLastDim) { + iotaShape.push_back(1); + } else { + iotaShape[indexVectorDim] = 1; + } + auto iotaType = + RankedTensorType::get(iotaShape, indicesType.getElementType()); + + SmallVector indicesToConcat; + indicesToConcat.reserve(indicesBatchingDims.size() + 1); + for (int64_t batchingDim : indicesBatchingDims) { + indicesToConcat.push_back( + rewriter.create(loc, iotaType, batchingDim)); + } + if (indexVectorDimOnLastDim) { + indicesToConcat.push_back( + rewriter.create(loc, iotaType, indices)); + } else { + indicesToConcat.push_back(indices); + } + return rewriter.create(loc, indicesToConcat, indexVectorDim); +} + +//===----------------------------------------------------------------------===// +// Patterns (non DRR) +//===----------------------------------------------------------------------===// + +// Converts a `GatherOp` with batching dims to a `GatherOp` without batching +// dims, such that each batching dim becomes a collapsed slice dim with a +// corresponding `IotaOp` concatenated to the start indices. +class GatherWithBatchingDimsExpander : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherOp op, + PatternRewriter &rewriter) const override { + GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); + ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); + if (operandBatchingDims.empty()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "gather op has no batching dims"; + }); + } + + SmallVector newCollapsedSliceDims = mergeSortedDims( + operandBatchingDims, dimNumbers.getCollapsedSliceDims()); + SmallVector newStartIndexMap = + llvm::to_vector(llvm::concat( + operandBatchingDims, dimNumbers.getStartIndexMap())); + Value newIndices = createConcatIndices( + op.getStartIndices(), dimNumbers.getIndexVectorDim(), + dimNumbers.getStartIndicesBatchingDims(), rewriter); + rewriter.replaceOpWithNewOp( + op, op.getOperand(), newIndices, + GatherDimensionNumbersAttr::get( + op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, + /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, + newStartIndexMap, dimNumbers.getIndexVectorDim()), + op.getSliceSizes(), /*indicesAreSorted=*/false); + + return success(); + } +}; + +// Converts a `ScatterOp` with batching dims to a `ScatterOp` without batching +// dims, such that each batching dim becomes an inserted window dim with a +// corresponding `IotaOp` concatenated to the scatter indices. +class ScatterWithBatchingDimsExpander : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ScatterOp op, + PatternRewriter &rewriter) const override { + ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); + ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); + if (inputBatchingDims.empty()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "scatter op has no batching dims"; + }); + } + + SmallVector newInsertedWindowDims = + mergeSortedDims(inputBatchingDims, dimNumbers.getInsertedWindowDims()); + SmallVector newScatterDimsToOperandDims = + llvm::to_vector(llvm::concat( + inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); + Value newIndices = createConcatIndices( + op.getScatterIndices(), dimNumbers.getIndexVectorDim(), + dimNumbers.getScatterIndicesBatchingDims(), rewriter); + auto newScatterOp = rewriter.create( + op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, + op.getUpdates(), + ScatterDimensionNumbersAttr::get( + op.getContext(), dimNumbers.getUpdateWindowDims(), + newInsertedWindowDims, + /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, + newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), + /*indicesAreSorted=*/false, op.getUniqueIndices()); + + newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); + rewriter.replaceOp(op, newScatterOp.getResults()); + + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pass //===----------------------------------------------------------------------===// @@ -107,10 +241,16 @@ struct StablehloCreateCompatibilityExpanderPass void populateStablehloCreateCompatibilityExpanderPatterns( RewritePatternSet *patterns, MLIRContext *context, vhlo::Version targetVersion) { + // StableHLO GatherOp/ScatterOp with batching dims is introduced in v1.1.0. + if (targetVersion < vhlo::Version(1, 1, 0)) { + patterns + ->add( + context); + } // StableHLO TanOp is introduced in v1.4.0. if (targetVersion < vhlo::Version(1, 4, 0)) { - patterns->add(context); - patterns->add(context); + patterns->add(context); } }