Skip to content

Commit 070359d

Browse files
committed
Address review comments
1 parent 6a2b53b commit 070359d

File tree

15 files changed

+48
-103
lines changed

15 files changed

+48
-103
lines changed

mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ struct ConvertEnqueueAllocToCall
379379

380380
// Create output memrefs from output descriptors
381381
SmallVector<Value> results;
382-
unsigned offset = 1;
382+
unsigned offset = 1; // Skip num of results
383383
for (unsigned i = 0; i < op->getNumResults(); ++i) {
384384
unsigned rank = cast<MemRefType>(op->getResult(i).getType()).getRank();
385385
Value rankOffset = b.create<executor::GetOffsetOp>(

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanTransforms
3737
MLIRTensorRTStablehloScalarToArith
3838
MLIRTensorRTStablehloToTensorRT
3939
MLIRTensorRTTensorRTRuntimeDialect
40+
MLIRBufferizationToMemRef
4041
MLIRTransforms
4142
StablehloOps
4243
)

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,11 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
113113
call.getInputsMutable().erase(i);
114114
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
115115
callAlloc.getInputsMutable().erase(i);
116-
else {
117-
llvm::errs() << "Unexpected operation type in callOps\n";
118-
callOp->dump();
119-
return failure();
120-
}
116+
else
117+
return emitError(funcOp->getLoc())
118+
<< "Unexpected operation type in callOps";
121119
}
122120
}
123-
124121
return success();
125122
}
126123

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
//===----------------------------------------------------------------------===//
2525
#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h"
2626
#include "mlir-tensorrt/Transforms/Passes.h"
27+
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
2728
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
2829
#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
2930
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
@@ -80,6 +81,7 @@ void plan::buildPlanBufferDeallocationPipeline(
8081
pm.addPass(createCanonicalizerPass());
8182
pm.addPass(bufferization::createBufferDeallocationSimplificationPass());
8283
pm.addPass(bufferization::createLowerDeallocationsPass());
84+
pm.addPass(mlir::createBufferizationToMemRefPass());
8385
pm.addPass(createCSEPass());
8486
pm.addPass(createCanonicalizerPass());
8587
}
@@ -103,31 +105,21 @@ struct ClusteringPipelineCliOpts
103105
llvm::cl::init(NV_TENSORRT_MAJOR)};
104106
};
105107

106-
struct PlanBufferizationPipelineCliOpts
107-
: public PassPipelineOptions<PlanBufferizationPipelineCliOpts> {
108-
Option<bool> enableNonDPSReturns{
109-
*this, "enable-non-dps-returns",
110-
llvm::cl::desc("allow backend clusters to directly allocate outputs"),
111-
llvm::cl::init(false)};
112-
};
113-
114108
} // namespace
115109

116110
// Register pipelines.
117111

118112
void plan::registerPlanDialectPipelines() {
119-
PassPipelineRegistration<PlanBufferizationPipelineCliOpts>
120-
executorBufferizationPipeline(
121-
"plan-bufferize-pipeline",
122-
"perform bufferization and standard pre/post processing passes",
123-
[](OpPassManager &pm, const PlanBufferizationPipelineCliOpts &opts) {
124-
PlanAllocTensorsPassOptions allocTensorOpts{};
125-
allocTensorOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
126-
buildPlanBufferizationPipeline(pm, allocTensorOpts);
127-
buildPlanBufferOptimizationPipeline(pm);
128-
buildPlanBufferDeallocationPipeline(
129-
pm, bufferization::DeallocationOptions{false});
130-
});
113+
PassPipelineRegistration<> executorBufferizationPipeline(
114+
"plan-bufferize-pipeline",
115+
"perform bufferization and standard pre/post processing passes",
116+
[](OpPassManager &pm) {
117+
PlanAllocTensorsPassOptions allocTensorOpts{};
118+
buildPlanBufferizationPipeline(pm, allocTensorOpts);
119+
buildPlanBufferOptimizationPipeline(pm);
120+
buildPlanBufferDeallocationPipeline(
121+
pm, bufferization::DeallocationOptions{false});
122+
});
131123

132124
PassPipelineRegistration<> bufferOptPipeline(
133125
"plan-buffer-opt-pipeline", "perform post-bufferization optimizations",

mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ extern "C" {
5353
/// caller must be sure to delete errors via mtrtStatusDestroy.
5454
//===----------------------------------------------------------------------===//
5555

56-
typedef struct MTRT_RuntimeClient MTRT_Runtimeclient;
56+
struct MTRT_RuntimeClient; // Forward declaration
5757

5858
//===----------------------------------------------------------------------===//
5959
// MTRT_GlobalDebug

mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -550,20 +550,6 @@ void executor::populateMemRefToExecutorPatterns(
550550

551551
namespace {
552552

553-
class RemoveNoOpClonePattern : public OpRewritePattern<bufferization::CloneOp> {
554-
public:
555-
using OpRewritePattern<bufferization::CloneOp>::OpRewritePattern;
556-
557-
LogicalResult matchAndRewrite(bufferization::CloneOp op,
558-
PatternRewriter &rewriter) const override {
559-
if (op.getInput().getType() == op.getOutput().getType()) {
560-
rewriter.replaceOp(op, op.getInput());
561-
return success();
562-
}
563-
return failure();
564-
}
565-
};
566-
567553
/// Pass to convert `memref` to `executor` dialect operrations.
568554
class ConvertMemRefToExecutorPass
569555
: public mlir::executor::impl::ConvertMemRefToExecutorPassBase<
@@ -596,9 +582,6 @@ class ConvertMemRefToExecutorPass
596582
executor::populateMemRefToExecutorPatterns(
597583
patterns, typeConverter, allowUncheckedMemrefCastConversion);
598584

599-
// Remove unrealized cast and redundant clone operations.
600-
patterns.add<RemoveNoOpClonePattern>(ctx);
601-
602585
if (failed(applyPartialConversion(getOperation(), target,
603586
std::move(patterns))))
604587
return signalPassFailure();

mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,31 +54,22 @@ class ExecutorBufferizationTestPass
5454
}
5555
};
5656

57-
struct PlanBufferizationPipelineCliOpts
58-
: public PassPipelineOptions<PlanBufferizationPipelineCliOpts> {
59-
Option<bool> enableNonDPSReturns{
60-
*this, "enable-non-dps-returns",
61-
llvm::cl::desc("allow backend clusters to directly allocate outputs"),
62-
llvm::cl::init(false)};
63-
};
64-
6557
} // namespace
6658

6759
namespace mlir::executor {
6860
void registerTestExecutorBufferizePass() {
6961
PassRegistration<ExecutorBufferizationTestPass>();
7062

71-
PassPipelineRegistration<PlanBufferizationPipelineCliOpts>
72-
executorBufferizationPipeline(
73-
"test-executor-bufferization-pipeline",
74-
"Run one-shot-bufferization and buffer deallocation pipelines",
75-
[](OpPassManager &pm, const PlanBufferizationPipelineCliOpts &opts) {
76-
pm.addPass(std::make_unique<ExecutorBufferizationTestPass>());
77-
pm.addPass(bufferization::createDropEquivalentBufferResultsPass());
78-
bufferization::BufferDeallocationPipelineOptions deallocOptions{};
79-
bufferization::buildBufferDeallocationPipeline(pm, deallocOptions);
80-
pm.addPass(createCSEPass());
81-
pm.addPass(createCanonicalizerPass());
82-
});
63+
PassPipelineRegistration<> executorBufferizationPipeline(
64+
"test-executor-bufferization-pipeline",
65+
"Run one-shot-bufferization and buffer deallocation pipelines",
66+
[](OpPassManager &pm) {
67+
pm.addPass(std::make_unique<ExecutorBufferizationTestPass>());
68+
pm.addPass(bufferization::createDropEquivalentBufferResultsPass());
69+
bufferization::BufferDeallocationPipelineOptions deallocOptions{};
70+
bufferization::buildBufferDeallocationPipeline(pm, deallocOptions);
71+
pm.addPass(createCSEPass());
72+
pm.addPass(createCanonicalizerPass());
73+
});
8374
}
8475
} // namespace mlir::executor

mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ class PyRuntimeClient
244244
using Base::Base;
245245
DECLARE_WRAPPER_CONSTRUCTORS(PyRuntimeClient);
246246

247-
static constexpr auto kMethodTable = CAPITable<MTRT_Runtimeclient>{
247+
static constexpr auto kMethodTable = CAPITable<MTRT_RuntimeClient>{
248248
mtrtRuntimeClientIsNull, mtrtRuntimeClientDestroy};
249249
};
250250

@@ -961,7 +961,8 @@ PYBIND11_MODULE(_api, m) {
961961
[](PyRuntimeSession &self, std::string name,
962962
std::vector<py::object> inArgs,
963963
std::optional<std::vector<py::object>> outArgs,
964-
std::optional<MTRT_Stream> stream, PyRuntimeClient &client) {
964+
std::optional<MTRT_Stream> stream,
965+
PyRuntimeClient *client = nullptr) {
965966
MTRT_StringView nameRef{name.data(), name.size()};
966967

967968
int64_t numResults;
@@ -980,7 +981,8 @@ PYBIND11_MODULE(_api, m) {
980981
self, nameRef, inArgsGeneric.data(), inArgsGeneric.size(),
981982
outArgsGeneric.data(), outArgsGeneric.size(),
982983
resultsGeneric.data(), stream ? *stream : mtrtStreamGetNull(),
983-
client);
984+
client ? MTRT_RuntimeClient(*client)
985+
: mtrtRuntimeClientGetNull());
984986
THROW_IF_MTRT_ERROR(s);
985987

986988
std::vector<py::object> resultPyObject;
@@ -992,7 +994,7 @@ PYBIND11_MODULE(_api, m) {
992994
return resultPyObject;
993995
},
994996
py::arg("name"), py::arg("in_args"), py::arg("out_args") = py::none(),
995-
py::arg("stream") = py::none(), py::arg("client"),
997+
py::arg("stream") = py::none(), py::arg("client") = nullptr,
996998
"Execute a function given input and optional output arguments. "
997999
"Return optional results as a Python object if output arguments are "
9981000
"not present.");

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -568,10 +568,10 @@ static LogicalResult serializeSplatElements(DenseIntOrFPElementsAttr values,
568568
std::fill_n(reinterpret_cast<uint8_t *>(data.data()), data.size(), packed);
569569
return llvm::success();
570570
}
571-
llvm::errs() << "Error: "
572-
<< "unsupported data type to convert MLIR splat attribute to "
573-
"TensorRT weights!";
574-
return llvm::failure();
571+
572+
return emitError(UnknownLoc::get(values.getContext()))
573+
<< "unsupported data type to convert MLIR splat attribute to TensorRT "
574+
"weights!";
575575
}
576576

577577
FailureOr<nvinfer1::Weights>

mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,7 @@ def test_stablehlo_add(
3636
session = runtime.RuntimeSession(session_options, exe)
3737

3838
session.execute_function(
39-
"main",
40-
in_args=test.in_args,
41-
out_args=test.out_args,
42-
stream=stream,
43-
client=runtime_client,
39+
"main", in_args=test.in_args, out_args=test.out_args, stream=stream
4440
)
4541
output = [
4642
(

0 commit comments

Comments
 (0)