Skip to content

Commit

Permalink
Fix dynamic shapes and multiple return values
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Oct 27, 2024
1 parent 61a1bd9 commit 9118b9a
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 319 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,49 +161,4 @@ def TensorRTRuntime_EnqueueAllocOp : TensorRTRuntime_Op<"enqueue_alloc", [
}];
}

//===----------------------------------------------------------------------===//
// EnqueueAllocOp
//===----------------------------------------------------------------------===//

def TensorRTRuntime_EnqueueAllocOp : TensorRTRuntime_Op<"enqueue_alloc", [
DeclareOpInterfaceMethods<TensorKindOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
]> {
let description = [{
Asynchronously executes the computation represented by the
`execution_context` on the specified CUDA stream. This operation
can accept inputs of either tensor or memref types and returns
results of either tensor or memref types.
}];

let arguments = (ins
TensorRTRuntime_Context:$execution_context,
CUDA_Stream:$stream,
Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$inputs,
OptionalAttr<DenseI64ArrayAttr>:$host_tensor_args
);

let results = (outs Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$results);

let assemblyFormat = [{
$execution_context `stream` `(` $stream `)` ` `
(`host_tensor_args` $host_tensor_args^ ` ` )?
`(` $inputs `)`
attr-dict `:` functional-type($inputs, $results)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
/// Return true if the operand is a host tensor argument.
bool isOperandOnHost(OpOperand *operand) {
unsigned operandIdx = operand->getOperandNumber();
if(std::optional<ArrayRef<int64_t>> indices = getHostTensorArgs()) {
return llvm::is_contained(*indices, operandIdx - 2);
}
return false;
}
}];
}

#endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTRUNTIMEOPS_TD
12 changes: 6 additions & 6 deletions mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions(
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
"calculations (but they can still be inputs)"));
addOption(
"use-non-dps-call-conv", useNonDPSCallConv, llvm::cl::init(false),
"enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false),
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"));
addOption("executor-index-bitwidth", executorIndexBitwidth,
Expand Down Expand Up @@ -307,7 +307,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline(
plan::StablehloClusteringPassOptions clusteringOpts{};
clusteringOpts.disallowHostTensorsInTensorRTClusters =
opts.disallowHostTensorsInTensorRTClusters;
clusteringOpts.useNonDPSCallConv = opts.useNonDPSCallConv;
clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
clusteringOpts.entrypoint = opts.entrypoint;
plan::buildPlanSegmentationPipeline(pm, clusteringOpts);

Expand Down Expand Up @@ -342,7 +342,7 @@ void StableHloToExecutableTask::buildPostClusteringPipeline(
// Perform bufferization.
pm.addPass(createMemRefCastEliminationPass());
plan::PlanAllocTensorsPassOptions allocTensorsOpts{};
allocTensorsOpts.useNonDPSCallConv = opts.useNonDPSCallConv;
allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts));
pm.addPass(plan::createPlanBufferizePass());
pm.addPass(createMemRefCastEliminationPass());
Expand Down Expand Up @@ -532,8 +532,8 @@ struct ClusteringPipelineCliOpts
*this, "device-compute-capability",
llvm::cl::desc("target device compute capability (SM version)"),
llvm::cl::init(60)};
Option<bool> useNonDPSCallConv{
*this, "use-non-dps-call-conv",
Option<bool> enableNonDPSReturns{
*this, "enable-non-dps-returns",
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"),
llvm::cl::init(false)};
Expand Down Expand Up @@ -564,7 +564,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
opts.deviceComputeCapability = cliOpts.deviceComputeCapability;
opts.deviceMaxSharedMemoryPerBlockKb =
cliOpts.deviceMaxSharedMemoryPerBlockKb;
opts.useNonDPSCallConv = cliOpts.useNonDPSCallConv;
opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns;
opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost;
opts.entrypoint = cliOpts.entrypoint;
return opts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ convertCallOp(Operation *op, IRRewriter &rewriter,
SmallVector<int64_t> hostTensorArgs;
for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) {
const TensorKindLattice *kind = solver.lookupState<TensorKindLattice>(arg);
if (!isa<RankedTensorType>(arg.getType()))
continue;
RankedTensorType rtt = cast<RankedTensorType>(arg.getType());
// To be conservative, we only do this if type is i32 and num elements
// <= 8.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern<plan::WithValuesOp> {
} // namespace

/// Get a map from `tensorrt.func` functions to associated `tensorrt.call`
/// operations.
static llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>>
/// and `tensorrt.call_alloc` operations.
static llvm::DenseMap<func::FuncOp, SmallVector<Operation *>>
getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>> map;
op->walk([&](tensorrt::CallOp callOp) {
func::FuncOp func = callOp.getFuncCallee(collection);
if (map.contains(func)) {
map[func].push_back(callOp);
llvm::DenseMap<func::FuncOp, SmallVector<Operation *>> map;
op->walk([&](Operation *callOp) {
if (!isa<tensorrt::CallOp, tensorrt::CallAllocOp>(callOp))
return;
}
map.insert(std::make_pair(func, SmallVector<tensorrt::CallOp>{callOp}));

func::FuncOp func;
if (auto call = dyn_cast<tensorrt::CallOp>(callOp))
func = call.getFuncCallee(collection);
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
func = callAlloc.getFuncCallee(collection);
else
return;

if (map.count(func))
map[func].push_back(callOp);
else
map.insert({func, SmallVector<Operation *>{callOp}});
});
return map;
}
Expand All @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
/// `tensorrt.call` operations.
static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
ModuleOp op, func::FuncOp funcOp,
ArrayRef<tensorrt::CallOp> callOps) {
ArrayRef<Operation *> callOps) {
llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0);
for (BlockArgument arg : funcOp.getArguments()) {
if (arg.use_empty())
Expand All @@ -99,8 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
funcOp.eraseArgument(i);

// Update the call ops.
for (tensorrt::CallOp callOp : callOps)
callOp.getInputsMutable().erase(i);
for (Operation *callOp : callOps) {
if (auto call = dyn_cast<tensorrt::CallOp>(callOp)) {
call.getInputsMutable().erase(i);
} else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp)) {
callAlloc.getInputsMutable().erase(i);
} else {
llvm::errs() << "Unexpected operation type in callOps\n";
callOp->dump();
}
}
}

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,82 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,

static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
plan::InlineClosedAllocGroupOp op) {
return op.emitError("outlinining inline closed alloc group ops to tensorrt "
"dialect is not yet implemented");
tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op);
auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs()));
FailureOr<FunctionOpInterface> func = createOutlinedFunc(
rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster",
"cluster.tensorrt", TypeRange(op.getInputs()),
op.getYield()->getOperandTypes());
if (failed(func))
return failure();
assert(func->getFunctionBody().getBlocks().size() == 1 &&
"expected body with one block");
func->setPublic();

rewriter.setInsertionPoint(op);

auto callOp = rewriter.create<tensorrt::CallAllocOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
{FlatSymbolRefAttr::get(*func)}));

// Populate the function arguments attributes.
for (unsigned i = 0; i < (*func).getNumArguments(); i++) {
BoundsAttr srcAttr = cast<BoundsAttr>(op.getInputAttrs()[i]);
// We may have scalar (index|signless int)-typed values since we haven't
// eliminated `plan.(with_shape|with_values)` ops yet.
if (!op.argHasTensorType(i) || srcAttr.isNone())
continue;
FailureOr<tensorrt::ShapeProfileAttr> boundAttr =
getTensorRTShapeProfile(srcAttr, op.getInputs()[i]);
if (failed(boundAttr))
return op->emitOpError("failed to create TensorRT shape profile "
"attribute from Plan BoundsAttr for argument #")
<< i << " (" << srcAttr << ")";
if (srcAttr.isShapeBound()) {
func->setArgAttr(i,
tensorrt::TensorRTDialect::getShapeProfileArgAttrName(),
*boundAttr);
continue;
}
assert(srcAttr.isValueBound() && "expected value bound or shape bound");
func->setArgAttr(
i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(),
*boundAttr);
func->setArgAttr(i, mlir::getHostTensorArgAttrName(),
rewriter.getUnitAttr());
}

// Populate the function entry block.
rewriter.eraseBlock(&func->getFunctionBody().front());

// Move private decomposition funcs associated with all `stablehlo.composite`
// ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op
// has its own symbol table.
SymbolTableCollection symbolTable;
for (auto compositeOp : op.getBody().getOps<stablehlo::CompositeOp>()) {
auto decompositionFunc = dyn_cast_if_present<func::FuncOp>(
symbolTable.lookupSymbolIn(op->getParentOfType<ModuleOp>(),
compositeOp.getDecompositionAttr()));
if (!decompositionFunc)
return emitError(compositeOp.getLoc())
<< "failed to lookup stablehlo.composite decomposition "
"function: "
<< compositeOp.getDecompositionAttr();
rewriter.moveOpAfter(decompositionFunc, func->getOperation());
}

// Move region op operations to the func body.
Operation *regionYieldOp = op.getYield();
rewriter.inlineRegionBefore(op.getRegion(), func->getFunctionBody(),
func->getFunctionBody().end());
rewriter.setInsertionPoint(regionYieldOp);
rewriter.replaceOpWithNewOp<func::ReturnOp>(regionYieldOp,
regionYieldOp->getOperands());

// replace the original region results.
rewriter.replaceOp(op, callOp);
return success();
}

/// Create outlined functions for each `scf.execute_region` operation within
Expand Down
55 changes: 0 additions & 55 deletions mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,61 +128,6 @@ class PinnedMemoryAllocator {
std::unique_ptr<BlockEventQueue> pendingBlockEvents;
};

/// Manages output tensor descriptors for TensorRT execution.
class OutputDescriptor {
public:
/// Constructs an OutputDescriptor from a raw pointer.
///
/// \param ptr Raw pointer to the descriptor data.
OutputDescriptor(uintptr_t ptr);

/// Returns the number of results in the descriptor.
int64_t getNumberOfResults() const;

/// Gets the rank of a specific tensor result.
///
/// \param resultIndex Index of the result.
unsigned getRank(int resultIndex) const;

/// Sets the data pointer for a specific tensor result.
///
/// \param resultIndex Index of the result to update.
/// \param ptr New data pointer value.
void setTensorDataPtr(int resultIndex, uintptr_t ptr);

/// Sets the shape for a specific tensor result.
///
/// \param resultIndex Index of the result to update.
/// \param shape Vector containing the shape dimensions.
void setShape(int resultIndex, const std::vector<int64_t> &shape);

/// Sets the stride for a specific tensor result.
///
/// \param resultIndex Index of the result to update.
/// \param stride Vector containing the stride values.
void setStride(int resultIndex, const std::vector<int64_t> &stride);

private:
/// Pointer to the raw descriptor data.
int64_t *mData;

/// Total size of the descriptor data.
size_t mSize;

/// Calculates the index for a specific result in the descriptor.
size_t getIndexForResult(int resultIndex) const;

/// Calculates the total size of the descriptor.
static size_t calculateTotalSize(uintptr_t ptr);

/// Calculates the offset for a specific result in the descriptor.
static size_t calculateOffsetForResult(const int64_t *desc,
int64_t resultIndex);

/// Fixed fields corresponding to rank, data ptr.
static constexpr int OUTPUT_DESC_FIXED_FIELDS = 2;
};

} // namespace mlirtrt

#endif // MLIR_TENSORRT_SUPPORT_ALLOCATORS_H
11 changes: 6 additions & 5 deletions mlir-tensorrt/executor/lib/Runtime/API/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ void AllocTracker::incrementExternalCount(uintptr_t ptr) {
llvm::formatv("Untracked pointer {0}", ptr).str().c_str());
std::unique_ptr<Metadata> const &metadata = map.at(ptr);
int32_t ref = ++metadata->externalReferenceCount;
MTRT_DBG("Incremented external reference for pointer %d to %d", ptr, ref);
MTRT_DBGF("Incremented external reference for pointer 0x%lx to %d", ptr, ref);
}

void AllocTracker::decrementExternalCount(uintptr_t ptr) {
Expand All @@ -422,11 +422,12 @@ void AllocTracker::decrementExternalCount(uintptr_t ptr) {
llvm::formatv("External reference count cannot be negative: {0}", ref)
.str()
.c_str());
MTRT_DBG("Decremented external reference for pointer %d to %d", ptr, ref);
MTRT_DBGF("Decremented external reference for pointer 0x%lx to %d", ptr, ref);
if (ref == 0 && metadata->releasedInternally) {
MTRT_DBG("External reference to an internally released pointer %d is 0, "
"try deallocating pointer memory of size %lu",
ptr, ref, metadata->info.size);
MTRT_DBGF(
"External reference to an internally released pointer 0x%lx is 0, "
"try deallocating pointer memory of size %lu",
ptr, metadata->info.size);
Status s = safeDeallocate(*this, metadata->info.ptr);
if (!s.isOk())
MTRT_DBGF("error while deallocating dangling memory: %s",
Expand Down
Loading

0 comments on commit 9118b9a

Please sign in to comment.