Skip to content

Commit

Permalink
Add OutlineTensorRTOpPass
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhuoz004 committed Jan 21, 2025
1 parent 7965ac4 commit 250f6f4
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(_TABLEGEN_ARGS )
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TensorRTToExecutable ${_TABLEGEN_ARGS})
add_public_tablegen_target(MLIRTensorRTTensorRTToExecutableIncGen)
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,18 @@

namespace mlirtrt::compiler {

// TODO: Does this also need Tablegen'd pass?
//===----------------------------------------------------------------------===//
// Add Tablegen'd pass declarations and registration methods.
//===----------------------------------------------------------------------===//
#define GEN_PASS_DECL
#define GEN_PASS_REGISTRATION
#include "mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.h.inc"

//===----------------------------------------------------------------------===//
// Pipeline Registrations
//===----------------------------------------------------------------------===//

/// Register the TensorRT clustering and compilation pipelines.
// TODO (pranavm): How to do pipeline registration?
void registerTensorRTToExecutablePipelines();

} // namespace mlirtrt::compiler
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//===- Passes.td ----------------------------------------------------------===//
//
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE_PASSES
#define MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE_PASSES

include "mlir/Pass/PassBase.td"

//===----------------------------------------------------------------------===//
// OutlineTensorRTOpPass
//===----------------------------------------------------------------------===//
// TODO: what are the dependent dialects? what are the options?

def OutlineTensorRTOpPass : Pass<"outline-tensorrt-op",
"::mlir::ModuleOp"> {
let summary = "Outline all tensorrt ops into a tensorrt module";
}

#endif // MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@
// won't need it.
#ifdef MLIR_TRT_TARGET_TENSORRT
#include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h"

#include "mlir-executor/Runtime/API/API.h"
#include "mlir-executor/Support/Status.h"
#include "mlir-tensorrt-dialect/Utils/Options.h"
#include "mlir-tensorrt-dialect/Utils/OptionsBundle.h"
#include "mlir-tensorrt/Compiler/Client.h"
#include "mlir-tensorrt/Compiler/Extension.h"
#include "mlir-tensorrt/Compiler/OptionsProviders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/TypeID.h"

namespace mlirtrt::compiler {
Expand All @@ -38,12 +43,17 @@ namespace mlirtrt::compiler {
// TensorRTToExecutableOptions
//===----------------------------------------------------------------------===//

class TensorRTToExecutableTask;

// TODO (pranavm): Figure out a better way to reuse TRT translation options -
// maybe move to options providers?
struct TensorRTOptions
: public mlirtrt::compiler::OptionsProvider<TensorRTOptions> {
struct TensorRTOptions : public OptionsProvider<TensorRTOptions> {
public:
using OptionsProvider::OptionsProvider;
mlir::tensorrt::TensorRTTranslationOptions options;

TensorRTOptions(mlir::OptionsContext &ctx) : OptionsProvider(ctx) {}

void addToOptions(mlir::OptionsContext &context) {
options.addToOptions(context);
}
Expand All @@ -52,12 +62,10 @@ struct TensorRTOptions
struct TensorRTToExecutableOptions
: public mlir::OptionsBundle<DeviceOptions, DebugOptions, ExecutorOptions,
TensorRTOptions> {
// Default initialization does not require any extensions.
TensorRTToExecutableOptions() = default;

TensorRTToExecutableOptions(TaskExtensionRegistry extensions);

/// Initializes the options using a default extension set (TensorRT
/// extension).
StablehloToExecutableOptions();

Option<std::string> entrypoint{this, "entrypoint", llvm::cl::init("main"),
llvm::cl::desc("entrypoint function name")};
Expand All @@ -71,6 +79,8 @@ class TensorRTToExecutableTask
: public CompilationTask<TensorRTToExecutableTask,
TensorRTToExecutableOptions> {
public:
TensorRTToExecutableTask(mlir::MLIRContext *ctx,
const TensorRTToExecutableOptions &options);

/// Build the clustering pipeline that occurs on TensorRT Ops.
static void
Expand All @@ -84,13 +94,6 @@ class TensorRTToExecutableTask

static void populatePassManager(mlir::PassManager &pm,
const TensorRTToExecutableOptions &options);

/// Compile a TensorRT module into a MLIR-TensorRT Runtime executable.
/// This is the "functional" entrypoint that will allocate a new PassManager
/// for a single run.
// static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
// compileTensorRTToExecutable(CompilerClient &client, mlir::ModuleOp module,
// const TensorRTToExecutableOptions &options);
};

/// Register the task/options with the client's registry.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
add_mlir_tensorrt_library(MLIRTensorRTCompilerTensorRTToExecutable
TensorRTToExecutable.cpp
TensorRTExtension.cpp
Passes.cpp

PARTIAL_SOURCES_INTENDED

DEPENDS
MLIRTensorRTStablehloToExecutableIncGen
MLIRTensorRTTensorRTToExecutableIncGen

LINK_LIBS PUBLIC
MLIRIR
Expand Down
189 changes: 189 additions & 0 deletions mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,203 @@
#include "mlir-tensorrt/Compiler/TensorRTToExecutable/TensorRTToExecutable.h"
#include "mlir-tensorrt/Conversion/Passes.h"
#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/PassOptions.h"

#ifdef MLIR_TRT_ENABLE_HLO

namespace mlirtrt::compiler {
#define GEN_PASS_DEF_OUTLINETENSORRTOPPASS
#include "mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.h.inc"
} // namespace mlirtrt::compiler

using namespace mlirtrt;
using namespace mlirtrt::compiler;
using namespace mlir;

namespace {

//===----------------------------------------------------------------------===//
// OutlineTensorRTOpPass
//===----------------------------------------------------------------------===//

/// ClusteringOpts that identifies groups of TensorRT operations and will be
/// clustered into one TensorRT function (which is eventually translated to a
/// engine).
static FailureOr<ClusteringOpts>
getTensorRTClusteringOptions(Operation *op) {
ClusteringOpts opts;
opts.mergeIndependentClusters = [](Operation *, ClusterRange, Operation *,
ClusterRange) { return true; };
opts.clusterTarget = Attribute{};
opts.isClusterableOp = [](Operation *op) {
if (llvm::isa<tensorrt::TensorRTDialect>(op->getDialect()))
return true;
return false;
};

return opts;
}

/// Create a `func.func` operation that represents `regionOp` and inserts into
/// the `module` SymbolTable. The function is given a name starting with
/// `nameBase` but may have numbers appended in order to unique the name. The
/// created function has argument/result types as indicated by the parameters.
static FailureOr<FunctionOpInterface>
createOutlinedFunc(RewriterBase &rewriter, Location loc, Operation *module,
StringRef nameBase, TypeRange funcArgTypes,
TypeRange funcResultTypes) {
OpBuilder::InsertionGuard g(rewriter);

// Create the func for outlining the region body.
FunctionType type =
FunctionType::get(rewriter.getContext(), funcArgTypes, funcResultTypes);
auto outlinedFunc = func::FuncOp::create(loc, nameBase, type, {});
Block *funcBody = outlinedFunc.addEntryBlock();

// Add an empty terminator.
rewriter.setInsertionPointToEnd(funcBody);
rewriter.create<func::ReturnOp>(loc);

// Insert into the module.
SymbolTable(module).insert(outlinedFunc,
module->getRegions().front().front().end());

// Tag the function with a UnitAttr for identifying the different kinds of
// functions based on the cluster type.
return cast<FunctionOpInterface>(outlinedFunc.getOperation());
}

/// Given the `op`, find the closest ModuleOp and check if the module has a
/// `tensorrt.module` operation in it. If it does, then return the existing
/// `tensorrt.module` operation. Otherwise, create a new `tensorrt.module`.
static tensorrt::TensorRTModuleOp getOrCreateTensorRTModuleOp(Operation *op) {
auto moduleOp = op->getParentOfType<ModuleOp>();
if (!moduleOp)
return nullptr;
SymbolTable symbolTable(moduleOp);
tensorrt::TensorRTModuleOp result = nullptr;
for (auto trtModuleOp :
moduleOp.getBody()->getOps<tensorrt::TensorRTModuleOp>()) {
result = trtModuleOp;
break;
}
if (result)
return result;

// Create the function. Symbol name de-duplication occurs with insert into the
// symbol table.
result = tensorrt::TensorRTModuleOp::create(moduleOp.getLoc(), "trt_engines");
symbolTable.insert(result, op->getParentOp() == moduleOp ? Block::iterator(op)
: Block::iterator{});
return result;
}

/// Helper function to call the `makeRegionIsolatedFromAbove` to capture all
/// required arguments into the InlineGroupOp region.
// static LogicalResult
// makeIsolatedFromAboveImpl(RewriterBase &rewriter, plan::InlineGroupOp regionOp,
// llvm::function_ref<bool(Operation *)> callBack) {
// Region &region = regionOp.getRegion();
// SmallVector<Value> capturedValues =
// makeRegionIsolatedFromAbove(rewriter, region, callBack);
// SmallVector<Value> operands = regionOp.getOperands();
// operands.append(capturedValues);
// auto isolatedRegionOp =
// rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands);
// rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(),
// isolatedRegionOp.getRegion().begin());
// rewriter.eraseOp(regionOp);
// return success();
// }

static FailureOr<tensorrt::CallAllocOp>
outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule, plan::InlineGroupOp op) {

// Make the region isolated from above. This captures the input operands.
SmallVector<Value> inputs =
makeRegionIsolatedFromAbove(rewriter, op.getRegion());

// Create the outlined function
FailureOr<FunctionOpInterface> func =
createOutlinedFunc(rewriter, op.getLoc(), trtModule,
"tensorrt_cluster", TypeRange(inputs), op->getResultTypes());
if (failed(func))
return failure();

rewriter.setInsertionPoint(op);
auto callOp = rewriter.create<tensorrt::CallAllocOp>(
op.getLoc(), op.getResultTypes(), inputs,
SymbolRefAttr::get(trtModule.getNameAttr(),
{FlatSymbolRefAttr::get(*func)}));

// Populate the function entry block.
rewriter.eraseBlock(&func->getFunctionBody().front());

// Move region op operations to the func body.
Operation *regionYieldOp = op.getYield();
rewriter.inlineRegionBefore(op.getRegion(), func->getFunctionBody(),
func->getFunctionBody().end());
rewriter.setInsertionPoint(regionYieldOp);
rewriter.replaceOpWithNewOp<func::ReturnOp>(regionYieldOp,
regionYieldOp->getOperands());
// replace the original region results.
rewriter.replaceOp(op, callOp);

return callOp;
}


class OutlineTensorRTOpPass
: public compiler::impl::OutlineTensorRTOpPassBase<
OutlineTensorRTOpPass> {
public:
using Base::Base;
void runOnOperation() override {
ModuleOp module = getOperation();

SymbolTableCollection symbolTable;
IRRewriter rewriter(&getContext());
// what are these? are they needed?
// DataFlowSolver solver;
// solver.load<dataflow::DeadCodeAnalysis>();
// solver.load<dataflow::SparseConstantPropagation>();
// solver.load<TensorKindAnalysis>(symbolTable);
// if (failed(solver.initializeAndRun(module)))
// return signalPassFailure();

FailureOr<ClusteringOpts> opts = getTensorRTClusteringOptions(module);
if (failed(opts)) {
emitError(module.getLoc()) << "failed to create clustering options";
return signalPassFailure();
}
// What do they do here?
// patterns.add(*opts, createInlineGroupOp, isOpInClusterRegion,
// target.getClusterFilter(),
// PatternBenefit(target.getClusterBenefit()));

// FailureOr<SmallVector<Operation *>> regionOps =
// rewrite->findClusterAndCreateRegionOp(module, rewriter);
// if (failed(regionOps)) {
// emitError(module.getLoc())
// << "clustering rewrite " << rewrite->getTarget() << " failed ";
// return signalPassFailure();
// }

tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(module);

SmallVector<plan::InlineGroupOp> clusters;
module.walk(
[&](plan::InlineGroupOp cluster) { clusters.push_back(cluster); });

for (plan::InlineGroupOp cluster : clusters) {
if (failed(outlineOp(rewriter, trtModuleOp, cluster)))
return signalPassFailure();
}
}
};
} // namespace

//===----------------------------------------------------------------------===//
// Pipeline Registrations
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 250f6f4

Please sign in to comment.