diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/CMakeLists.txt b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/CMakeLists.txt index 48758fad0..70216a43a 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/CMakeLists.txt +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(StablehloToExecutable) +add_subdirectory(TensorRTToExecutable) diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/CMakeLists.txt b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/CMakeLists.txt new file mode 100644 index 000000000..e549a6d5c --- /dev/null +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/CMakeLists.txt @@ -0,0 +1,4 @@ +set(_TABLEGEN_ARGS ) +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TensorRTToExecutable ${_TABLEGEN_ARGS}) +add_public_tablegen_target(MLIRTensorRTTensorRTToExecutableIncGen) diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.h new file mode 100644 index 000000000..53d6eb705 --- /dev/null +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.h @@ -0,0 +1,49 @@ +//===- Passes.h ----------------------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Declarations for opt tool pipeline command-line registration for pipelines +/// related to "tensorrt-to-executable". +/// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE_PASSES +#define MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE_PASSES + +#include +#include + +namespace mlirtrt::compiler { + +//===----------------------------------------------------------------------===// +// Add Tablegen'd pass declarations and registration methods. +//===----------------------------------------------------------------------===// +#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.h.inc" + +//===----------------------------------------------------------------------===// +// Pipeline Registrations +//===----------------------------------------------------------------------===// + +/// Register the TensorRT clustering and compilation pipelines. +void registerTensorRTToExecutablePipelines(); + +} // namespace mlirtrt::compiler + +#endif // MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE_PASSES diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.td new file mode 100644 index 000000000..a49940af0 --- /dev/null +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.td @@ -0,0 +1,35 @@ +//===- Passes.td ----------------------------------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE_PASSES +#define MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// OutlineTensorRTOpPass +//===----------------------------------------------------------------------===// +// TODO: what are the dependent dialects? what are the options? + +def OutlineTensorRTOpPass : Pass<"outline-tensorrt-op", + "::mlir::ModuleOp"> { + let summary = "Outline all tensorrt ops into a tensorrt module"; +} + +#endif // MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE_PASSES diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/TensorRTToExecutable.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/TensorRTToExecutable.h new file mode 100644 index 000000000..c0d204940 --- /dev/null +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/TensorRTToExecutable.h @@ -0,0 +1,107 @@ +//===- TensorRTToExecutable.h -----------------------------------*- C++ -*-===// +// +// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE +#define MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE + +// TODO (pranavm): MLIR_TRT_TARGET_TENSORRT is only needed because we pull in +// the TranslateToTensorRT.h header. If we move the translation options, we +// won't need it. +#ifdef MLIR_TRT_TARGET_TENSORRT +#include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h" + +#include "mlir-executor/Runtime/API/API.h" +#include "mlir-executor/Support/Status.h" +#include "mlir-tensorrt-dialect/Utils/Options.h" +#include "mlir-tensorrt-dialect/Utils/OptionsBundle.h" +#include "mlir-tensorrt/Compiler/Client.h" +#include "mlir-tensorrt/Compiler/Extension.h" +#include "mlir-tensorrt/Compiler/OptionsProviders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/TypeID.h" + +namespace mlirtrt::compiler { + +//===----------------------------------------------------------------------===// +// TensorRTToExecutableOptions +//===----------------------------------------------------------------------===// + +class TensorRTToExecutableTask; + +// TODO (pranavm): Figure out a better way to reuse TRT translation options - +// maybe move to options providers? +struct TensorRTOptions : public OptionsProvider { +public: + using OptionsProvider::OptionsProvider; + mlir::tensorrt::TensorRTTranslationOptions options; + + TensorRTOptions(mlir::OptionsContext &ctx) : OptionsProvider(ctx) {} + + void addToOptions(mlir::OptionsContext &context) { + options.addToOptions(context); + } +}; + +struct TensorRTToExecutableOptions + : public mlir::OptionsBundle { + // Default initialization does not require any extensions. + TensorRTToExecutableOptions() = default; + + TensorRTToExecutableOptions(TaskExtensionRegistry extensions); + + Option entrypoint{this, "entrypoint", llvm::cl::init("main"), + llvm::cl::desc("entrypoint function name")}; +}; + +//===----------------------------------------------------------------------===// +// TensorRTToExecutableTask +//===----------------------------------------------------------------------===// + +class TensorRTToExecutableTask + : public CompilationTask { +public: + TensorRTToExecutableTask(mlir::MLIRContext *ctx, + const TensorRTToExecutableOptions &options); + + /// Build the clustering pipeline that occurs on TensorRT Ops. + static void + buildTensorRTClusteringPipeline(mlir::OpPassManager &pm, + const TensorRTToExecutableOptions &options); + + /// Build the compilation pipeline that runs after clustering. + static void + buildPostClusteringPipeline(mlir::OpPassManager &pm, + const TensorRTToExecutableOptions &options); + + static void populatePassManager(mlir::PassManager &pm, + const TensorRTToExecutableOptions &options); +}; + +/// Register the task/options with the client's registry. +void registerTensorRTToExecutableTask(); + +} // namespace mlirtrt::compiler + +MLIR_DECLARE_EXPLICIT_TYPE_ID(mlirtrt::compiler::TensorRTToExecutableTask) + +#endif +#endif // MLIR_TENSORRT_COMPILER_TENSORRTTOEXECUTABLE diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Registration/RegisterMlirTensorRtPasses.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Registration/RegisterMlirTensorRtPasses.h index 7ac779ec8..ac87e2a26 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Registration/RegisterMlirTensorRtPasses.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Registration/RegisterMlirTensorRtPasses.h @@ -23,6 +23,7 @@ #define REGISTRATION_REGISTERMLIRTENSORRTPASSES_H #include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h" +#include "mlir-tensorrt/Compiler/TensorRTToExecutable/TensorRTToExecutable.h" #include "mlir-tensorrt/Conversion/Passes.h" #include "mlir-tensorrt/Transforms/Passes.h" #include "mlir/Conversion/Passes.h" @@ -53,6 +54,12 @@ inline void registerAllMlirTensorRtPasses() { mlir::registerTransformsPasses(); mlir::registerConvertPDLToPDLInterp(); + // TODO (pranavm): Check if this needs to be conditional - the TRT passes + // above are not. +#ifdef MLIR_TRT_TARGET_TENSORRT + mlirtrt::compiler::registerTensorRTToExecutableTask(); +#endif + #ifdef MLIR_TRT_ENABLE_HLO mlirtrt::compiler::registerStablehloToExecutablePasses(); mlirtrt::compiler::registerStablehloToExecutablePipelines(); diff --git a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Registration/CMakeLists.txt b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Registration/CMakeLists.txt index 4462ce99c..644817c0b 100644 --- a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Registration/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Registration/CMakeLists.txt @@ -6,4 +6,5 @@ add_mlir_tensorrt_public_c_api_library(MLIRTensorRTCAPIRegisterAllDialects MLIRTensorRTRegistration MLIRCAPIIR MLIRTensorRTCompilerStableHloToExecutable + MLIRTensorRTCompilerTensorRTToExecutable ) diff --git a/mlir-tensorrt/compiler/lib/Compiler/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Compiler/CMakeLists.txt index 68e186012..f4f98a46c 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Compiler/CMakeLists.txt @@ -15,4 +15,5 @@ add_mlir_tensorrt_library(MLIRTensorRTCompilerClient MLIRTensorRTSupportDeviceInfo ) -add_subdirectory(StablehloToExecutable) \ No newline at end of file +add_subdirectory(StablehloToExecutable) +add_subdirectory(TensorRTToExecutable) diff --git a/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/CMakeLists.txt new file mode 100644 index 000000000..63ab3d933 --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_tensorrt_library(MLIRTensorRTCompilerTensorRTToExecutable + TensorRTToExecutable.cpp + Passes.cpp + + PARTIAL_SOURCES_INTENDED + + DEPENDS + MLIRTensorRTTensorRTToExecutableIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRTensorRTRegistration + MLIRTensorRTTargetLua + MLIRTensorRTOptionUtils + MLIRTensorRTTargetTensorRT + MLIRTensorRTCompilerClient + ) diff --git a/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp b/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp new file mode 100644 index 000000000..fd31cf3fe --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp @@ -0,0 +1,250 @@ +//===- Passes.cpp --------------------------------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.h" +#include "mlir-executor/Executor/Transforms/Passes.h" +#include "mlir-tensorrt/Compiler/TensorRTToExecutable/TensorRTToExecutable.h" +#include "mlir-tensorrt/Conversion/Passes.h" +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/PassOptions.h" + +#ifdef MLIR_TRT_ENABLE_HLO + +namespace mlirtrt::compiler { +#define GEN_PASS_DEF_OUTLINETENSORRTOPPASS +#include "mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.h.inc" +} // namespace mlirtrt::compiler + +using namespace mlirtrt; +using namespace mlirtrt::compiler; +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// OutlineTensorRTOpPass +//===----------------------------------------------------------------------===// + +/// ClusteringOpts that identifies groups of TensorRT operations and will be +/// clustered into one TensorRT function (which is eventually translated to a +/// engine). +static FailureOr +getTensorRTClusteringOptions(Operation *op) { + ClusteringOpts opts; + opts.mergeIndependentClusters = [](Operation *, ClusterRange, Operation *, + ClusterRange) { return true; }; + opts.clusterTarget = Attribute{}; + opts.isClusterableOp = [](Operation *op) { + if (llvm::isa(op->getDialect())) + return true; + return false; + }; + + return opts; +} + +/// Create a `func.func` operation that represents `regionOp` and inserts into +/// the `module` SymbolTable. The function is given a name starting with +/// `nameBase` but may have numbers appended in order to unique the name. The +/// created function has argument/result types as indicated by the parameters. +static FailureOr +createOutlinedFunc(RewriterBase &rewriter, Location loc, Operation *module, + StringRef nameBase, TypeRange funcArgTypes, + TypeRange funcResultTypes) { + OpBuilder::InsertionGuard g(rewriter); + + // Create the func for outlining the region body. + FunctionType type = + FunctionType::get(rewriter.getContext(), funcArgTypes, funcResultTypes); + auto outlinedFunc = func::FuncOp::create(loc, nameBase, type, {}); + Block *funcBody = outlinedFunc.addEntryBlock(); + + // Add an empty terminator. + rewriter.setInsertionPointToEnd(funcBody); + rewriter.create(loc); + + // Insert into the module. + SymbolTable(module).insert(outlinedFunc, + module->getRegions().front().front().end()); + + // Tag the function with a UnitAttr for identifying the different kinds of + // functions based on the cluster type. + return cast(outlinedFunc.getOperation()); +} + +/// Given the `op`, find the closest ModuleOp and check if the module has a +/// `tensorrt.module` operation in it. If it does, then return the existing +/// `tensorrt.module` operation. Otherwise, create a new `tensorrt.module`. +static tensorrt::TensorRTModuleOp getOrCreateTensorRTModuleOp(Operation *op) { + auto moduleOp = op->getParentOfType(); + if (!moduleOp) + return nullptr; + SymbolTable symbolTable(moduleOp); + tensorrt::TensorRTModuleOp result = nullptr; + for (auto trtModuleOp : + moduleOp.getBody()->getOps()) { + result = trtModuleOp; + break; + } + if (result) + return result; + + // Create the function. Symbol name de-duplication occurs with insert into the + // symbol table. + result = tensorrt::TensorRTModuleOp::create(moduleOp.getLoc(), "trt_engines"); + symbolTable.insert(result, op->getParentOp() == moduleOp ? Block::iterator(op) + : Block::iterator{}); + return result; +} + +/// Helper function to call the `makeRegionIsolatedFromAbove` to capture all +/// required arguments into the InlineGroupOp region. +// static LogicalResult +// makeIsolatedFromAboveImpl(RewriterBase &rewriter, plan::InlineGroupOp regionOp, +// llvm::function_ref callBack) { +// Region ®ion = regionOp.getRegion(); +// SmallVector capturedValues = +// makeRegionIsolatedFromAbove(rewriter, region, callBack); +// SmallVector operands = regionOp.getOperands(); +// operands.append(capturedValues); +// auto isolatedRegionOp = +// rewriter.create(regionOp.getLoc(), operands); +// rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(), +// isolatedRegionOp.getRegion().begin()); +// rewriter.eraseOp(regionOp); +// return success(); +// } + +static FailureOr +outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule, plan::InlineGroupOp op) { + + // Make the region isolated from above. This captures the input operands. + SmallVector inputs = + makeRegionIsolatedFromAbove(rewriter, op.getRegion()); + + // Create the outlined function + FailureOr func = + createOutlinedFunc(rewriter, op.getLoc(), trtModule, + "tensorrt_cluster", TypeRange(inputs), op->getResultTypes()); + if (failed(func)) + return failure(); + + rewriter.setInsertionPoint(op); + auto callOp = rewriter.create( + op.getLoc(), op.getResultTypes(), inputs, + SymbolRefAttr::get(trtModule.getNameAttr(), + {FlatSymbolRefAttr::get(*func)})); + + // Populate the function entry block. + rewriter.eraseBlock(&func->getFunctionBody().front()); + + // Move region op operations to the func body. + Operation *regionYieldOp = op.getYield(); + rewriter.inlineRegionBefore(op.getRegion(), func->getFunctionBody(), + func->getFunctionBody().end()); + rewriter.setInsertionPoint(regionYieldOp); + rewriter.replaceOpWithNewOp(regionYieldOp, + regionYieldOp->getOperands()); + // replace the original region results. + rewriter.replaceOp(op, callOp); + + return callOp; +} + + +class OutlineTensorRTOpPass + : public compiler::impl::OutlineTensorRTOpPassBase< + OutlineTensorRTOpPass> { +public: + using Base::Base; + void runOnOperation() override { + ModuleOp module = getOperation(); + + SymbolTableCollection symbolTable; + IRRewriter rewriter(&getContext()); + // what are these? are they needed? + // DataFlowSolver solver; + // solver.load(); + // solver.load(); + // solver.load(symbolTable); + // if (failed(solver.initializeAndRun(module))) + // return signalPassFailure(); + + FailureOr opts = getTensorRTClusteringOptions(module); + if (failed(opts)) { + emitError(module.getLoc()) << "failed to create clustering options"; + return signalPassFailure(); + } + // What do they do here? + // patterns.add(*opts, createInlineGroupOp, isOpInClusterRegion, + // target.getClusterFilter(), + // PatternBenefit(target.getClusterBenefit())); + + // FailureOr> regionOps = + // rewrite->findClusterAndCreateRegionOp(module, rewriter); + // if (failed(regionOps)) { + // emitError(module.getLoc()) + // << "clustering rewrite " << rewrite->getTarget() << " failed "; + // return signalPassFailure(); + // } + + tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(module); + + SmallVector clusters; + module.walk( + [&](plan::InlineGroupOp cluster) { clusters.push_back(cluster); }); + + for (plan::InlineGroupOp cluster : clusters) { + if (failed(outlineOp(rewriter, trtModuleOp, cluster))) + return signalPassFailure(); + } + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pipeline Registrations +//===----------------------------------------------------------------------===// + +namespace { +class TensorRTToExecutablePassPipelineOptions + : public PassPipelineOptionsAdaptor< + TensorRTToExecutablePassPipelineOptions, + TensorRTToExecutableOptions> {}; +} // namespace + +void mlirtrt::compiler::registerTensorRTToExecutablePipelines() { + PassPipelineRegistration( + "tensorrt-clustering-pipeline", + "apply clustering to tensorrt IR", + [](OpPassManager &pm, + const TensorRTToExecutablePassPipelineOptions &opts) { + TensorRTToExecutableTask::buildTensorRTClusteringPipeline(pm, opts); + }); + + PassPipelineRegistration( + "tensorrt-compilation-pipeline", "apply compilation post-clustering", + [](OpPassManager &pm, + const TensorRTToExecutablePassPipelineOptions &opts) { + TensorRTToExecutableTask::buildPostClusteringPipeline(pm, opts); + }); +} + +#endif // MLIR_TRT_ENABLE_HLO \ No newline at end of file diff --git a/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/TensorRTToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/TensorRTToExecutable.cpp new file mode 100644 index 000000000..af39858d7 --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/TensorRTToExecutable.cpp @@ -0,0 +1,191 @@ +//===- TensorRTToExecutable.cpp ---------------------------------*- C++ -*-===// +// +// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +#ifdef MLIR_TRT_TARGET_TENSORRT + +#include "mlir-tensorrt/Compiler/TensorRTToExecutable/TensorRTToExecutable.h" +#include "mlir-executor/Conversion/Passes.h" +#include "mlir-executor/Executor/Transforms/Passes.h" +#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h" +#include "mlir-tensorrt/Compiler/OptionsProviders.h" +#include "mlir-tensorrt/Compiler/OptionsRegistry.h" +#include "mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.h" +#include "mlir-tensorrt/Conversion/Passes.h" +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" +#include "mlir-tensorrt/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using namespace mlirtrt::compiler; + +//===----------------------------------------------------------------------===// +// TensorRTToExecutableOptions +//===----------------------------------------------------------------------===// + +TensorRTToExecutableOptions::TensorRTToExecutableOptions( + TaskExtensionRegistry extensions) { + // TODO (pranavm): We don't need extensions - remove from constructor and add + // `setExtensions` to base class. + assert(extensions.extensions.size() == 0); +} + +//===----------------------------------------------------------------------===// +// TensorRTToExecutableTask +//===----------------------------------------------------------------------===// + +TensorRTToExecutableTask::TensorRTToExecutableTask( + MLIRContext *ctx, const TensorRTToExecutableOptions &options) + : CompilationTask(ctx, options) { + options.get().applyToPassManager(*this); +} + +void TensorRTToExecutableTask::buildTensorRTClusteringPipeline( + OpPassManager &pm, const TensorRTToExecutableOptions &opts) { + pm.addPass(createOutlineTensorRTOpPass()); +} + +void TensorRTToExecutableTask::buildPostClusteringPipeline( + OpPassManager &pm, const TensorRTToExecutableOptions &options) { + // Post-clustering + pm.addPass(createConvertTensorRTToTensorRTRuntimePass()); + + pm.addNestedPass(plan::createPostClusteringValidationPass()); + + pm.addPass(createCanonicalizerPass()); + + pm.addPass(createInlinerPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createCanonicalizerPass()); + + // We then perform some final simplification on the top-level func.func ops + // (e.g. public entrypoint functions). + pm.addNestedPass(createSCFDetensorizeLoopsPass()); + pm.addNestedPass(createCanonicalizerPass()); + + // Pre-bufferization + // Simplify and translate functions nested in `tensorrt.module` ops. + auto &trtPM = pm.nest(); + tensorrt::buildTensorRTModuleTransformationPipeline( + trtPM, options.get().options.enableStronglyTyped); + trtPM.addPass(tensorrt::createTranslateTensorRTPass( + nullptr, options.get().options)); + + pm.addPass(createMemRefCastEliminationPass()); + pm.addPass(plan::createPlanAllocTensorsPass()); + pm.addPass(plan::createPlanBufferizePass()); + pm.addPass(createMemRefCastEliminationPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(bufferization::createDropEquivalentBufferResultsPass()); + plan::buildPlanBufferOptimizationPipeline(pm); + plan::buildPlanBufferDeallocationPipeline( + pm, bufferization::DeallocationOptions{ + /*privateFuncDynamicOwnership=*/false}); + + // Post-bufferization + pm.addPass(createConvertMemRefToCUDAPass()); + pm.addPass(createConvertPlanToExecutorPass()); + pm.addPass(executor::createExecutorAllocsToGlobalsPass()); + pm.addNestedPass( + executor::createExecutorPopulateFunctionMetadataPass()); + + // Executor lowering + ConvertTensorRTRuntimeToExecutorPassOptions toExecutorOpts; + toExecutorOpts.indexBitwidth = options.get().indexBitwidth; + toExecutorOpts.usePackedMemRefCConv = + options.get().usePackedMemRefCConv; + pm.addPass(createConvertTensorRTRuntimeToExecutorPass(toExecutorOpts)); + + ConvertCUDAToExecutorPassOptions cudaToExecutorOpts; + cudaToExecutorOpts.indexBitwidth = + options.get().indexBitwidth; + cudaToExecutorOpts.usePackedMemRefCConv = + options.get().usePackedMemRefCConv; + pm.addPass(createConvertCUDAToExecutorPass(cudaToExecutorOpts)); + + pm.addPass(createDropNestedModulesPass()); +} + +void TensorRTToExecutableTask::populatePassManager( + mlir::PassManager &pm, const TensorRTToExecutableOptions &options) { + buildTensorRTClusteringPipeline(pm, options); + + buildPostClusteringPipeline(pm, options); + + mlir::executor::ConvertStdToExecutorPassOptions stdToExecOpts; + stdToExecOpts.indexBitwidth = options.get().indexBitwidth; + stdToExecOpts.usePackedMemRefCConv = true; + mlir::executor::buildExecutorLoweringPipeline(pm, stdToExecOpts); +} + +void mlirtrt::compiler::registerTensorRTToExecutableTask() { + registerOption( + "tensorrt-to-executable", + [](MLIRContext *ctx, ArrayRef opts) + -> StatusOr> { + auto task = optionsCreateFromArgs(ctx, opts); + if (!task.isOk()) + return task.getStatus(); + return std::unique_ptr(std::move(*task)); + }); + + registerCompilationTask( + "tensorrt-to-executable", + [](CompilerClient &client, llvm::ArrayRef options) + -> StatusOr { + TensorRTToExecutableOptions result; + std::string err; + if (failed(result.parse(options, err))) + return getInvalidArgStatus( + "failed to parse options string \"{0:$[ ]}\" due to error {1}", + llvm::iterator_range(options), err); + + llvm::Error finalizeStatus = result.finalize(); + std::optional errMsg{}; + llvm::handleAllErrors(std::move(finalizeStatus), + [&errMsg](const llvm::StringError &err) { + errMsg = err.getMessage(); + }); + + if (errMsg) + return getInvalidArgStatus("failed to parse options due to error {0}", + errMsg); + + std::optional hashCode = result.getHash(); + if (!hashCode) + return getInvalidArgStatus("failed to hash options"); + + CompilationTaskBase *cached = client.lookupCachedCompilationTask( + mlir::TypeID::get(), *hashCode); + if (cached) + return cached; + + auto newPM = std::make_unique( + client.getContext(), result); + auto ptr = newPM.get(); + client.updateCachedCompilationTask( + *hashCode, std::move(newPM)); + return ptr; + }); +} + +MLIR_DEFINE_EXPLICIT_TYPE_ID(mlirtrt::compiler::TensorRTToExecutableTask) + +#endif