Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][TORCH] Add support for enable_gqa flag in SDPA op #3950

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 123 additions & 5 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,54 @@ static FailureOr<SmallVector<Value>> createTMTensorTopkOp(
return SmallVector<Value>(topkOp.getResults());
}

static FailureOr<Value>
repeatTensorElementsForDim(Operation *op, ConversionPatternRewriter &rewriter,
Type resType, Value self, int64_t repeats,
int64_t dim) {
Location loc = op->getLoc();
auto context = op->getContext();
auto selfTy = cast<BaseTensorType>(self.getType());

int64_t inputRank = selfTy.getSizes().size();
dim = toPositiveDim(dim, inputRank);
Value dimValue =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
Value dimValuePlusOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim + 1));

auto unsqueezedInfo = unsqueezeTensor(rewriter, op, self, dimValuePlusOne);
if (failed(unsqueezedInfo))
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor op");
self = *unsqueezedInfo;

Value constMinusOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
SmallVector<Value> expandShapeValueList(inputRank + 1, constMinusOne);
expandShapeValueList[dim + 1] =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(repeats));
Value expandShapeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), expandShapeValueList);

SmallVector<int64_t> expandShape(inputRank + 1);
for (int64_t i = 0; i <= dim; i++) {
expandShape[i] = selfTy.getSizes()[i];
}
expandShape[dim + 1] = repeats;
for (int64_t i = dim + 1; i < inputRank; i++) {
expandShape[i + 1] = selfTy.getSizes()[i];
}

BaseTensorType expandTy =
rewriter.getType<ValueTensorType>(expandShape, selfTy.getOptionalDtype());
Value expandSelf =
rewriter.create<AtenBroadcastToOp>(loc, expandTy, self, expandShapeList);

Value result = rewriter.create<PrimsCollapseOp>(loc, resType, expandSelf,
dimValue, dimValuePlusOne);
return result;
}

namespace {
template <typename AtenOpT>
class ConvertAtenScatterOp : public OpConversionPattern<AtenOpT> {
Expand Down Expand Up @@ -1651,6 +1699,65 @@ class ConvertAtenScaledDotProductAttentionOp
: public OpConversionPattern<AtenScaledDotProductAttentionOp> {
public:
using OpConversionPattern::OpConversionPattern;

static LogicalResult
preProcessGroupQueryAttentionInput(AtenScaledDotProductAttentionOp op,
ConversionPatternRewriter &rewriter,
const TypeConverter *typeConverter,
Value query, Value &key, Value &value) {
auto queryTy = cast<ShapedType>(query.getType());
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());

int64_t rank = queryTy.getRank();

int64_t qNumHeads = queryTy.getDimSize(rank - 3);
int64_t kNumHeads = valueTy.getDimSize(rank - 3);
int64_t vNumHeads = keyTy.getDimSize(rank - 3);

if (llvm::any_of(llvm::ArrayRef<int64_t>{qNumHeads, kNumHeads, vNumHeads},
[](int64_t d) { return d == Torch::kUnknownSize; })) {
return llvm::failure();
}

if (llvm::all_equal(
llvm::ArrayRef<int64_t>{qNumHeads, kNumHeads, vNumHeads}))
return llvm::success();

if ((qNumHeads % kNumHeads) && (qNumHeads % vNumHeads))
return llvm::failure();

int64_t repeatKeyShape = qNumHeads / kNumHeads;
int64_t repeatValueShape = qNumHeads / vNumHeads;

Location loc = op.getLoc();
FailureOr<Value> keyRepeated = repeatTensorElementsForDim(
op.getOperation(), rewriter, /*resType=*/op.getQuery().getType(),
op.getKey(),
/*repeats=*/repeatKeyShape, /*dim=*/rank - 3);
if (failed(keyRepeated))
return rewriter.notifyMatchFailure(
loc, "Failed to repeat the tensor elements for key.");

FailureOr<Value> valueRepeated = repeatTensorElementsForDim(
op.getOperation(), rewriter, /*resType=*/op.getQuery().getType(),
op.getValue(),
/*repeats=*/repeatValueShape, /*dim=*/rank - 3);
if (failed(valueRepeated))
return rewriter.notifyMatchFailure(
loc, "Failed to repeat the tensor elements for value.");

key = typeConverter->materializeTargetConversion(
rewriter, loc,
typeConverter->convertType(keyRepeated.value().getType()),
keyRepeated.value());
value = typeConverter->materializeTargetConversion(
rewriter, loc,
typeConverter->convertType(valueRepeated.value().getType()),
valueRepeated.value());
return success();
}

LogicalResult
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1795,11 +1902,6 @@ class ConvertAtenScaledDotProductAttentionOp
scaleFloat != 1.0)
return rewriter.notifyMatchFailure(loc, "only default scale supported");
}
bool isGQAEnabled;
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)) ||
isGQAEnabled)
return rewriter.notifyMatchFailure(
loc, "grouped query attention not supported");

if (queryTy.getRank() != valueTy.getRank() ||
queryTy.getRank() != keyTy.getRank())
Expand All @@ -1808,6 +1910,22 @@ class ConvertAtenScaledDotProductAttentionOp
if (queryTy.getRank() < 3)
return rewriter.notifyMatchFailure(op, "missing batch dimension");

bool isGQAEnabled;
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)))
return rewriter.notifyMatchFailure(
loc, "Expected enable_gqa flag to be constant bool");

// For the cases when `enable_gqa` flag is set to true, we have to
// pre-process the inputs specifically key and value by repeating the
// elements for the head dim.
// The reference code is available here:
// https://github.com/pytorch/pytorch/pull/132689/files#diff-e726853e9795dfb6c74ab1e10945f5d5f24540eb7bc633e5c76f69bc258f24d6R612
if (enableGQA) {
if (failed(preProcessGroupQueryAttentionInput(
op, rewriter, getTypeConverter(), query, key, value)))
return failure();
}

llvm::SmallVector<ReassociationIndices, 3> reassociation(3);
for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i)
reassociation.front().push_back(i);
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@
"BernoulliFloatModule_basic",
"UniformModule_basic",
"UniformStaticShapeModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand Down Expand Up @@ -3352,6 +3353,7 @@
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"ScaledDotProductAttentionGQAModule_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down Expand Up @@ -3893,6 +3895,7 @@
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
}

ONNX_TOSA_CRASHING_SET = {
Expand Down
27 changes: 27 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5742,6 +5742,33 @@ def ScaledDotProductAttentionBoolMaskModule_basic(module, tu: TestUtils):
module.forward(query, key, value, mask)


class ScaledDotProductAttentionGQAModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([4, 32, 3, 8], torch.float32, True),
([4, 8, 3, 8], torch.float32, True),
([4, 8, 3, 8], torch.float32, True),
]
)
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(
query, key, value, enable_gqa=True
)


@register_test_case(module_factory=lambda: ScaledDotProductAttentionGQAModule())
def ScaledDotProductAttentionGQAModule_basic(module, tu: TestUtils):
query = torch.randn(4, 32, 3, 8, dtype=torch.float32)
key = torch.randn(4, 8, 3, 8, dtype=torch.float32)
value = torch.randn(4, 8, 3, 8, dtype=torch.float32)
module.forward(query, key, value)


# ==============================================================================


Expand Down
Loading