Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavm-nvidia committed Dec 10, 2024
1 parent 7b4937b commit 137caf5
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ struct DeviceOptions : public OptionsProvider<DeviceOptions> {
llvm::Error finalizeImpl();
};

struct EntrypointOptions : public OptionsProvider<EntrypointOptions> {
struct CommonCompilationOptions
: public OptionsProvider<CommonCompilationOptions> {
public:
/// Entrypoint function name.
std::string entrypoint = "main";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
#include "mlir-tensorrt/Compiler/OptionsProviders.h"
#include "mlir/Pass/PassManager.h"

using namespace mlirtrt::compiler;
using namespace mlir;

//===----------------------------------------------------------------------===//
// Common helpers
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class StableHloToExecutableTask;

struct StableHLOToExecutableOptions
: public mlir::OptionsBundle<DebugOptions, ExecutorOptions, DeviceOptions,
EntrypointOptions> {
CommonCompilationOptions> {
/// Initializes the options. The extensions in the provided registry
/// must be extensions for the StableHloToExecutable task.
StableHLOToExecutableOptions(TaskExtensionRegistry extensions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct TensorRTOptions

struct TensorRTToExecutableOptions
: public mlir::OptionsBundle<DeviceOptions, DebugOptions, ExecutorOptions,
EntrypointOptions, TensorRTOptions> {
CommonCompilationOptions, TensorRTOptions> {

TensorRTToExecutableOptions(TaskExtensionRegistry extensions);
};
Expand Down
4 changes: 2 additions & 2 deletions mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline(
populateExtensionPasses(pm, opts, Phase::PreClustering);

plan::StablehloClusteringPassOptions clusteringOpts{};
clusteringOpts.entrypoint = opts.get<EntrypointOptions>().entrypoint;
clusteringOpts.entrypoint = opts.get<CommonCompilationOptions>().entrypoint;
plan::buildPlanSegmentationPipeline(pm, clusteringOpts);

// Compile outlined funcs marked with `cluster.host`. The HLO in these
Expand Down Expand Up @@ -451,7 +451,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
cliOpts.deviceMaxSharedMemoryPerBlockKb;
opts.get<DeviceOptions>().shouldInferFromHost =
cliOpts.inferDeviceOptionsFromHost;
opts.get<EntrypointOptions>().entrypoint = cliOpts.entrypoint;
opts.get<CommonCompilationOptions>().entrypoint = cliOpts.entrypoint;
return opts;
}

Expand Down

0 comments on commit 137caf5

Please sign in to comment.