From b80b459179a07938dbf6f75825e454f499a408e8 Mon Sep 17 00:00:00 2001 From: Jhalak Patel Date: Mon, 30 Sep 2024 09:56:10 -0700 Subject: [PATCH 1/2] Enable end to end non-DPS testing Implement python binding changes to allow execute function return multiple returns. Update tests to use non-DPS style calling convention. Also, enable end to end lowering by enabling conversion of closed alloc group op to tensorrt dialect. Miscellaneous fixes: 1. Add missing handling of `CallAllocOp` in EliminateShapeOps pass. 2. Skip non ranked tensor type function arguments while collecting host tensor arguments. 3. Temporarily add a pass to remove clone operation in MemRefToExecutor dialect conversion. 4. Relax memref creation for empty shape tensors. 5. Fix memref life returned from Lua function results. This required session allocator to track returned memref. Also, address Fix incorrect indexing into output memref results Return error status instead of silently erroring out during TensorRT weight conversion Address review comments --- .../lib/Compiler/StableHloToExecutable.cpp | 15 +- .../TensorRTRuntimeToExecutor.cpp | 28 +- .../Dialect/Plan/Transforms/CMakeLists.txt | 1 + .../Plan/Transforms/EliminateShapeOps.cpp | 39 ++- .../Plan/Transforms/OutlineClusters.cpp | 145 ++++---- .../lib/Dialect/Plan/Transforms/Passes.cpp | 32 +- .../include/mlir-executor-c/Runtime/Runtime.h | 41 ++- .../include/mlir-executor/Runtime/API/API.h | 2 +- .../Runtime/Backend/Lua/LuaRuntime.h | 7 - .../mlir-executor/Support/Allocators.h | 6 + .../executor/lib/CAPI/Common/Common.cpp | 2 +- .../executor/lib/CAPI/Runtime/Runtime.cpp | 101 +++++- .../lib/Conversion/MemRefToExecutor.cpp | 2 + .../executor/lib/Runtime/API/API.cpp | 40 ++- .../lib/Runtime/Backend/Lua/LuaRuntime.cpp | 87 +++-- .../Lua/Modules/TensorRT/TensorRTModule.cpp | 8 +- .../executor/lib/Support/Allocators.cpp | 21 +- .../test/lib/BufferizationTestPass.cpp | 1 + .../python/bindings/Runtime/RuntimePyBind.cpp | 68 +++- .../NetworkEncoder.cpp | 27 +- .../tensorrt-runtime-to-executor.mlir | 108 ++++-- .../IntegrationTests/test_non_dps_cconv.py | 323 ++++++++++++++++++ 22 files changed, 888 insertions(+), 216 deletions(-) create mode 100644 mlir-tensorrt/test/python/IntegrationTests/test_non_dps_cconv.py diff --git a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp index 3b8308449..45357387c 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp @@ -223,6 +223,10 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions( disallowHostTensorsInTensorRTClusters, llvm::cl::init(false), llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor " "calculations (but they can still be inputs)")); + addOption( + "enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false), + llvm::cl::desc( + "allow tensorrt based output allocations using output allocator")); addOption("executor-index-bitwidth", executorIndexBitwidth, llvm::cl::init(64)); addOption("device-compute-capability", deviceComputeCapability, @@ -307,6 +311,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline( plan::StablehloClusteringPassOptions clusteringOpts{}; clusteringOpts.disallowHostTensorsInTensorRTClusters = opts.disallowHostTensorsInTensorRTClusters; + clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns; clusteringOpts.entrypoint = opts.entrypoint; plan::buildPlanSegmentationPipeline(pm, clusteringOpts); @@ -340,7 +345,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline( // Perform bufferization. pm.addPass(createMemRefCastEliminationPass()); - pm.addPass(plan::createPlanAllocTensorsPass()); + plan::PlanAllocTensorsPassOptions allocTensorsOpts{}; + allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns; + pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts)); pm.addPass(plan::createPlanBufferizePass()); pm.addPass(createMemRefCastEliminationPass()); pm.addPass(createCanonicalizerPass()); @@ -529,6 +536,11 @@ struct ClusteringPipelineCliOpts *this, "device-compute-capability", llvm::cl::desc("target device compute capability (SM version)"), llvm::cl::init(60)}; + Option enableNonDPSReturns{ + *this, "enable-non-dps-returns", + llvm::cl::desc( + "allow tensorrt based output allocations using output allocator"), + llvm::cl::init(false)}; Option deviceMaxSharedMemoryPerBlockKb{ *this, "device-max-smem-per-block", llvm::cl::desc("max shared memory per block (in kilobytes)"), @@ -556,6 +568,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts( opts.deviceComputeCapability = cliOpts.deviceComputeCapability; opts.deviceMaxSharedMemoryPerBlockKb = cliOpts.deviceMaxSharedMemoryPerBlockKb; + opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns; opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost; opts.entrypoint = cliOpts.entrypoint; return opts; diff --git a/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp b/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp index 2a585a43d..8c10b35d4 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp @@ -379,20 +379,22 @@ struct ConvertEnqueueAllocToCall // Create output memrefs from output descriptors SmallVector results; + // Initialize output descriptor offset to skip number of results. + // `outputDescOffset` is used to retrieve rank, ptr, shapes, and strides per + // result. + outputDescOffset = 1; for (unsigned i = 0; i < op->getNumResults(); ++i) { unsigned rank = cast(op->getResult(i).getType()).getRank(); - unsigned offset = - 1 + - i * (2 * rank + 2); // num res, (i * (rank, ptr, [shape], [stride])) - Value rankOffset = b.create( b.getI64Type(), structType, - ArrayRef{this->createIndexConstant(b, 0), - rewriter.getI64IntegerAttr(offset++)}); + ArrayRef{ + this->createIndexConstant(b, 0), + rewriter.getI64IntegerAttr(outputDescOffset++)}); Value devicePtrOffset = b.create( b.getI64Type(), structType, - ArrayRef{this->createIndexConstant(b, 0), - rewriter.getI64IntegerAttr(offset++)}); + ArrayRef{ + this->createIndexConstant(b, 0), + rewriter.getI64IntegerAttr(outputDescOffset++)}); [[maybe_unused]] Value rankValue = b.create( b.getI64Type(), outputDescriptors, rankOffset); @@ -406,8 +408,9 @@ struct ConvertEnqueueAllocToCall for (unsigned r = 0; r < rank; ++r) { Value shapeOffset = b.create( b.getI64Type(), structType, - ArrayRef{this->createIndexConstant(b, 0), - rewriter.getI64IntegerAttr(offset++)}); + ArrayRef{ + this->createIndexConstant(b, 0), + rewriter.getI64IntegerAttr(outputDescOffset++)}); Value shape = b.create( b.getI64Type(), outputDescriptors, shapeOffset); shapes.push_back(shape); @@ -416,8 +419,9 @@ struct ConvertEnqueueAllocToCall for (unsigned r = 0; r < rank; ++r) { Value strideOffset = b.create( b.getI64Type(), structType, - ArrayRef{this->createIndexConstant(b, 0), - rewriter.getI64IntegerAttr(offset++)}); + ArrayRef{ + this->createIndexConstant(b, 0), + rewriter.getI64IntegerAttr(outputDescOffset++)}); Value shape = b.create( b.getI64Type(), outputDescriptors, strideOffset); shapes.push_back(shape); diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt index 65add5f43..2f0bb5c11 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt @@ -37,6 +37,7 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanTransforms MLIRTensorRTStablehloScalarToArith MLIRTensorRTStablehloToTensorRT MLIRTensorRTTensorRTRuntimeDialect + MLIRBufferizationToMemRef MLIRTransforms StablehloOps ) diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp index eb29494c9..84d67b037 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern { } // namespace /// Get a map from `tensorrt.func` functions to associated `tensorrt.call` -/// operations. -static llvm::DenseMap> +/// and `tensorrt.call_alloc` operations. +static llvm::DenseMap> getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) { - llvm::DenseMap> map; - op->walk([&](tensorrt::CallOp callOp) { - func::FuncOp func = callOp.getFuncCallee(collection); - if (map.contains(func)) { - map[func].push_back(callOp); + llvm::DenseMap> map; + op->walk([&](Operation *callOp) { + if (!isa(callOp)) return; + + func::FuncOp func; + if (auto call = dyn_cast(callOp)) { + func = call.getFuncCallee(collection); + } else { + auto callAlloc = dyn_cast(callOp); + func = callAlloc.getFuncCallee(collection); } - map.insert(std::make_pair(func, SmallVector{callOp})); + + if (map.count(func)) + map[func].push_back(callOp); + else + map.insert({func, SmallVector{callOp}}); }); return map; } @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) { /// `tensorrt.call` operations. static LogicalResult removeUnusedArgs(SymbolTableCollection &collection, ModuleOp op, func::FuncOp funcOp, - ArrayRef callOps) { + ArrayRef callOps) { llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0); for (BlockArgument arg : funcOp.getArguments()) { if (arg.use_empty()) @@ -99,10 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection, funcOp.eraseArgument(i); // Update the call ops. - for (tensorrt::CallOp callOp : callOps) - callOp.getInputsMutable().erase(i); + for (Operation *callOp : callOps) { + if (auto call = dyn_cast(callOp)) + call.getInputsMutable().erase(i); + else if (auto callAlloc = dyn_cast(callOp)) + callAlloc.getInputsMutable().erase(i); + else + return emitError(funcOp->getLoc()) + << "Unexpected operation type in callOps"; + } } - return success(); } diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp index d593393f6..391d86578 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp @@ -156,28 +156,32 @@ getTensorRTShapeProfile(plan::BoundsAttr attr, Value v) { return getProfileAttr(attr.getMinShape(), attr.getMaxShape()); } -static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, - plan::InlineClosedGroupOp op) { - tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op); - auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs())); - FailureOr func = createOutlinedFunc( - rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster", - "cluster.tensorrt", TypeRange(op.getInputs()), - op.getYield()->getOperandTypes()); - if (failed(func)) - return failure(); - assert(func->getFunctionBody().getBlocks().size() == 1 && - "expected body with one block"); - func->setPublic(); - - rewriter.setInsertionPoint(op); - auto callOp = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getInputs(), op.getOuts(), - SymbolRefAttr::get(trtModuleOp.getNameAttr(), - {FlatSymbolRefAttr::get(*func)})); +template +static auto createCallOp(RewriterBase &rewriter, OpType op, + tensorrt::TensorRTModuleOp trtModuleOp, + FunctionOpInterface func) { + static_assert( + std::is_same_v || + std::is_same_v, + "OpType must be either InlineClosedGroupOp or InlineClosedAllocGroupOp"); + if constexpr (std::is_same_v) + return rewriter.create( + op.getLoc(), op.getResultTypes(), op.getInputs(), op.getOuts(), + SymbolRefAttr::get(trtModuleOp.getNameAttr(), + {FlatSymbolRefAttr::get(func)})); + else if constexpr (std::is_same_v) + return rewriter.create( + op.getLoc(), op.getResultTypes(), op.getInputs(), + SymbolRefAttr::get(trtModuleOp.getNameAttr(), + {FlatSymbolRefAttr::get(func)})); +} +template +static LogicalResult populateFunctionAttributes(RewriterBase &rewriter, + OpType op, + FunctionOpInterface *func) { // Populate the function arguments attributes. - for (unsigned i = 0; i < (*func).getNumArguments(); i++) { + for (unsigned i = 0; i < func->getNumArguments(); i++) { BoundsAttr srcAttr = cast(op.getInputAttrs()[i]); // We may have scalar (index|signless int)-typed values since we haven't // eliminated `plan.(with_shape|with_values)` ops yet. @@ -202,30 +206,57 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, func->setArgAttr(i, mlir::getHostTensorArgAttrName(), rewriter.getUnitAttr()); } - // Populate the function result attributes. - for (unsigned i = 0; i < (*func).getNumResults(); i++) { - BoundsAttr srcAttr = cast(op.getResAttrs()[i]); - if (srcAttr.isNone()) - continue; - FailureOr boundsAttr = - getTensorRTShapeProfile(srcAttr, op.getResults()[i]); - if (failed(boundsAttr)) - return op->emitOpError("failed to create TensorRT shape profile " - "attribute from Plan BoundsAttr for result #") - << i << " (" << srcAttr << ")"; - if (srcAttr.isShapeBound()) { + // Populate the function result attributes for DPS call op. + if constexpr (std::is_same_v) { + for (unsigned i = 0; i < func->getNumResults(); i++) { + BoundsAttr srcAttr = cast(op.getResAttrs()[i]); + if (srcAttr.isNone()) + continue; + FailureOr boundsAttr = + getTensorRTShapeProfile(srcAttr, op.getResults()[i]); + if (failed(boundsAttr)) + return op->emitOpError("failed to create TensorRT shape profile " + "attribute from Plan BoundsAttr for result #") + << i << " (" << srcAttr << ")"; + if (srcAttr.isShapeBound()) { + func->setResultAttr( + i, tensorrt::TensorRTDialect::getShapeProfileArgAttrName(), + *boundsAttr); + continue; + } + assert(srcAttr.isValueBound() && "expected value bound or shape bound"); func->setResultAttr( - i, tensorrt::TensorRTDialect::getShapeProfileArgAttrName(), + i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(), *boundsAttr); - continue; + func->setResultAttr(i, mlir::getHostTensorArgAttrName(), + rewriter.getUnitAttr()); } - assert(srcAttr.isValueBound() && "expected value bound or shape bound"); - func->setResultAttr( - i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(), - *boundsAttr); - func->setResultAttr(i, mlir::getHostTensorArgAttrName(), - rewriter.getUnitAttr()); } + return success(); +} + +template +static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, OpType op) { + tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op); + auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs())); + + FailureOr func = createOutlinedFunc( + rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster", + "cluster.tensorrt", TypeRange(op.getInputs()), + op.getYield()->getOperandTypes()); + + if (failed(func)) + return failure(); + + assert(func->getFunctionBody().getBlocks().size() == 1 && + "expected body with one block"); + func->setPublic(); + + rewriter.setInsertionPoint(op); + auto callOp = createCallOp(rewriter, op, trtModuleOp, *func); + + if (failed(populateFunctionAttributes(rewriter, op, &(*func)))) + return failure(); // Populate the function entry block. rewriter.eraseBlock(&func->getFunctionBody().front()); @@ -234,14 +265,14 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, // ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op // has its own symbol table. SymbolTableCollection symbolTable; - for (auto compositeOp : op.getBody().getOps()) { + for (auto compositeOp : + op.getBody().template getOps()) { auto decompositionFunc = dyn_cast_if_present( - symbolTable.lookupSymbolIn(op->getParentOfType(), + symbolTable.lookupSymbolIn(op->template getParentOfType(), compositeOp.getDecompositionAttr())); if (!decompositionFunc) return emitError(compositeOp.getLoc()) - << "failed to lookup stablehlo.composite decomposition " - "function: " + << "failed to lookup stablehlo.composite decomposition function: " << compositeOp.getDecompositionAttr(); rewriter.moveOpAfter(decompositionFunc, func->getOperation()); } @@ -254,24 +285,20 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, rewriter.replaceOpWithNewOp(regionYieldOp, regionYieldOp->getOperands()); - // Erase the DPS arugments, which now should be unused. - if (llvm::any_of(func->getArguments().take_back(op.getOuts().size()), - [](BlockArgument arg) { return !arg.use_empty(); })) - return failure(); - func->getFunctionBody().front().eraseArguments(op.getInputs().size(), - op.getOuts().size()); + if constexpr (std::is_same_v) { + // Erase the DPS arugments, which now should be unused. + if (llvm::any_of(func->getArguments().take_back(op.getOuts().size()), + [](BlockArgument arg) { return !arg.use_empty(); })) + return failure(); + func->getFunctionBody().front().eraseArguments(op.getInputs().size(), + op.getOuts().size()); + } // replace the original region results. rewriter.replaceOp(op, callOp); return success(); } -static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, - plan::InlineClosedAllocGroupOp op) { - return op.emitError("outlinining inline closed alloc group ops to tensorrt " - "dialect is not yet implemented"); -} - /// Create outlined functions for each `scf.execute_region` operation within /// `region`. static FailureOr> @@ -302,12 +329,14 @@ createFunctionsFromRegions(RewriterBase &rewriter, Region ®ion, } if (auto group = dyn_cast(op)) { - if (failed(outlineTensorRTRegion(rewriter, group))) + if (failed(outlineTensorRTRegion(rewriter, + group))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto allocGroup = dyn_cast(op)) { - if (failed(outlineTensorRTRegion(rewriter, allocGroup))) + if (failed(outlineTensorRTRegion( + rewriter, allocGroup))) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp index 753bea88e..1a0a6415d 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp @@ -24,6 +24,7 @@ //===----------------------------------------------------------------------===// #include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" #include "mlir-tensorrt/Transforms/Passes.h" +#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" #include "mlir/Dialect/Bufferization/Pipelines/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" @@ -80,6 +81,7 @@ void plan::buildPlanBufferDeallocationPipeline( pm.addPass(createCanonicalizerPass()); pm.addPass(bufferization::createBufferDeallocationSimplificationPass()); pm.addPass(bufferization::createLowerDeallocationsPass()); + pm.addPass(mlir::createBufferizationToMemRefPass()); pm.addPass(createCSEPass()); pm.addPass(createCanonicalizerPass()); } @@ -103,31 +105,21 @@ struct ClusteringPipelineCliOpts llvm::cl::init(NV_TENSORRT_MAJOR)}; }; -struct PlanBufferizationPipelineCliOpts - : public PassPipelineOptions { - Option enableNonDPSReturns{ - *this, "enable-non-dps-returns", - llvm::cl::desc("allow backend clusters to directly allocate outputs"), - llvm::cl::init(false)}; -}; - } // namespace // Register pipelines. void plan::registerPlanDialectPipelines() { - PassPipelineRegistration - executorBufferizationPipeline( - "plan-bufferize-pipeline", - "perform bufferization and standard pre/post processing passes", - [](OpPassManager &pm, const PlanBufferizationPipelineCliOpts &opts) { - PlanAllocTensorsPassOptions allocTensorOpts{}; - allocTensorOpts.enableNonDPSReturns = opts.enableNonDPSReturns; - buildPlanBufferizationPipeline(pm, allocTensorOpts); - buildPlanBufferOptimizationPipeline(pm); - buildPlanBufferDeallocationPipeline( - pm, bufferization::DeallocationOptions{false}); - }); + PassPipelineRegistration<> executorBufferizationPipeline( + "plan-bufferize-pipeline", + "perform bufferization and standard pre/post processing passes", + [](OpPassManager &pm) { + PlanAllocTensorsPassOptions allocTensorOpts{}; + buildPlanBufferizationPipeline(pm, allocTensorOpts); + buildPlanBufferOptimizationPipeline(pm); + buildPlanBufferDeallocationPipeline( + pm, bufferization::DeallocationOptions{false}); + }); PassPipelineRegistration<> bufferOptPipeline( "plan-buffer-opt-pipeline", "perform post-bufferization optimizations", diff --git a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h index 807d3dc23..35041e312 100644 --- a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h +++ b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h @@ -53,7 +53,7 @@ extern "C" { /// caller must be sure to delete errors via mtrtStatusDestroy. //===----------------------------------------------------------------------===// -typedef struct MTRT_RuntimeClient MTRT_Runtimeclient; +typedef struct MTRT_RuntimeClient MTRT_RuntimeClient; //===----------------------------------------------------------------------===// // MTRT_GlobalDebug @@ -87,7 +87,7 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStreamCreate(MTRT_Stream *stream); static inline bool mtrtStreamIsNull(MTRT_Stream stream) { return !stream.ptr; } /// Returns null stream. -static inline MTRT_Stream mtrtStreamGetNull() { return MTRT_Stream{nullptr}; } +static inline MTRT_Stream mtrtStreamGetNull() { return MTRT_Stream{NULL}; } /// Synchronizes `MTRT_Stream` MLIR_CAPI_EXPORTED MTRT_Status mtrtStreamSynchronize(MTRT_Stream stream); @@ -108,7 +108,7 @@ static inline bool mtrtDeviceIsNull(MTRT_Device device) { return !device.ptr; } /// Return a null MTRT_Device. This should be used where MTRT_Device input /// arguments are optional in functions below. -static inline MTRT_Device mtrtDeviceGetNull() { return MTRT_Device{nullptr}; } +static inline MTRT_Device mtrtDeviceGetNull() { return MTRT_Device{NULL}; } //===----------------------------------------------------------------------===// // MTRT_MemRefValue @@ -215,6 +215,11 @@ static inline bool mtrtRuntimeClientIsNull(MTRT_RuntimeClient client) { return !client.ptr; } +/// Returns null client. +static inline MTRT_RuntimeClient mtrtRuntimeClientGetNull() { + return MTRT_RuntimeClient{NULL}; +} + /// Creates a `MTRT_RuntimeClient`. Client must be alive for the lifetime of the /// program execution. /// The `stream` passed to the client is used by all underlying CUDA methods @@ -308,6 +313,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) { return !value.ptr; } +// Returns whether the RuntimeValue is MemRef. +MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value); + +// Returns whether the RuntimeValue is Scalar. +MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value); + /// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue. MLIR_CAPI_EXPORTED MTRT_RuntimeValue mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref); @@ -338,6 +349,9 @@ mtrtScalarValueCastToRuntimeValue(MTRT_ScalarValue v); MLIR_CAPI_EXPORTED MTRT_Status mtrtScalarValueGetType(MTRT_ScalarValue scalar, MTRT_ScalarTypeCode *code); +MLIR_CAPI_EXPORTED MTRT_Status mtrtScalarValueGet(MTRT_ScalarValue scalar, + int64_t *data); + //===----------------------------------------------------------------------===// // MTRT_RuntimeSessionOptions //===----------------------------------------------------------------------===// @@ -391,16 +405,27 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) { return !session.ptr; } -/// Using `session`, execute the pubic function with the specified name. -/// The `inArgs` and `outArgs` are arrays for input arguments and destination -/// arguments, respectively. Input arguments may be MemRefs or scalars, but -/// destination arguments must be MemRefs. +/// Using `session`, execute the public function with the specified name. +/// The `inArgs`, `outArgs`, and `results` are arrays for input arguments, +/// output arguments, and return values, respectively. Arguments and results +/// can be MemRefs, scalars, or other supported types. Both `outArgs` and +/// `results` can be used simultaneously, allowing for functions that both +/// modify arguments and return values. /// A stream may optionally be specified, otherwise pass the result of /// `mtrtStreamGetNull()`. +/// +/// The `results` array must point to an array with at least the number of +/// elements returned by mtrtRuntimeSessionGetNumResults for the given function. MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction( MTRT_RuntimeSession session, MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs, - const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream); + const MTRT_RuntimeValue *outArgs, size_t numOutArgs, + MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client); + +/// Return number of results given a function name. Function name refers +/// to an exported function in the executable. +MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNumResults( + MTRT_RuntimeSession session, MTRT_StringView name, int64_t *numResults); //===----------------------------------------------------------------------===// // DLPack diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h index 96aec51db..672a3f915 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h @@ -427,7 +427,7 @@ class ExecutableView { /// Return a function by name. This asserts that the function with the given /// name exists. - FunctionView getFunction(std::string_view name) const; + StatusOr getFunction(std::string_view name) const; ConstantView getConstant(int64_t idx) const { assert(view->constants() && "expected valid constant pointer"); diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h index 88c616bc7..a58b1d022 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h @@ -104,13 +104,6 @@ executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name, std::optional stream = {}, std::optional client = {}); -// Parses the results of a function call, handling both scalar and MemRef return -// types -StatusOr>> -parseResults(const sol::protected_function_result &pfr, - const FunctionSignatureView &sig, - std::optional client); - } // namespace mlirtrt::runtime #endif // MLIR_TENSORRT_RUNTIME_BACKEND_LUA_LUARUNTIME_H diff --git a/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h b/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h index cf5bd3169..7df229832 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h @@ -104,6 +104,9 @@ class PinnedMemoryAllocator { PinnedMemoryAllocator(); ~PinnedMemoryAllocator(); + /// Untracks + void untrack(uintptr_t ptr); + StatusOr allocate(size_t size); /// Free the block associated with the given pointer on the given stream. An @@ -114,6 +117,9 @@ class PinnedMemoryAllocator { private: EventPool eventPool; + /// Tracks all the pointers which need not to freed up. + static std::vector untrackedPtrs; + /// Tracks all blocks allocated by the allocator. struct BlockTracker; std::unique_ptr blockTracker; diff --git a/mlir-tensorrt/executor/lib/CAPI/Common/Common.cpp b/mlir-tensorrt/executor/lib/CAPI/Common/Common.cpp index 46bd564e3..d66659c36 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Common/Common.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Common/Common.cpp @@ -344,7 +344,7 @@ MTRT_Status mtrtBoundsGetMax(MTRT_Bounds bounds, MTRT_ArrayRefI64 *maxBounds) { MTRT_FunctionSignature mtrtGetFunctionSignature(MTRT_Executable exec, const char *name) { auto sig = const_cast( - unwrap(exec)->getFunction(name).getSignature().view); + (*unwrap(exec)->getFunction(name)).getSignature().view); return wrap(sig); } diff --git a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp index fa8bc850e..d2a1d8f62 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp @@ -37,6 +37,17 @@ #include "cuda_runtime_api.h" #endif +#if defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#endif +#include "cuda_bf16.h" +#include "cuda_fp16.h" +#include "cuda_fp8.h" +#if defined(__clang__) +#pragma GCC diagnostic pop +#endif + struct MTRT_StreamImpl; #define DEFINE_C_API_PTR_METHODS(name, cpptype) \ @@ -682,6 +693,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) { return wrap(static_cast(x)); } +bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) { + RuntimeValue *x = unwrap(value); + return x->getKind() == RuntimeValue::Kind::MemRef; +} + +bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) { + RuntimeValue *x = unwrap(value); + return x->getKind() == RuntimeValue::Kind::Scalar; +} + //===----------------------------------------------------------------------===// // MTRT_RuntimeSessionOptions //===----------------------------------------------------------------------===// @@ -728,7 +749,8 @@ MTRT_Status mtrtRuntimeSessionDestroy(MTRT_RuntimeSession session) { MTRT_Status mtrtRuntimeSessionExecuteFunction( MTRT_RuntimeSession session, MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs, - const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream) { + const MTRT_RuntimeValue *outArgs, size_t numOutArgs, + MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client) { LuaRuntimeSession *cppSession = static_cast(unwrap(session)); @@ -738,19 +760,38 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction( llvm::SmallVector outArgValues = llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs), [](MTRT_RuntimeValue arg) { return unwrap(arg); }); - - StatusOr>> result = + StatusOr>> resultValues = executeFunctionWithLuaBackend( *cppSession, std::string_view(name.data, name.length), inArgValues, outArgValues, !mtrtStreamIsNull(stream) ? std::optional(unwrap(stream)->getRawStream()) - : std::nullopt); - if (!result.isOk()) - return wrap(result.getStatus()); + : std::nullopt, + !mtrtRuntimeClientIsNull(client) ? std::optional(unwrap(client)) + : std::nullopt); + if (!resultValues.isOk()) + return wrap(resultValues.getStatus()); + + for (size_t i = 0; i < resultValues->size(); ++i) + results[i] = wrap((*resultValues)[i].release()); return mtrtStatusGetOk(); } + +MTRT_Status mtrtRuntimeSessionGetNumResults(MTRT_RuntimeSession session, + MTRT_StringView name, + int64_t *numResults) { + LuaRuntimeSession *cppSession = + static_cast(unwrap(session)); + StatusOr func = cppSession->getExecutable().getFunction( + std::string_view(name.data, name.length)); + if (func.isError()) { + return wrap(func.getStatus()); + } + *numResults = (*func).getSignature().getNumResults(); + return mtrtStatusGetOk(); +} + //===----------------------------------------------------------------------===// // MTRT_RuntimeClient //===----------------------------------------------------------------------===// @@ -796,3 +837,51 @@ MTRT_Status mtrtScalarValueGetType(MTRT_ScalarValue scalar, *code = static_cast(cppScalar->getType().getCode()); return mtrtStatusGetOk(); } + +MTRT_Status mtrtScalarValueGet(MTRT_ScalarValue scalar, int64_t *data) { + ScalarValue *cppScalar = unwrap(scalar); + ScalarTypeCode code = cppScalar->getType().getCode(); + switch (code) { + case ScalarTypeCode::f8e4m3fn: + *data = static_cast(cppScalar->get<__nv_fp8_e4m3>()); + break; + case ScalarTypeCode::f16: + *data = static_cast(cppScalar->get<__half>()); + break; + case ScalarTypeCode::bf16: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::f32: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::f64: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i1: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i4: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i8: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::ui8: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i16: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i32: + *data = static_cast(cppScalar->get()); + break; + case ScalarTypeCode::i64: + *data = cppScalar->get(); + break; + default: + return wrap(getInvalidArgStatus( + "function input argument with scalar type {0} is unsupported", + impl::EnumNameScalarTypeCode(code))); + } + return mtrtStatusGetOk(); +} diff --git a/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp b/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp index 90e4c9681..34e69188d 100644 --- a/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp +++ b/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp @@ -548,6 +548,7 @@ void executor::populateMemRefToExecutorPatterns( } namespace { + /// Pass to convert `memref` to `executor` dialect operrations. class ConvertMemRefToExecutorPass : public mlir::executor::impl::ConvertMemRefToExecutorPassBase< @@ -579,6 +580,7 @@ class ConvertMemRefToExecutorPass RewritePatternSet patterns(ctx); executor::populateMemRefToExecutorPatterns( patterns, typeConverter, allowUncheckedMemrefCastConversion); + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp index 661c11b4e..4a13ef80a 100644 --- a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp @@ -149,14 +149,17 @@ static bool isHostVisible(PointerType type) { // ExecutableView //===----------------------------------------------------------------------===// -FunctionView ExecutableView::getFunction(std::string_view name) const { +StatusOr +ExecutableView::getFunction(std::string_view name) const { const flatbuffers::Vector> &functions = *view->functions(); auto it = std::find_if(functions.begin(), functions.end(), [&](const impl::Function *x) { return x->name()->string_view() == name; }); - assert(it != view->functions()->end()); + if (it == view->functions()->end()) + return getStatusWithMsg(StatusCode::InvalidArgument, "Function with name (", + name, ") is not present in the executable"); return FunctionView(*it); } @@ -367,6 +370,7 @@ RuntimeSession::RuntimeSession(RuntimeSessionOptions options, //===----------------------------------------------------------------------===// AllocTracker::~AllocTracker() { + MTRT_DBGF("Destroying alloc tracker %p", static_cast(this)); MTRT_DBGF("checking %u allocations", map.size()); llvm::SmallVector ptrsToFree; ptrsToFree.reserve(map.size()); @@ -452,12 +456,19 @@ void AllocTracker::track(PointerInfo info) { // (e.g. function argument), in which case it may have been deallocated, // allowing an internal allocator to pick up that same address. That case is // not an error. - assert((!contains(info.ptr) || get(info.ptr).isExternallyManaged()) && - "an internally managed pointer should not already be tracked"); + if (contains(info.ptr) and get(info.ptr).isInternallyManaged()) { + MTRT_DBGF("Allocator %p: Internally managed pointer 0x%lx should not be " + "already tracked", + static_cast(this), info.ptr); + assert(0 && + "an internally managed pointer should not already be tracked"); + } } - MTRT_DBGF("AllocTracker is now tracking 0x%lx size=%lu space=%s ownership=%s", - info.ptr, info.size, runtime::impl::EnumNamePointerType(info.type), - runtime::impl::EnumNamePointerOwner(info.owner)); + MTRT_DBGF( + "AllocTracker %p is now tracking 0x%lx size=%lx space=%s ownership=%s", + static_cast(this), info.ptr, info.size, + runtime::impl::EnumNamePointerType(info.type), + runtime::impl::EnumNamePointerOwner(info.owner)); auto value = std::make_unique(); value->externalReferenceCount.store(0); value->releasedInternally = false; @@ -487,6 +498,8 @@ void AllocTracker::track(PointerInfo info) { } void AllocTracker::untrack(uintptr_t ptr) { + MTRT_DBGF("AllocTracker %p is now untracking 0x%lx)", + static_cast(this), ptr); assert(llvm::is_contained(map, ptr) && llvm::formatv("Untracked pointer {0}", ptr).str().c_str()); map.erase(map.find(ptr)); @@ -596,7 +609,7 @@ mlirtrt::Status runtime::safeDeallocate(AllocTracker &tracker, uintptr_t ptr, PointerInfo obj = tracker.get(ptr); if (obj.owner == PointerOwner::external) { - MTRT_DBGF("Untracking externally managed pointer 0x%lx", ptr); + MTRT_DBGF("Untracking externally managed 0x%lx", ptr); tracker.untrack(obj.ptr); return mlirtrt::Status::getOk(); } @@ -747,9 +760,16 @@ StatusOr> MemRefValue::create( if (!::getFootprintInBytes(shape, strides, bitsPerElement).isOk()) return getInvalidArgStatus( "only memrefs with non-negative strides are allowed"); - if (!ptr) + + auto isEmptyTensor = [](llvm::ArrayRef shape) -> bool { + return std::any_of(shape.begin(), shape.end(), + [](int64_t s) { return s == 0; }); + }; + + if (!ptr && !isEmptyTensor(shape)) return getInvalidArgStatus( - "MemRef objects must be created with a valid pointer"); + "MemRef objects must be created with a valid pointer for a non-empty " + "tensor"); if (isDeviceVisible(addressSpace) && (!device || !*device)) return getInvalidArgStatus("a specific device must be provided for MemRefs " diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp index aea894610..0ff8f0406 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp @@ -565,37 +565,45 @@ getScalarValue(const sol::protected_function_result &pfr, int index, } } -// Parses the results of a function call, handling both scalar and MemRef return -// types -StatusOr>> -runtime::parseResults(const sol::protected_function_result &pfr, - const FunctionSignatureView &sig, - std::optional client) { +/// Parses the results of a function call, handling both scalar and MemRef +/// return types. +/// +/// @param pfr The protected function result to parse. +/// @param sig The function signature view. +/// @param sessionAllocTracker The allocation tracker for the current session. +/// @param client Optional runtime client pointer. +/// @return A vector of unique pointers to RuntimeValue, or an error status. +static StatusOr>> +parseResults(const sol::protected_function_result &pfr, + const FunctionSignatureView &sig, LuaRuntimeSession &session, + std::optional client) { llvm::SmallVector> results; + results.reserve(sig.getNumResults()); + for (unsigned i = 0; i < sig.getNumResults(); ++i) { + const auto &resultType = sig.getResult(i); - if (sig.getResult(i).isa()) { - auto scalar = getScalarValue(pfr, i, sig); - if (!scalar.isOk()) - return scalar.getStatus(); - results.push_back(std::move(*scalar)); + if (resultType.isa()) { + auto scalarValue = getScalarValue(pfr, i, sig); + if (!scalarValue.isOk()) + return scalarValue.getStatus(); + results.push_back(std::move(*scalarValue)); continue; } - MemRefTableReader reader(pfr, i); - - if (!sig.getResult(i).isa()) + if (!resultType.isa()) return getInvalidArgStatus("Result can only be a memref or scalar"); // Handle MemRef return values - const auto &resultView = sig.getResult(i).get(); - unsigned rank = resultView.getRank(); + const auto &memRefView = resultType.get(); + MemRefTableReader reader(pfr, i); // Extract MemRef metadata uintptr_t allocPtr = reader.getNextValue(); [[maybe_unused]] uintptr_t alignedPtr = reader.getNextValue(); int64_t offset = reader.getNextValue(); + unsigned rank = memRefView.getRank(); llvm::SmallVector shape(rank); llvm::SmallVector strides(rank); @@ -608,15 +616,43 @@ runtime::parseResults(const sol::protected_function_result &pfr, if (!client) return getInvalidArgStatus("Runtime client cannot be nullptr"); - // Create MemRefValue from extracted data - auto memref = (*client)->createExternalMemRef( - resultView.getAddressSpace(), resultView.getElementType().getBitWidth(), - allocPtr, offset, shape, strides, (*client)->getDevices()[0].get(), - resultView.getElementType()); + // Create an external MemRef and track it in both session and client + // allocation trackers + MTRT_DBGF("Creating external MemRef for ptr 0x%lx: " + "Session alloc tracker: %p, Session pinner memory allocator: %p, " + "Client: %p, Client tracker: %p. " + "This ptr is registered with the session and will now be tracked " + "by the client as well.", + allocPtr, static_cast(&session.getAllocTracker()), + static_cast(&session.getPinnedMemorAllocator()), + static_cast(*client), + static_cast(&(*client)->getAllocTracker())); + + // We need here actually is to "release" the pointer from the session + // ownership and have the client assume + PointerInfo info = session.getAllocTracker().get(allocPtr); + session.getAllocTracker().untrack(info.ptr); + + // It is possible that pinned memory also tracks the memory for + // deallocation. + session.getPinnedMemorAllocator().untrack(info.ptr); + + AllocTracker &allocator = (*client)->getAllocTracker(); + // if (!allocator.contains(info.ptr)) + allocator.track(info); + + // Create a memref so that client now tracks it. + auto memref = MemRefValue::create( + *client, memRefView.getAddressSpace(), + memRefView.getElementType().getBitWidth(), allocPtr, offset, shape, + strides, (*client)->getDevices()[0].get(), memRefView.getElementType()); if (!memref.isOk()) return memref.getStatus(); + // Increment external reference count since we are returning a memref + allocator.incrementExternalCount(info.ptr); + results.push_back(std::move(*memref)); } @@ -630,8 +666,11 @@ runtime::executeFunctionWithLuaBackend( llvm::ArrayRef outputArgs, std::optional stream, std::optional client) { - FunctionView meta = session.getExecutable().getFunction(name); - FunctionSignatureView sig = meta.getSignature(); + StatusOr func = session.getExecutable().getFunction(name); + if (func.isError()) + return func.getStatus(); + + FunctionSignatureView sig = (*func).getSignature(); // Call the main function, if present. sol::state &lua = session.getLuaState(); @@ -715,5 +754,5 @@ runtime::executeFunctionWithLuaBackend( "\": ", err.what()); } - return parseResults(pfr, sig, client); + return parseResults(pfr, sig, session, client); } diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp index b24355eb3..943493f99 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp @@ -141,9 +141,12 @@ class OutputAllocatorImpl : public nvinfer1::IOutputAllocator { size = std::max(size, static_cast(1)); if (size > mOutputSize) { size = roundUp(size, alignment); - if (mOutputPtr) + if (mOutputPtr) { + MTRT_DBGF("tensorrt module output allocator deallocating 0x%lx", + mOutputPtr); mlirtrt::runtime::safeDeallocate(*mTracker, mOutputPtr, CudaStreamPtr(stream)); + } mOutputPtr = 0; mOutputSize = 0; StatusOr memory = @@ -152,6 +155,9 @@ class OutputAllocatorImpl : public nvinfer1::IOutputAllocator { if (memory.isOk()) { mOutputPtr = (*memory).ptr; mOutputSize = memory->size; + MTRT_DBGF( + "tensorrt module output allocator allocating %lu bytes at 0x%lx", + mOutputSize, mOutputPtr); } return reinterpret_cast(mOutputPtr); } diff --git a/mlir-tensorrt/executor/lib/Support/Allocators.cpp b/mlir-tensorrt/executor/lib/Support/Allocators.cpp index ce7310ad7..1a542e187 100644 --- a/mlir-tensorrt/executor/lib/Support/Allocators.cpp +++ b/mlir-tensorrt/executor/lib/Support/Allocators.cpp @@ -206,6 +206,8 @@ static void cudaFreeHostWrapper(uintptr_t ptr) { #endif } +std::vector PinnedMemoryAllocator::untrackedPtrs; + struct PinnedMemoryAllocator::BlockTracker { std::set blocks; llvm::DenseMap pointerToBlock; @@ -216,9 +218,13 @@ struct PinnedMemoryAllocator::BlockTracker { "[PinnedMemoryAllocator] Releasing block tracker that has %lu blocks", blocks.size()); for (Block *block : blocks) { - ALLOC_DBGF("[PinnedMemoryAllocator] releasing block %lu of size %lu", - block->ptr, block->size); - cudaFreeHostWrapper(block->ptr); + if (std::find(PinnedMemoryAllocator::untrackedPtrs.begin(), + PinnedMemoryAllocator::untrackedPtrs.end(), + block->ptr) == PinnedMemoryAllocator::untrackedPtrs.end()) { + ALLOC_DBGF("[PinnedMemoryAllocator] releasing block %lu of size %lu", + block->ptr, block->size); + cudaFreeHostWrapper(block->ptr); + } } } }; @@ -269,6 +275,13 @@ StatusOr PinnedMemoryAllocator::allocate(size_t size) { #endif } +// Free the given block. +void PinnedMemoryAllocator::untrack(uintptr_t ptr) { + if (!llvm::is_contained(untrackedPtrs, ptr)) { + untrackedPtrs.emplace_back(ptr); + } +} + // Free the given block. Status PinnedMemoryAllocator::freeAsync(uintptr_t ptr, CudaStream stream) { #ifdef MLIR_EXECUTOR_ENABLE_CUDA @@ -296,4 +309,4 @@ Status PinnedMemoryAllocator::freeAsync(uintptr_t ptr, CudaStream stream) { return getInternalErrorStatus( "MLIR-Executor was not built with CUDA enabled"); #endif -} \ No newline at end of file +} diff --git a/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp b/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp index 9e8a9e50a..869373d44 100644 --- a/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp +++ b/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp @@ -53,6 +53,7 @@ class ExecutorBufferizationTestPass } } }; + } // namespace namespace mlir::executor { diff --git a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp index ddb0c74e2..f35085b85 100644 --- a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp +++ b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp @@ -244,7 +244,7 @@ class PyRuntimeClient using Base::Base; DECLARE_WRAPPER_CONSTRUCTORS(PyRuntimeClient); - static constexpr auto kMethodTable = CAPITable{ + static constexpr auto kMethodTable = CAPITable{ mtrtRuntimeClientIsNull, mtrtRuntimeClientDestroy}; }; @@ -600,6 +600,15 @@ static MTRT_RuntimeValue convertArgType(py::object obj) { throw std::runtime_error("argument must be MemRef or scalar"); } +/// Convert Runtime value to PyMemRefValue or PyScalarValue object. +static py::object convertGenericArgToPyObject(MTRT_RuntimeValue value) { + if (mtrtRuntimeValueIsMemRef(value)) + return py::cast(mtrtRuntimeValueDynCastToMemRef(value)); + if (mtrtRuntimeValueIsScalar(value)) + return py::cast(mtrtRuntimeValueDynCastToScalar(value)); + throw std::runtime_error("argument must be MemRef or scalar"); +} + //===----------------------------------------------------------------------===// // Declare the bindings. //===----------------------------------------------------------------------===// @@ -615,11 +624,19 @@ PYBIND11_MODULE(_api, m) { py::buffer_protocol()) .def_property_readonly(MTRT_PYTHON_CAPI_PTR_ATTR, &PyScalarValue::getCapsule) - .def_property_readonly("type", [](PyScalarValue &self) { - MTRT_ScalarTypeCode code; - MTRT_Status s = mtrtScalarValueGetType(self, &code); + .def_property_readonly("type", + [](PyScalarValue &self) { + MTRT_ScalarTypeCode code; + MTRT_Status s = + mtrtScalarValueGetType(self, &code); + THROW_IF_MTRT_ERROR(s); + return code; + }) + .def_property_readonly("data", [](PyScalarValue &self) { + int64_t data; + MTRT_Status s = mtrtScalarValueGet(self, &data); THROW_IF_MTRT_ERROR(s); - return code; + return data; }); py::class_(m, "MemRefValue", py::module_local(), py::buffer_protocol()) @@ -950,22 +967,45 @@ PYBIND11_MODULE(_api, m) { .def( "execute_function", [](PyRuntimeSession &self, std::string name, - std::vector inArgs, std::vector outArgs, - std::optional stream) { + std::vector inArgs, + std::optional> outArgs, + std::optional stream, + PyRuntimeClient *client = nullptr) { MTRT_StringView nameRef{name.data(), name.size()}; + int64_t numResults; + MTRT_Status s = + mtrtRuntimeSessionGetNumResults(self, nameRef, &numResults); + THROW_IF_MTRT_ERROR(s); + auto inArgsGeneric = llvm::map_to_vector(inArgs, convertArgType); - auto outArgsGeneric = llvm::map_to_vector(outArgs, convertArgType); + auto outArgsGeneric = + outArgs ? llvm::map_to_vector(*outArgs, convertArgType) + : llvm::SmallVector{}; + + std::vector resultsGeneric(numResults); - MTRT_Status s = mtrtRuntimeSessionExecuteFunction( + s = mtrtRuntimeSessionExecuteFunction( self, nameRef, inArgsGeneric.data(), inArgsGeneric.size(), outArgsGeneric.data(), outArgsGeneric.size(), - stream ? *stream : mtrtStreamGetNull()); + resultsGeneric.data(), stream ? *stream : mtrtStreamGetNull(), + client ? MTRT_RuntimeClient(*client) + : mtrtRuntimeClientGetNull()); THROW_IF_MTRT_ERROR(s); - }, - py::arg("name"), py::arg("in_args"), py::arg("out_args"), - py::arg("stream") = py::none()); + std::vector resultPyObject; + if (numResults > 0) { + for (const auto &arg : resultsGeneric) + resultPyObject.push_back(convertGenericArgToPyObject(arg)); + } + + return resultPyObject; + }, + py::arg("name"), py::arg("in_args"), py::arg("out_args") = py::none(), + py::arg("stream") = py::none(), py::arg("client") = nullptr, + "Execute a function given input and optional output arguments. " + "Return optional results as a Python object if output arguments are " + "not present."); py::class_(m, "GlobalDebug", py::module_local()) .def_property_static("flag", &PyGlobalDebugFlag::get, &PyGlobalDebugFlag::set, "LLVM-wide debug flag") @@ -977,4 +1017,4 @@ PYBIND11_MODULE(_api, m) { py::overload_cast &>( &PyGlobalDebugFlag::set_types), "Sets specific debug types to be produced by LLVM"); -} \ No newline at end of file +} diff --git a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp index 6ae989205..bea4391b1 100644 --- a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp +++ b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp @@ -520,25 +520,25 @@ static void packNonSplatInt4Tensor(ElementsAttr values, int64_t count, } } -static void serializeSplatElements(DenseIntOrFPElementsAttr values, - std::vector &data) { +static LogicalResult serializeSplatElements(DenseIntOrFPElementsAttr values, + std::vector &data) { assert(values.isSplat() && "expected SplatElementsAttr"); auto rtt = cast(values.getType()); if (rtt.getElementType().isInteger(32)) { std::fill_n(reinterpret_cast(data.data()), values.getNumElements(), values.getSplatValue()); - return; + return llvm::success(); } if (rtt.getElementType().isInteger(8)) { std::fill_n(reinterpret_cast(data.data()), values.getNumElements(), values.getSplatValue()); - return; + return llvm::success(); } if (rtt.getElementType().isF32()) { std::fill_n(reinterpret_cast(data.data()), values.getNumElements(), values.getSplatValue()); - return; + return llvm::success(); } if (rtt.getElementType().isF16() || rtt.getElementType().isBF16()) { APInt tmp = values.getSplatValue().bitcastToAPInt(); @@ -546,7 +546,7 @@ static void serializeSplatElements(DenseIntOrFPElementsAttr values, uint16_t fillValue = *reinterpret_cast(tmp.getRawData()); std::fill_n(reinterpret_cast(data.data()), values.getNumElements(), fillValue); - return; + return llvm::success(); } if (rtt.getElementType().isFloat8E4M3FN()) { APInt tmp = values.getSplatValue().bitcastToAPInt(); @@ -554,7 +554,7 @@ static void serializeSplatElements(DenseIntOrFPElementsAttr values, uint8_t fillValue = *reinterpret_cast(tmp.getRawData()); std::fill_n(reinterpret_cast(data.data()), values.getNumElements(), fillValue); - return; + return llvm::success(); } if (rtt.getElementType().isInteger(4)) { APInt tmp = values.getSplatValue(); @@ -566,11 +566,12 @@ static void serializeSplatElements(DenseIntOrFPElementsAttr values, packed |= ((value & 0x0F) << 4); // Fill `data` vector with `packed` std::fill_n(reinterpret_cast(data.data()), data.size(), packed); - return; + return llvm::success(); } - llvm_unreachable("unsupported data type to convert MLIR splat attribute to " - "TensorRT weights!"); + return emitError(UnknownLoc::get(values.getContext())) + << "unsupported data type to convert MLIR splat attribute to TensorRT " + "weights!"; } FailureOr @@ -615,8 +616,10 @@ NvInferNetworkEncoder::getNvInferWeights(ElementsAttr values) { weights.values = data.data(); if (values.isSplat() && isa(values)) { - serializeSplatElements(cast(values), - weightsMap[values]); + LogicalResult status = serializeSplatElements( + cast(values), weightsMap[values]); + if (failed(status)) + return failure(); return weights; } diff --git a/mlir-tensorrt/test/Conversion/TensorRTRuntimeToExecutor/tensorrt-runtime-to-executor.mlir b/mlir-tensorrt/test/Conversion/TensorRTRuntimeToExecutor/tensorrt-runtime-to-executor.mlir index 55e7535d0..28b8900c8 100644 --- a/mlir-tensorrt/test/Conversion/TensorRTRuntimeToExecutor/tensorrt-runtime-to-executor.mlir +++ b/mlir-tensorrt/test/Conversion/TensorRTRuntimeToExecutor/tensorrt-runtime-to-executor.mlir @@ -102,35 +102,93 @@ func.func @main(%arg0: memref<1x3x256x256xf32, #executor.memory_type>) - // CHECK: %[[v8:.*]] = cuda.stream.create : !cuda.stream // CHECK: %[[v9:.*]] = builtin.unrealized_conversion_cast %[[v8]] : !cuda.stream to !executor.ptr // CHECK: %[[v10:.*]] = executor.alloca %[[c1]] x !executor.table : (i64) -> !executor.ptr -// CHECK: %[[v11:.*]] = executor.getoffset[0, 0] : () -> i64, !executor.table -// CHECK: executor.store %[[c1]] to %[[v10]] + %[[v11]] : i64, !executor.ptr, i64 -// CHECK: %[[v12:.*]] = executor.getoffset[0, 1] : () -> i64, !executor.table -// CHECK: executor.store %[[c4]] to %[[v10]] + %[[v12]] : i64, !executor.ptr, i64 +// CHECK: %[[v11:.*]] = executor.getoffset[0, 0] +// CHECK: executor.store %[[c1]] to %[[v10]] + %[[v11]] +// CHECK: %[[v12:.*]] = executor.getoffset[0, 1] +// CHECK: executor.store %[[c4]] to %[[v10]] + %[[v12]] // CHECK: %[[v13:.*]] = executor.table.get %[[v6]][1] : , !executor.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64> // CHECK: %[[v14:.*]] = executor.table.create(%[[v13]], %[[c0]], %[[c4]], %[[c1]], %[[c3]], %[[c256]], %[[c256]] : !executor.ptr, i64, i64, i64, i64, i64, i64) : , i64, i64, i64, i64, i64, i64> // CHECK: executor.call @_trtrt_enqueue_alloc(%[[v7]], %[[v9]], %[[v10]], %[[v14]]) : (!executor.opaque<"trtrt_context">, !executor.ptr, !executor.ptr, !executor.table, i64, i64, i64, i64, i64, i64>) -> () -// CHECK: %[[v15:.*]] = executor.getoffset[0, 2] : () -> i64, !executor.table -// CHECK: %[[v16:.*]] = executor.load %[[v10]] + %[[v12]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v17:.*]] = executor.load %[[v10]] + %[[v15]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v18:.*]] = executor.inttoptr %[[v17]] : (i64) -> !executor.ptr -// CHECK: %[[v19:.*]] = executor.getoffset[0, 3] : () -> i64, !executor.table -// CHECK: %[[v20:.*]] = executor.load %[[v10]] + %[[v19]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v21:.*]] = executor.getoffset[0, 4] : () -> i64, !executor.table -// CHECK: %[[v22:.*]] = executor.load %[[v10]] + %[[v21]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v23:.*]] = executor.getoffset[0, 5] : () -> i64, !executor.table -// CHECK: %[[v24:.*]] = executor.load %[[v10]] + %[[v23]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v25:.*]] = executor.getoffset[0, 6] : () -> i64, !executor.table -// CHECK: %[[v26:.*]] = executor.load %[[v10]] + %[[v25]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v27:.*]] = executor.getoffset[0, 7] : () -> i64, !executor.table -// CHECK: %[[v28:.*]] = executor.load %[[v10]] + %[[v27]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v29:.*]] = executor.getoffset[0, 8] : () -> i64, !executor.table -// CHECK: %[[v30:.*]] = executor.load %[[v10]] + %[[v29]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v31:.*]] = executor.getoffset[0, 9] : () -> i64, !executor.table -// CHECK: %[[v32:.*]] = executor.load %[[v10]] + %[[v31]] : (!executor.ptr, i64) -> i64 -// CHECK: %[[v33:.*]] = executor.getoffset[0, 10] : () -> i64, !executor.table -// CHECK: %[[v34:.*]] = executor.load %[[v10]] + %[[v33]] : (!executor.ptr, i64) -> i64 +// CHECK: %[[v15:.*]] = executor.getoffset[0, 2] +// CHECK: %[[v16:.*]] = executor.load %[[v10]] + %[[v12]] +// CHECK: %[[v17:.*]] = executor.load %[[v10]] + %[[v15]] +// CHECK: %[[v18:.*]] = executor.inttoptr %[[v17]] +// CHECK: %[[v19:.*]] = executor.getoffset[0, 3] +// CHECK: %[[v20:.*]] = executor.load %[[v10]] + %[[v19]] +// CHECK: %[[v21:.*]] = executor.getoffset[0, 4] +// CHECK: %[[v22:.*]] = executor.load %[[v10]] + %[[v21]] +// CHECK: %[[v23:.*]] = executor.getoffset[0, 5] +// CHECK: %[[v24:.*]] = executor.load %[[v10]] + %[[v23]] +// CHECK: %[[v25:.*]] = executor.getoffset[0, 6] +// CHECK: %[[v26:.*]] = executor.load %[[v10]] + %[[v25]] +// CHECK: %[[v27:.*]] = executor.getoffset[0, 7] +// CHECK: %[[v28:.*]] = executor.load %[[v10]] + %[[v27]] +// CHECK: %[[v29:.*]] = executor.getoffset[0, 8] +// CHECK: %[[v30:.*]] = executor.load %[[v10]] + %[[v29]] +// CHECK: %[[v31:.*]] = executor.getoffset[0, 9] +// CHECK: %[[v32:.*]] = executor.load %[[v10]] + %[[v31]] +// CHECK: %[[v33:.*]] = executor.getoffset[0, 10] +// CHECK: %[[v34:.*]] = executor.load %[[v10]] + %[[v33]] // CHECK: %[[v35:.*]] = executor.table.create(%[[v18]], %[[v18]], %[[c0]], %[[v20]], %[[v22]], %[[v24]], %[[v26]], %[[v28]], %[[v30]], %[[v32]], %[[v34]] : !executor.ptr, !executor.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64) : , !executor.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64> // CHECK: %[[v36:.*]] = builtin.unrealized_conversion_cast %[[v35]] : !executor.table, !executor.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64> to memref> // CHECK: cuda.stream.sync %[[v8]] : !cuda.stream // CHECK: return %[[v36]] : memref> -// CHECK: } \ No newline at end of file +// CHECK: } + +// ----- + +func.func @main(%arg0: memref, %arg1: memref, %context: !trtrt.context, %stream: !cuda.stream) -> (memref, memref) attributes {executor.function_metadata = #executor.func_meta<[memref {#executor.dim_bounds}, memref {#executor.dim_bounds}], [memref {#executor.dim_bounds}, memref {#executor.dim_bounds}], num_output_args = 0>} { + %2:2 = trtrt.enqueue_alloc %context stream(%stream) (%arg1, %arg0) : (memref, memref) -> (memref, memref) + return %2#0, %2#1 : memref, memref +} + +// CHECK-LABEL: module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry, #dlti.dl_entry, 64 : i64>, #dlti.dl_entry, 64 : i64>>} { +// executor.func private @_trtrt_enqueue_alloc(!executor.opaque<"trtrt_context">, !executor.ptr, !executor.ptr, ...) +// CHECK-LABEL: func.func @main +// CHECK-SAME: (%[[arg0:.+]]: memref, %[[arg1:.+]]: memref, %[[arg2:.+]]: !trtrt.context, %[[arg3:.+]]: !cuda.stream) -> (memref, memref) attributes {executor.function_metadata = #executor.func_meta<[memref {#executor.dim_bounds}, memref {#executor.dim_bounds}], [memref {#executor.dim_bounds}, memref {#executor.dim_bounds}], num_output_args = 0>} { +// CHECK-DAG: %[[c2:.+]] = executor.constant 2 : i64 +// CHECK-DAG: %[[c0:.+]] = executor.constant 0 : i64 +// CHECK-DAG: %[[c1:.+]] = executor.constant 1 : i64 +// CHECK: %[[v0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] +// CHECK: %[[v1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] +// CHECK: %[[v2:.+]] = builtin.unrealized_conversion_cast %[[arg3]] +// CHECK: %[[v3:.+]] = builtin.unrealized_conversion_cast %[[arg2]] +// CHECK: %[[v4:.+]] = executor.alloca %[[c1]] x !executor.table +// CHECK: %[[v5:.+]] = executor.getoffset[0, 0] +// CHECK: executor.store %[[c2]] to %[[v4]] + %[[v5]] +// CHECK: %[[v6:.+]] = executor.getoffset[0, 1] +// CHECK: executor.store %[[c1]] to %[[v4]] + %[[v6]] +// CHECK: %[[v7:.+]] = executor.getoffset[0, 5] +// CHECK: executor.store %[[c2]] to %[[v4]] + %[[v7]] +// CHECK: %[[v8:.+]] = executor.table.get %[[v1]][1] +// CHECK: %[[v9:.+]] = executor.table.get %[[v1]][3] +// CHECK: %[[v10:.+]] = executor.table.get %[[v1]][4] +// CHECK: %[[v11:.+]] = executor.table.get %[[v0]][1] +// CHECK: %[[v12:.+]] = executor.table.get %[[v0]][3] +// CHECK: %[[v13:.+]] = executor.table.create(%[[v8]], %[[c0]], %[[c2]], %[[v9]], %[[v10]], %[[v11]], %[[c0]], %[[c1]], %[[v12]] : !executor.ptr, i64, i64, i64, i64, !executor.ptr, i64, i64, i64) +// CHECK: executor.call @_trtrt_enqueue_alloc(%[[v3]], %[[v2]], %[[v4]], %[[v13]]) +// CHECK: %[[v14:.+]] = executor.getoffset[0, 2] +// CHECK: %[[v15:.+]] = executor.load %[[v4]] + %[[v6]] +// CHECK: %[[v16:.+]] = executor.load %[[v4]] + %[[v14]] +// CHECK: %[[v17:.+]] = executor.inttoptr %[[v16]] +// CHECK: %[[v18:.+]] = executor.getoffset[0, 3] +// CHECK: %[[v19:.+]] = executor.load %[[v4]] + %[[v18]] +// CHECK: %[[v20:.+]] = executor.getoffset[0, 4] +// CHECK: %[[v21:.+]] = executor.load %[[v4]] + %[[v20]] +// CHECK: %[[v22:.+]] = executor.table.create(%[[v17]], %[[v17]], %[[c0]], %[[v19]], %[[v21]] : !executor.ptr, !executor.ptr, i64, i64, i64) +// CHECK: %[[v23:.+]] = executor.getoffset[0, 6] +// CHECK: %[[v24:.+]] = executor.load %[[v4]] + %[[v7]] +// CHECK: %[[v25:.+]] = executor.load %[[v4]] + %[[v23]] +// CHECK: %[[v26:.+]] = executor.inttoptr %[[v25]] +// CHECK: %[[v27:.+]] = executor.getoffset[0, 7] +// CHECK: %[[v28:.+]] = executor.load %[[v4]] + %[[v27]] +// CHECK: %[[v29:.+]] = executor.getoffset[0, 8] +// CHECK: %[[v30:.+]] = executor.load %[[v4]] + %[[v29]] +// CHECK: %[[v31:.+]] = executor.getoffset[0, 9] +// CHECK: %[[v32:.+]] = executor.load %[[v4]] + %[[v31]] +// CHECK: %[[v33:.+]] = executor.getoffset[0, 10] +// CHECK: %[[v34:.+]] = executor.load %[[v4]] + %[[v33]] +// CHECK: %[[v35:.+]] = executor.table.create(%[[v26]], %[[v26]], %[[c0]], %[[v28]], %[[v30]], %[[v32]], %[[v34]] : !executor.ptr, !executor.ptr, i64, i64, i64, i64, i64) +// CHECK: %[[v36:.+]] = builtin.unrealized_conversion_cast %[[v35]] +// CHECK: %[[v37:.+]] = builtin.unrealized_conversion_cast %[[v22]] +// CHECK: return %[[v37]], %[[v36]] : memref, memref \ No newline at end of file diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_non_dps_cconv.py b/mlir-tensorrt/test/python/IntegrationTests/test_non_dps_cconv.py new file mode 100644 index 000000000..242275c29 --- /dev/null +++ b/mlir-tensorrt/test/python/IntegrationTests/test_non_dps_cconv.py @@ -0,0 +1,323 @@ +# RUN: %PYTHON %s +import time + +import mlir_tensorrt.compiler.api as compiler +import mlir_tensorrt.compiler.ir as ir +import mlir_tensorrt.runtime.api as runtime +import numpy as np + +single_return = """ +func.func @main(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %1 = stablehlo.add %arg0, %arg0 : (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + func.return %1 : tensor<2x3x4xf32> +} +""" + +scalar_return = """ +func.func @main(%arg0: tensor<2x3x4xf32>) -> index { + %1 = tensor.rank %arg0 : tensor<2x3x4xf32> + func.return %1 : index +} +""" + +mixed_return = """ +func.func @main(%arg0: tensor<2x3x4xf32>) -> (tensor<2x3x4xf32>, index) { + %1 = stablehlo.add %arg0, %arg0 : (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + %2 = tensor.rank %1 : tensor<2x3x4xf32> + func.return %1, %2 : tensor<2x3x4xf32>, index +} +""" + +multiple_return = """ +func.func @main(%arg0: tensor<2x3x4xf32>) -> (tensor<2x3x4xf32>, tensor<2x3x4xf32>) { + %1 = stablehlo.add %arg0, %arg0 : (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + %2 = stablehlo.add %arg0, %1 : (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + func.return %1, %2 : tensor<2x3x4xf32>, tensor<2x3x4xf32> +} +""" + +dynamic_shape = """ +func.func @main(%arg0: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}, + %arg1: tensor {tensorrt.shape_profile = #tensorrt.shape_profile}) + -> tensor { + %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor + %1 = stablehlo.reshape %0 : (tensor) -> tensor<1xi32> + %2 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor) -> tensor + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> + %4 = stablehlo.concatenate %1, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %5 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor) -> tensor + %6 = stablehlo.reshape %5 : (tensor) -> tensor<1xi32> + %7 = stablehlo.get_dimension_size %arg1, dim = 1 : (tensor) -> tensor + %8 = stablehlo.reshape %7 : (tensor) -> tensor<1xi32> + %9 = stablehlo.concatenate %6, %8, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %10 = stablehlo.maximum %4, %9 : tensor<2xi32> + %11 = stablehlo.dynamic_broadcast_in_dim %arg0, %10, dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %12 = stablehlo.dynamic_broadcast_in_dim %arg1, %10, dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %13 = stablehlo.add %11, %12 : tensor + return %13 : tensor +} +""" + +session_tracking_h2h = """ +func.func @main() -> (tensor> {tensorrt.host_tensor}) { + %c = stablehlo.constant dense<[1, 2]> : tensor<2xi32> + %0 = bufferization.alloc_tensor() {memory_space = #plan.memory_space} : tensor<2xi32, #plan.memory_space> + %1 = bufferization.materialize_in_destination %c in %0 : (tensor<2xi32>, tensor<2xi32, #plan.memory_space>) -> tensor<2xi32, #plan.memory_space> + %cast = tensor.cast %1 : tensor<2xi32, #plan.memory_space> to tensor> + return %cast : tensor> +} +""" + +empty_shape_tensor = """ +func.func @main() -> (tensor> {tensorrt.host_tensor}) { + %c = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> + %c_0 = stablehlo.constant dense<2> : tensor + %c_1 = stablehlo.constant dense<1> : tensor<1xi32> + %c_2 = stablehlo.constant dense<2> : tensor<1xi32> + %c_3 = stablehlo.constant dense<2> : tensor + %c_4 = stablehlo.constant dense<2> : tensor<1xi32> + %0 = stablehlo.concatenate %c_2, %c_4, %c_1, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + %1 = stablehlo.dynamic_reshape %c, %0 : (tensor<2x2xi32>, tensor<3xi32>) -> tensor + %c_5 = stablehlo.constant dense<2> : tensor + %c_6 = stablehlo.constant dense<2> : tensor<1xi32> + %c_7 = stablehlo.constant dense<2> : tensor + %c_8 = stablehlo.constant dense<2> : tensor<1xi32> + %c_9 = stablehlo.constant dense<0> : tensor + %c_10 = stablehlo.constant dense<0> : tensor<1xi32> + %2 = stablehlo.concatenate %c_6, %c_8, %c_10, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + %3 = stablehlo.dynamic_broadcast_in_dim %1, %2, dims = [0, 1, 2] : (tensor, tensor<3xi32>) -> tensor + %c_11 = stablehlo.constant dense<2> : tensor<1xi32> + %c_12 = stablehlo.constant dense<> : tensor<0xi32> + %c_13 = stablehlo.constant dense<> : tensor<0xi32> + %4 = stablehlo.compare EQ, %c_12, %c_13 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1> + %5 = stablehlo.select %4, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32> + %6 = stablehlo.dynamic_broadcast_in_dim %c_7, %5, dims = [] : (tensor, tensor<0xi32>) -> tensor + %7 = stablehlo.dynamic_broadcast_in_dim %c_9, %5, dims = [] : (tensor, tensor<0xi32>) -> tensor + %8 = stablehlo.multiply %6, %7 : tensor + %9 = stablehlo.reshape %8 : (tensor) -> tensor<1xi32> + %10 = stablehlo.concatenate %c_11, %9, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %11 = stablehlo.dynamic_reshape %3, %10 : (tensor, tensor<2xi32>) -> tensor + %c0 = arith.constant 0 : index + %dim = tensor.dim %11, %c0 : tensor + %c1 = arith.constant 1 : index + %dim_14 = tensor.dim %11, %c1 : tensor + %12 = bufferization.alloc_tensor(%dim, %dim_14) {memory_space = #plan.memory_space} : tensor> + %13 = bufferization.materialize_in_destination %11 in %12 : (tensor, tensor>) -> tensor> + %cast = tensor.cast %13 : tensor> to tensor> + return %cast : tensor> +} +""" + + +# The RuntimeClient can and should persist across multiple Executables, RuntimeSessions, etc. +# It is primarily an interface for creating and manipulating buffers. +client = runtime.RuntimeClient() +stream = client.create_stream() +devices = client.get_devices() + + +def compile_executable(program, debug=False): + # Build/parse the main function. + with ir.Context() as context: + m = ir.Module.parse(program) + + # Use the compiler API to compile to executable. + client = compiler.CompilerClient(context) + c_opts = [ + "--tensorrt-builder-opt-level=3", + "--tensorrt-strongly-typed=false", + "--entrypoint=main", + "--enable-non-dps-returns", + ] + opts = compiler.StableHLOToExecutableOptions(client, c_opts) + if debug: + opts.set_debug_options(False, [], "tmp") + exe = compiler.compiler_stablehlo_to_executable(client, m.operation, opts) + return exe + + +def test_single_return(): + exe = compile_executable(single_return) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + arg0 = client.create_memref( + np.arange(0.0, 24.0, dtype=np.float32).reshape(2, 3, 4).data, + device=devices[0], + stream=stream, + ) + results = session.execute_function( + "main", in_args=[arg0], stream=stream, client=client + ) + + output = np.asarray(client.copy_to_host(results[0], stream=stream)) + stream.sync() + + print(output) + + +def test_scalar_return(): + exe = compile_executable(scalar_return) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + arg0 = client.create_memref( + np.arange(0.0, 24.0, dtype=np.float32).reshape(2, 3, 4).data, + device=devices[0], + stream=stream, + ) + results = session.execute_function( + "main", in_args=[arg0], stream=stream, client=client + ) + + print(results[0].data) + + +def test_mixed_return(): + exe = compile_executable(mixed_return) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + arg0 = client.create_memref( + np.arange(0.0, 24.0, dtype=np.float32).reshape(2, 3, 4).data, + device=devices[0], + stream=stream, + ) + results = session.execute_function( + "main", in_args=[arg0], stream=stream, client=client + ) + + assert type(results[0]) == runtime.MemRefValue + assert type(results[1]) == runtime.ScalarValue + + output = np.asarray(client.copy_to_host(results[0], stream=stream)) + stream.sync() + + print(output) + print(results[1].data) + + +def test_multiple_return(): + exe = compile_executable(multiple_return) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + arg0 = client.create_memref( + np.arange(0.0, 24.0, dtype=np.float32).reshape(2, 3, 4).data, + device=devices[0], + stream=stream, + ) + results = session.execute_function( + "main", in_args=[arg0], stream=stream, client=client + ) + + output_0 = np.asarray(client.copy_to_host(results[0], stream=stream)) + output_1 = np.asarray(client.copy_to_host(results[1], stream=stream)) + + stream.sync() + + print(output_0) + print(output_1) + + +def test_dynamic_shape(): + exe = compile_executable(dynamic_shape) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + arg0 = client.create_memref( + np.arange(0.0, 8.0, dtype=np.float32).reshape((4, 2)).data, + device=devices[0], + stream=stream, + ) + arg1 = client.create_memref( + np.ones((4, 2), dtype=np.float32).data, device=devices[0], stream=stream + ) + + results = session.execute_function( + "main", in_args=[arg0, arg1], stream=stream, client=client + ) + + output = np.asarray(client.copy_to_host(results[0], stream=stream)) + stream.sync() + + print(output) + + +def test_session_tracking_d2h(): + exe = compile_executable(session_tracking_h2h) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + results = session.execute_function("main", in_args=[], stream=stream, client=client) + stream.sync() + print(np.asarray(results[0])) + + +def test_empty_shape_tensor(): + exe = compile_executable(empty_shape_tensor) + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + results = session.execute_function("main", in_args=[], stream=stream, client=client) + stream.sync() + print(np.asarray(results[0])) + + +if __name__ == "__main__": + print("Test: single return") + test_single_return() + # CHECK-LABEL: Test: single return + # CHECK: [[[ 0. 2. 4. 6.] + # CHECK: [ 8. 10. 12. 14.] + # CHECK: [16. 18. 20. 22.]] + # CHECK: + # CHECK: [[24. 26. 28. 30.] + # CHECK: [32. 34. 36. 38.] + # CHECK: [40. 42. 44. 46.]]] + + print("Test: multiple return") + test_multiple_return() + # CHECK-LABEL: Test: multiple return + # CHECK: [[[ 0. 2. 4. 6.] + # CHECK: [ 8. 10. 12. 14.] + # CHECK: [16. 18. 20. 22.]] + # CHECK: + # CHECK: [[24. 26. 28. 30.] + # CHECK: [32. 34. 36. 38.] + # CHECK: [40. 42. 44. 46.]]] + # CHECK: [[[ 0. 3. 6. 9.] + # CHECK: [12. 15. 18. 21.] + # CHECK: [24. 27. 30. 33.]] + # CHECK: + # CHECK: [[36. 39. 42. 45.] + # CHECK: [48. 51. 54. 57.] + # CHECK: [60. 63. 66. 69.]]] + + print("Test: dynamic shape") + test_dynamic_shape() + # CHECK-LABEL: Test: dynamic shape + # CHECK: [[1. 2.] + # CHECK: [3. 4.] + # CHECK: [5. 6.] + # CHECK: [7. 8.]] + + print("Test: device to host copy") + test_session_tracking_d2h() + # CHECK-LABEL: Test: device to host copy + # CHECK: [1 2] + + print("Test: empty shape tensor") + test_empty_shape_tensor() + # CHECK-LABEL: Test: empty shape tensor + # CHECK: [] + + print("Test: scalar return") + test_scalar_return() + # CHECK-LABEL: Test: scalar return + # CHECK: 3 + print("Test: mixed return") + + test_mixed_return() + # CHECK-LABEL: Test: mixed return + # CHECK: [[[ 0. 2. 4. 6.] + # CHECK: [ 8. 10. 12. 14.] + # CHECK: [16. 18. 20. 22.]] + # CHECK: + # CHECK: [[24. 26. 28. 30.] + # CHECK: [32. 34. 36. 38.] + # CHECK: [40. 42. 44. 46.]]] + # CHECK: 3 From 8710d9a674eb6c7c1dcc77491528213f96fa6695 Mon Sep 17 00:00:00 2001 From: Jhalak Patel Date: Sat, 9 Nov 2024 11:19:56 -0800 Subject: [PATCH 2/2] Fix memory leak --- .../include/mlir-executor/Runtime/API/API.h | 11 ++++++-- .../mlir-executor/Support/Allocators.h | 10 ++++--- .../executor/lib/CAPI/Runtime/Runtime.cpp | 4 +++ .../executor/lib/Runtime/API/API.cpp | 20 ++++++++++++-- .../lib/Runtime/Backend/Lua/LuaRuntime.cpp | 23 +++++++--------- .../Backend/Lua/Modules/CUDA/CUDAModule.cpp | 27 +++++++++++++++++++ .../Lua/Modules/TensorRT/TensorRTModule.cpp | 6 +++++ .../executor/lib/Support/Allocators.cpp | 18 +++++++------ .../python/bindings/Runtime/RuntimePyBind.cpp | 1 + 9 files changed, 92 insertions(+), 28 deletions(-) diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h index 672a3f915..61b0057a5 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h @@ -795,6 +795,12 @@ class AllocTracker { /// Returns true if the ptr is released internally. bool isReleasedInternally(uintptr_t ptr) const; + /// Mark pointer for release after consumption + void markForReleaseAfterConsumption(uintptr_t ptr); + + /// Check if pointer is marked for release after consumption + bool isMarkedForReleaseAfterConsumption(uintptr_t ptr); + private: struct Metadata { std::atomic externalReferenceCount = {0}; @@ -802,6 +808,7 @@ class AllocTracker { // if this is true then it should be truelly released and untracked // when decrementExternalCount causes count to go to zero bool releasedInternally{false}; + bool releaseAfterConsumption{false}; PointerInfo info; }; @@ -870,7 +877,7 @@ class RuntimeSession { ExecutableView getExecutable() const { return executable; } - PinnedMemoryAllocator &getPinnedMemorAllocator() { + PinnedMemoryAllocator &getPinnedMemoryAllocator() { return *pinnedMemoryAllocator; } @@ -968,7 +975,7 @@ class RuntimeClient { ResourceTracker &getResourceTracker() { return resourceTracker; } /// Return the PinnedMemoryAllocator. - PinnedMemoryAllocator &getPinnedMemorAllocator() { + PinnedMemoryAllocator &getPinnedMemoryAllocator() { return pinnedMemoryAllocator; } diff --git a/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h b/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h index 7df229832..d31b2de75 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h @@ -104,7 +104,11 @@ class PinnedMemoryAllocator { PinnedMemoryAllocator(); ~PinnedMemoryAllocator(); - /// Untracks + /// Marks a pointer as client-managed, deferring its deallocation + /// This method is used when a pinned memory pointer is returned to the client + /// and its lifecycle is no longer managed by the PinnedMemoryAllocator. + /// Pointers marked this way will not be automatically freed in the + /// allocator's destructor. void untrack(uintptr_t ptr); StatusOr allocate(size_t size); @@ -117,8 +121,8 @@ class PinnedMemoryAllocator { private: EventPool eventPool; - /// Tracks all the pointers which need not to freed up. - static std::vector untrackedPtrs; + /// Stores pointers to memory blocks that are now managed by the client. + static std::vector clientManagedPtrs; /// Tracks all blocks allocated by the allocator. struct BlockTracker; diff --git a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp index d2a1d8f62..2939e060d 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp @@ -27,6 +27,7 @@ #include "mlir-executor/Runtime/API/API.h" #include "mlir-executor/Runtime/API/ExecutableFlatbuffer.h" #include "mlir-executor/Runtime/Backend/Lua/LuaRuntime.h" +#include "mlir-executor/Runtime/Support/Support.h" #include "mlir-executor/Support/Status.h" #include "mlir/Support/FileUtilities.h" #include "llvm/Support/Debug.h" @@ -325,6 +326,8 @@ MTRT_Status mtrtMemRefCreateExternal( MTRT_Status mtrtMemRefValueDestroyAsync(MTRT_MemRefValue buffer, MTRT_Stream stream) { MemRefValue *memref = unwrap(buffer); + MTRT_DBGF("destroying memref pointer 0x%lx asynchronously", + memref->getMemory()); Status s = memref->getClient()->deallocate( std::unique_ptr(memref), mtrtStreamIsNull(stream) ? std::nullopt @@ -336,6 +339,7 @@ MTRT_Status mtrtMemRefValueDestroyAsync(MTRT_MemRefValue buffer, MTRT_Status mtrtMemRefValueDestroy(MTRT_MemRefValue buffer) { MemRefValue *memref = unwrap(buffer); + MTRT_DBGF("destroying memref pointer 0x%lx", memref->getMemory()); Status s = memref->getClient()->deallocate(std::unique_ptr(memref)); if (!s.isOk()) diff --git a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp index 4a13ef80a..e04d242d5 100644 --- a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp @@ -396,6 +396,20 @@ AllocTracker::~AllocTracker() { MTRT_DBGF("freed %zu bytes of unfreed memory", totalSize); } +void AllocTracker::markForReleaseAfterConsumption(uintptr_t ptr) { + assert(llvm::is_contained(map, ptr) && + llvm::formatv("Untracked pointer {0}", ptr).str().c_str()); + std::unique_ptr const &metadata = map.at(ptr); + metadata->releaseAfterConsumption = true; +} + +bool AllocTracker::isMarkedForReleaseAfterConsumption(uintptr_t ptr) { + assert(llvm::is_contained(map, ptr) && + llvm::formatv("Untracked pointer {0}", ptr).str().c_str()); + std::unique_ptr const &metadata = map.at(ptr); + return metadata->releaseAfterConsumption; +} + void AllocTracker::markReleasedInternally(uintptr_t ptr) { assert(llvm::is_contained(map, ptr) && llvm::formatv("Untracked pointer {0}", ptr).str().c_str()); @@ -472,6 +486,7 @@ void AllocTracker::track(PointerInfo info) { auto value = std::make_unique(); value->externalReferenceCount.store(0); value->releasedInternally = false; + value->releaseAfterConsumption = false; value->info = info; if (!contains(info.ptr)) { map.insert(std::make_pair(info.ptr, std::move(value))); @@ -669,6 +684,7 @@ ResourceTracker::~ResourceTracker() { void ResourceTracker::track(uintptr_t ptr, Deleter deleter) { assert(ptr && deleter && "expected valid ptr and deleter"); + MTRT_DBGF("tracking resource at 0x%lx", ptr); tracker.insert(std::make_pair(ptr, deleter)); } @@ -985,7 +1001,7 @@ RuntimeClient::copyToDevice(const MemRefValue &hostBufferImpl, // TODO: Currently, this implementation supports only row major packed // canonical layout (no padding). StatusOr pinnedMemory = - this->getPinnedMemorAllocator().allocate(totalBufferSize); + this->getPinnedMemoryAllocator().allocate(totalBufferSize); if (!pinnedMemory.isOk()) return pinnedMemory.getStatus(); @@ -1004,7 +1020,7 @@ RuntimeClient::copyToDevice(const MemRefValue &hostBufferImpl, reinterpret_cast(*cudaStream))); // Free pinned host memory asynchronously. - getPinnedMemorAllocator().freeAsync(pinnedMemory->ptr, *cudaStream); + getPinnedMemoryAllocator().freeAsync(pinnedMemory->ptr, *cudaStream); } else { MTRT_DBG("synchronously copying {0} (host) to {1} (device), size={2} bytes", hostBufferImpl.getVoidPtr(), (*deviceMemRef)->getVoidPtr(), diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp index 0ff8f0406..34d74a222 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp @@ -98,7 +98,7 @@ static void registerLuaRuntimeMethodsCommon( } void mlirtrt::runtime::registerLuaRuntimeMethods( - lua_State *state, const RuntimeSessionOptions &options, + lua_State *state, const RuntimeSessionOptions &options, PinnedMemoryAllocator *pinnedMemoryAllocator, AllocTracker *allocTracker, ResourceTracker *resourceTracker) { registerLuaRuntimeMethodsCommon(state, pinnedMemoryAllocator, allocTracker, @@ -152,7 +152,7 @@ LuaRuntimeSession::create(RuntimeSessionOptions options, // Register builtin methods. registerLuaRuntimeMethods(lua.lua_state(), session->getOptions(), - &session->getPinnedMemorAllocator(), + &session->getPinnedMemoryAllocator(), &session->getAllocTracker(), &session->getResourceTracker()); @@ -624,7 +624,7 @@ parseResults(const sol::protected_function_result &pfr, "This ptr is registered with the session and will now be tracked " "by the client as well.", allocPtr, static_cast(&session.getAllocTracker()), - static_cast(&session.getPinnedMemorAllocator()), + static_cast(&session.getPinnedMemoryAllocator()), static_cast(*client), static_cast(&(*client)->getAllocTracker())); @@ -632,14 +632,14 @@ parseResults(const sol::protected_function_result &pfr, // ownership and have the client assume PointerInfo info = session.getAllocTracker().get(allocPtr); session.getAllocTracker().untrack(info.ptr); + (*client)->getAllocTracker().track(info); - // It is possible that pinned memory also tracks the memory for - // deallocation. - session.getPinnedMemorAllocator().untrack(info.ptr); - - AllocTracker &allocator = (*client)->getAllocTracker(); - // if (!allocator.contains(info.ptr)) - allocator.track(info); + // Defer deallocation of this pinned memory pointer + // This pointer is likely still in use by the client and should not be + // immediately freed. By untracking it here, we ensure it won't be + // deallocated in the PinnedMemoryAllocator's destructor, allowing + // the client to manage its lifecycle. + session.getPinnedMemoryAllocator().untrack(info.ptr); // Create a memref so that client now tracks it. auto memref = MemRefValue::create( @@ -650,9 +650,6 @@ parseResults(const sol::protected_function_result &pfr, if (!memref.isOk()) return memref.getStatus(); - // Increment external reference count since we are returning a memref - allocator.incrementExternalCount(info.ptr); - results.push_back(std::move(*memref)); } diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/CUDA/CUDAModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/CUDA/CUDAModule.cpp index 912873b93..70914d5bc 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/CUDA/CUDAModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/CUDA/CUDAModule.cpp @@ -432,6 +432,15 @@ registerCudaMemoryManagementOps(sol::state_view &lua, cudaMemcpyDeviceToHost, stream), state); + // Check if the source pointer is marked for release after consumption + if (allocTracker->isMarkedForReleaseAfterConsumption(src)) { + // This pointer was allocated by TensorRT and used in a device-device + // or device-host copy operation. It's not wrapped in a memref, so it + // won't be released by external memref destruction. We need to + // explicitly free it. + SET_LUA_ERROR_IF_ERROR(runtime::safeDeallocate(*allocTracker, src), + state); + } }; lua["__cuda_memcpy_host_pinned2device"] = @@ -480,6 +489,15 @@ registerCudaMemoryManagementOps(sol::state_view &lua, cudaMemcpyDeviceToHost, stream), state); + // Check if the source pointer is marked for release after consumption + if (allocTracker->isMarkedForReleaseAfterConsumption(src)) { + // This pointer was allocated by TensorRT and used in a device-device + // or device-host copy operation. It's not wrapped in a memref, so it + // won't be released by external memref destruction. We need to + // explicitly free it. + SET_LUA_ERROR_IF_ERROR(runtime::safeDeallocate(*allocTracker, src), + state); + } }; lua["__cuda_memcpy_device2device"] = [allocTracker]( sol::this_state state, @@ -504,6 +522,15 @@ registerCudaMemoryManagementOps(sol::state_view &lua, cudaMemcpyDeviceToDevice, stream), state); + // Check if the source pointer is marked for release after consumption + if (allocTracker->isMarkedForReleaseAfterConsumption(src)) { + // This pointer was allocated by TensorRT and used in a device-device + // or device-host copy operation. It's not wrapped in a memref, so it + // won't be released by external memref destruction. We need to + // explicitly free it. + SET_LUA_ERROR_IF_ERROR(runtime::safeDeallocate(*allocTracker, src), + state); + } return; }; } diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp index 943493f99..e33db9923 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp @@ -155,6 +155,12 @@ class OutputAllocatorImpl : public nvinfer1::IOutputAllocator { if (memory.isOk()) { mOutputPtr = (*memory).ptr; mOutputSize = memory->size; + // Mark the output pointer for release after consumption + // This is necessary because TensorRT-allocated pointers used in device-device + // or device-host copies may not be wrapped in a memref and tracked by the client. + // By marking it here, we ensure it will be explicitly freed after it's consumed + // in copy operations, preventing memory leaks. + mTracker->markForReleaseAfterConsumption(mOutputPtr); MTRT_DBGF( "tensorrt module output allocator allocating %lu bytes at 0x%lx", mOutputSize, mOutputPtr); diff --git a/mlir-tensorrt/executor/lib/Support/Allocators.cpp b/mlir-tensorrt/executor/lib/Support/Allocators.cpp index 1a542e187..3e1f999e4 100644 --- a/mlir-tensorrt/executor/lib/Support/Allocators.cpp +++ b/mlir-tensorrt/executor/lib/Support/Allocators.cpp @@ -206,7 +206,7 @@ static void cudaFreeHostWrapper(uintptr_t ptr) { #endif } -std::vector PinnedMemoryAllocator::untrackedPtrs; +std::vector PinnedMemoryAllocator::clientManagedPtrs; struct PinnedMemoryAllocator::BlockTracker { std::set blocks; @@ -218,13 +218,14 @@ struct PinnedMemoryAllocator::BlockTracker { "[PinnedMemoryAllocator] Releasing block tracker that has %lu blocks", blocks.size()); for (Block *block : blocks) { - if (std::find(PinnedMemoryAllocator::untrackedPtrs.begin(), - PinnedMemoryAllocator::untrackedPtrs.end(), - block->ptr) == PinnedMemoryAllocator::untrackedPtrs.end()) { + if (std::find(clientManagedPtrs.begin(), clientManagedPtrs.end(), + block->ptr) == clientManagedPtrs.end()) { ALLOC_DBGF("[PinnedMemoryAllocator] releasing block %lu of size %lu", block->ptr, block->size); cudaFreeHostWrapper(block->ptr); } + // Blocks found in clientManagedPtrs are not freed here, as they are now + // managed by the client } } }; @@ -251,7 +252,7 @@ StatusOr PinnedMemoryAllocator::allocate(size_t size) { if (lowerBound != freeBlocks->set.end()) { Block *result = *lowerBound; freeBlocks->set.erase(result); - ALLOC_DBGF("re-using block %lu of size %lu", result->ptr, result->size); + ALLOC_DBGF("re-using block %lx of size %lu", result->ptr, result->size); return PinnedMemoryBlock{result->ptr, result->size}; } @@ -275,10 +276,11 @@ StatusOr PinnedMemoryAllocator::allocate(size_t size) { #endif } -// Free the given block. +std::vector clientManagedPtrs; + void PinnedMemoryAllocator::untrack(uintptr_t ptr) { - if (!llvm::is_contained(untrackedPtrs, ptr)) { - untrackedPtrs.emplace_back(ptr); + if (!llvm::is_contained(clientManagedPtrs, ptr)) { + clientManagedPtrs.emplace_back(ptr); } } diff --git a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp index f35085b85..24e75bf19 100644 --- a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp +++ b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp @@ -345,6 +345,7 @@ static std::unique_ptr createMemRef( static std::unique_ptr createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule, std::optional assertCanonicalStrides) { + DLManagedTensor *managedTensor = static_cast( PyCapsule_GetPointer(capsule.ptr(), "dltensor"));