Skip to content

Commit

Permalink
add passes
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavm-nvidia committed Dec 11, 2024
1 parent 4c232ea commit 08f90f3
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@
// Common helpers
//===----------------------------------------------------------------------===//

mlir::LogicalResult setupPassManager(mlir::PassManager &pm,
const DebugOptions &options);
mlir::LogicalResult
setupPassManager(mlir::PassManager &pm,
const mlirtrt::compiler::DebugOptions &options);
69 changes: 67 additions & 2 deletions mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,23 @@
#include "mlir-tensorrt/Compiler/TensorRTToExecutable.h"
#include "mlir-executor/Conversion/Passes.h"
#include "mlir-executor/Executor/Transforms/Passes.h"
#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h"
#include "mlir-tensorrt/Compiler/OptionsRegistry.h"
#include "mlir-tensorrt/Compiler/PassManagerUtils.h"
#include "mlir-tensorrt/Conversion/Passes.h"
#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h"
#include "mlir-tensorrt/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/Passes.h"

using namespace mlir;
using namespace mlirtrt::compiler;

TensorRTToExecutableOptions::TensorRTToExecutableOptions(
TaskExtensionRegistry extensions) {
// TODO (pranavm): Do we need to support extensions?
// TODO (pranavm): We don't need extensions - remove from constructor and add
// `setExtensions` to base class.
assert(extensions.extensions.size() == 0);
}

void TensorRTToExecutableTask::populatePassManager(
Expand All @@ -43,7 +52,63 @@ void TensorRTToExecutableTask::populatePassManager(
/// specifications.
}

// TODO (pranavm): Which passes go here?
// Post-clustering
pm.addPass(createConvertTensorRTToTensorRTRuntimePass());

pm.addNestedPass<func::FuncOp>(plan::createPostClusteringValidationPass());

pm.addPass(createCanonicalizerPass());

pm.addPass(createInlinerPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());

// We then perform some final simplification on the top-level func.func ops
// (e.g. public entrypoint functions).
pm.addNestedPass<func::FuncOp>(createSCFDetensorizeLoopsPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());

// Pre-bufferization
// Simplify and translate functions nested in `tensorrt.module` ops.
auto &trtPM = pm.nest<tensorrt::TensorRTModuleOp>();
tensorrt::buildTensorRTModuleTransformationPipeline(
trtPM, options.get<TensorRTOptions>().options.enableStronglyTyped);
trtPM.addPass(tensorrt::createTranslateTensorRTPass(
nullptr, nullptr, options.get<TensorRTOptions>().options));

pm.addPass(createMemRefCastEliminationPass());
pm.addPass(plan::createPlanAllocTensorsPass());
pm.addPass(plan::createPlanBufferizePass());
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(bufferization::createDropEquivalentBufferResultsPass());
plan::buildPlanBufferOptimizationPipeline(pm);
plan::buildPlanBufferDeallocationPipeline(
pm, bufferization::DeallocationOptions{
/*privateFuncDynamicOwnership=*/false});

// Post-bufferization
pm.addPass(createConvertMemRefToCUDAPass());
pm.addPass(createConvertPlanToExecutorPass());
pm.addPass(executor::createExecutorAllocsToGlobalsPass());
pm.addNestedPass<func::FuncOp>(
executor::createExecutorPopulateFunctionMetadataPass());

// Executor lowering
ConvertTensorRTRuntimeToExecutorPassOptions toExecutorOpts;
toExecutorOpts.indexBitwidth = options.get<ExecutorOptions>().indexBitwidth;
toExecutorOpts.usePackedMemRefCConv =
options.get<ExecutorOptions>().usePackedMemRefCConv;
pm.addPass(createConvertTensorRTRuntimeToExecutorPass(toExecutorOpts));

ConvertCUDAToExecutorPassOptions cudaToExecutorOpts;
cudaToExecutorOpts.indexBitwidth =
options.get<ExecutorOptions>().indexBitwidth;
cudaToExecutorOpts.usePackedMemRefCConv =
options.get<ExecutorOptions>().usePackedMemRefCConv;
pm.addPass(createConvertCUDAToExecutorPass(cudaToExecutorOpts));

pm.addPass(createDropNestedModulesPass());

mlir::executor::ConvertStdToExecutorPassOptions stdToExecOpts;
stdToExecOpts.indexBitwidth = options.get<ExecutorOptions>().indexBitwidth;
Expand Down

0 comments on commit 08f90f3

Please sign in to comment.