diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index cb486026112b..0f1104f237ba 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -5944,7 +5944,26 @@ OpFoldResult mlir::polygeist::SubmapOp::fold(mlir::polygeist::SubmapOp::FoldAdap return nullptr; } + +class DimSubMap final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DimOp op, + PatternRewriter &rewriter) const override { + auto subMapOp = op.getSource().getDefiningOp(); + if (!subMapOp) return failure(); + + auto idx = op.getIndex().getDefiningOp(); + if (!idx) return failure(); + + rewriter.replaceOp(op, subMapOp.getSizes()[idx.value()]); + + return success(); + } +}; + void polygeist::SubmapOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 03bda7dbba02..c0bd0fe7feef 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -120,7 +120,7 @@ AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { // check_reduction is set true, when passed from store/linalg.generic's output variable. // And it is returned true, only if index was not encountered in oldmap operands and check_reduction was set true. Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, - Value memref_val, Value index, Value bound, int firstNDims, ValueRange oldmap_operands, bool &check_reduction) { + Value memref_val, Value index, Value bound, int firstNDims, ValueRange oldmap_operands, Value origmemref, bool &check_reduction) { assert(oldmap_operands.size() == oldmap.getNumSymbols() + oldmap.getNumDims()); //Operands which don't correspond to indices SmallVector operands_without_indices; @@ -193,7 +193,11 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, SmallVector idx_sizes; for (size_t i=0; i(memref_val.getLoc(), memref_val, i)); + if (auto submap = origmemref.getDefiningOp()) + idx_sizes.push_back(submap.getSizes()[i]); + else + llvm_unreachable("Won't reach this case"); + //idx_sizes.push_back(builder.create(origmemref.getLoc(), origmemref, i)); } idx_sizes.push_back(bound); @@ -621,7 +625,7 @@ struct AffineForOpRaising : public OpRewritePattern { int idx = 0; // Iterate over input arguments - for (Value input : lg.getInputs()) { + for (const Value input : lg.getInputs()) { // Is this needed? if (conds.size() != 0) return failure(); @@ -673,7 +677,7 @@ struct AffineForOpRaising : public OpRewritePattern { check_reduction = false; auto newMemref = remap_in_affine_dim( legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, - firstNDims, ValueRange(lgOperands), check_reduction); + firstNDims, ValueRange(lgOperands), input, check_reduction); if (!legal) @@ -688,7 +692,7 @@ struct AffineForOpRaising : public OpRewritePattern { } // Iterate over output arguments - for (Value output : lg.getOutputs()) { + for (const Value output : lg.getOutputs()) { // Is this needed? if (conds.size() != 0) return failure(); @@ -712,7 +716,7 @@ struct AffineForOpRaising : public OpRewritePattern { size_t firstNDims = lgMap.getNumDims(); check_reduction = true; auto newMemref = remap_in_affine_dim( - legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands), check_reduction); + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands), output, check_reduction); if (!legal) return failure(); @@ -741,7 +745,7 @@ struct AffineForOpRaising : public OpRewritePattern { auto newMemref = remap_in_affine_dim( legal, rewriter, load.getAffineMap(), load.getMemref(), loop.getInductionVar(), loopSize, firstNDims, - load.getMapOperands(), check_reduction); + load.getMapOperands(), load.getMemref(), check_reduction); if (!legal) return failure(); @@ -769,7 +773,7 @@ struct AffineForOpRaising : public OpRewritePattern { auto newMemref = remap_in_affine_dim( legal, rewriter, store.getAffineMap(), store.getMemref(), loop.getInductionVar(), loopSize, firstNDims, - store.getMapOperands(), check_reduction); + store.getMapOperands(), store.getMemref(), check_reduction); if (!legal) return failure();