Skip to content

Commit

Permalink
Fix for incorrect for loop dims
Browse files Browse the repository at this point in the history
  • Loading branch information
arpitj1 committed Aug 28, 2024
1 parent 77c8168 commit 98f0119
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
21 changes: 20 additions & 1 deletion lib/polygeist/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5944,7 +5944,26 @@ OpFoldResult mlir::polygeist::SubmapOp::fold(mlir::polygeist::SubmapOp::FoldAdap
return nullptr;
}


class DimSubMap final : public OpRewritePattern<memref::DimOp> {
public:
using OpRewritePattern<memref::DimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(memref::DimOp op,
PatternRewriter &rewriter) const override {
auto subMapOp = op.getSource().getDefiningOp<polygeist::SubmapOp>();
if (!subMapOp) return failure();

auto idx = op.getIndex().getDefiningOp<arith::ConstantIndexOp>();
if (!idx) return failure();

rewriter.replaceOp(op, subMapOp.getSizes()[idx.value()]);

return success();
}
};

void polygeist::SubmapOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<LoadSubMap, StoreSubMap>(context);
results.insert<LoadSubMap, StoreSubMap, DimSubMap>(context);
}
20 changes: 12 additions & 8 deletions lib/polygeist/Passes/RaiseToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) {
// check_reduction is set true, when passed from store/linalg.generic's output variable.
// And it is returned true, only if index was not encountered in oldmap operands and check_reduction was set true.
Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap,
Value memref_val, Value index, Value bound, int firstNDims, ValueRange oldmap_operands, bool &check_reduction) {
Value memref_val, Value index, Value bound, int firstNDims, ValueRange oldmap_operands, Value origmemref, bool &check_reduction) {
assert(oldmap_operands.size() == oldmap.getNumSymbols() + oldmap.getNumDims());
//Operands which don't correspond to indices
SmallVector<Value> operands_without_indices;
Expand Down Expand Up @@ -193,7 +193,11 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap,
SmallVector<Value> idx_sizes;
for (size_t i=0; i<firstNDims; i++) {
//memref.dimOp captures the size of the memref
idx_sizes.push_back(builder.create<memref::DimOp>(memref_val.getLoc(), memref_val, i));
if (auto submap = origmemref.getDefiningOp<polygeist::SubmapOp>())
idx_sizes.push_back(submap.getSizes()[i]);
else
llvm_unreachable("Won't reach this case");
//idx_sizes.push_back(builder.create<memref::DimOp>(origmemref.getLoc(), origmemref, i));
}
idx_sizes.push_back(bound);

Expand Down Expand Up @@ -621,7 +625,7 @@ struct AffineForOpRaising : public OpRewritePattern<affine::AffineForOp> {

int idx = 0;
// Iterate over input arguments
for (Value input : lg.getInputs()) {
for (const Value input : lg.getInputs()) {
// Is this needed?
if (conds.size() != 0)
return failure();
Expand Down Expand Up @@ -673,7 +677,7 @@ struct AffineForOpRaising : public OpRewritePattern<affine::AffineForOp> {
check_reduction = false;
auto newMemref = remap_in_affine_dim(
legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize,
firstNDims, ValueRange(lgOperands), check_reduction);
firstNDims, ValueRange(lgOperands), input, check_reduction);


if (!legal)
Expand All @@ -688,7 +692,7 @@ struct AffineForOpRaising : public OpRewritePattern<affine::AffineForOp> {
}

// Iterate over output arguments
for (Value output : lg.getOutputs()) {
for (const Value output : lg.getOutputs()) {
// Is this needed?
if (conds.size() != 0)
return failure();
Expand All @@ -712,7 +716,7 @@ struct AffineForOpRaising : public OpRewritePattern<affine::AffineForOp> {
size_t firstNDims = lgMap.getNumDims();
check_reduction = true;
auto newMemref = remap_in_affine_dim(
legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands), check_reduction);
legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands), output, check_reduction);
if (!legal)
return failure();

Expand Down Expand Up @@ -741,7 +745,7 @@ struct AffineForOpRaising : public OpRewritePattern<affine::AffineForOp> {
auto newMemref = remap_in_affine_dim(
legal, rewriter, load.getAffineMap(), load.getMemref(),
loop.getInductionVar(), loopSize, firstNDims,
load.getMapOperands(), check_reduction);
load.getMapOperands(), load.getMemref(), check_reduction);

if (!legal)
return failure();
Expand Down Expand Up @@ -769,7 +773,7 @@ struct AffineForOpRaising : public OpRewritePattern<affine::AffineForOp> {
auto newMemref = remap_in_affine_dim(
legal, rewriter, store.getAffineMap(), store.getMemref(),
loop.getInductionVar(), loopSize, firstNDims,
store.getMapOperands(), check_reduction);
store.getMapOperands(), store.getMemref(), check_reduction);

if (!legal)
return failure();
Expand Down

0 comments on commit 98f0119

Please sign in to comment.