Skip to content

Commit

Permalink
[InputConversion] Switch to tablegen pass generation (#18245)
Browse files Browse the repository at this point in the history
This switches the pass generation definition to tablegen. The cleanup
includes switching passes to follow the `create*Pass` naming convention.
  • Loading branch information
marbre committed Aug 19, 2024
1 parent 10ba28d commit aeda149
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/InputConversion/Common/PassDetail.h"
#include "iree/compiler/InputConversion/Common/Passes.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "mlir/IR/BuiltinDialect.h"
Expand All @@ -14,10 +13,17 @@

namespace mlir::iree_compiler::InputConversion {

#define GEN_PASS_DEF_AUTOINPUTCONVERSIONPIPELINEPASS
#include "iree/compiler/InputConversion/Common/Passes.h.inc"

namespace {

struct AutoInputConversionPipelinePass final
: AutoInputConversionPipelineBase<AutoInputConversionPipelinePass> {
class AutoInputConversionPipelinePass final
: public impl::AutoInputConversionPipelinePassBase<
AutoInputConversionPipelinePass> {
public:
using impl::AutoInputConversionPipelinePassBase<
AutoInputConversionPipelinePass>::AutoInputConversionPipelinePassBase;
AutoInputConversionPipelinePass(PipelineExtensions *pipelineExtensions)
: pipelineExtensions(pipelineExtensions) {}
void runOnOperation() override;
Expand Down Expand Up @@ -77,11 +83,6 @@ void AutoInputConversionPipelinePass::getDependentDialects(

} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
createAutoInputConversionPipelinePass() {
return std::make_unique<AutoInputConversionPipelinePass>(nullptr);
}

std::unique_ptr<OperationPass<ModuleOp>>
createAutoInputConversionPipelinePass(PipelineExtensions *pipelineExtensions) {
return std::make_unique<AutoInputConversionPipelinePass>(pipelineExtensions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ iree_gentbl_cc_library(
iree_compiler_cc_library(
name = "PassHeaders",
hdrs = [
"PassDetail.h",
"Passes.h",
"Passes.h.inc",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ iree_cc_library(
NAME
PassHeaders
HDRS
"PassDetail.h"
"Passes.h"
"Passes.h.inc"
DEPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/InputConversion/Common/PassDetail.h"
#include "iree/compiler/InputConversion/Common/Passes.h"
#include "iree/compiler/Utils/ConversionUtils.h"
#include "llvm/ADT/APFloat.h"
Expand All @@ -29,6 +28,14 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::InputConversion {

#define GEN_PASS_DEF_DEMOTEF32TOF16PASS
#define GEN_PASS_DEF_DEMOTEF64TOF32PASS
#define GEN_PASS_DEF_DEMOTEI64TOI32PASS
#define GEN_PASS_DEF_PROMOTEBF16TOF32PASS
#define GEN_PASS_DEF_PROMOTEF16TOF32PASS
#include "iree/compiler/InputConversion/Common/Passes.h.inc"

namespace {

Value convertRankedFloat(OpBuilder &builder, Type type, ValueRange inputs,
Expand Down Expand Up @@ -295,77 +302,58 @@ struct DemoteI64ToI32Converter
return IntegerType::get(type.getContext(), 32, type.getSignedness());
}
};
struct DemoteI64ToI32Pass
: public ConvertTypesPass<DemoteI64ToI32Base<DemoteI64ToI32Pass>,
class DemoteI64ToI32Pass final
: public ConvertTypesPass<impl::DemoteI64ToI32PassBase<DemoteI64ToI32Pass>,
DemoteI64ToI32Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteI64ToI32Pass() {
return std::make_unique<DemoteI64ToI32Pass>();
}

namespace {
struct DemoteF32ToF16Converter
: public PrimitiveTypeConverter<Float32Type, Float16Type> {
Type getTargetType(Float32Type type) override {
return Float16Type::get(type.getContext());
}
};
struct DemoteF32ToF16Pass
: public ConvertTypesPass<DemoteF32ToF16Base<DemoteF32ToF16Pass>,
class DemoteF32ToF16Pass final
: public ConvertTypesPass<impl::DemoteF32ToF16PassBase<DemoteF32ToF16Pass>,
DemoteF32ToF16Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteF32ToF16Pass() {
return std::make_unique<DemoteF32ToF16Pass>();
}

namespace {
struct PromoteF16ToF32Converter
: public PrimitiveTypeConverter<Float16Type, Float32Type> {
Type getTargetType(Float16Type type) override {
return Float32Type::get(type.getContext());
}
};
struct PromoteF16ToF32Pass
: public ConvertTypesPass<PromoteF16ToF32Base<PromoteF16ToF32Pass>,
PromoteF16ToF32Converter> {};
class PromoteF16ToF32Pass final
: public ConvertTypesPass<
impl::PromoteF16ToF32PassBase<PromoteF16ToF32Pass>,
PromoteF16ToF32Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createPromoteF16ToF32Pass() {
return std::make_unique<PromoteF16ToF32Pass>();
}

namespace {
struct PromoteBF16ToF32Converter
: public FloatTypeConverter<BFloat16Type, Float32Type> {
Type getTargetType(BFloat16Type type) override {
return Float32Type::get(type.getContext());
}
};
struct PromoteBF16ToF32Pass
: public ConvertTypesPass<PromoteBF16ToF32Base<PromoteBF16ToF32Pass>,
PromoteBF16ToF32Converter> {};
class PromoteBF16ToF32Pass final
: public ConvertTypesPass<
impl::PromoteBF16ToF32PassBase<PromoteBF16ToF32Pass>,
PromoteBF16ToF32Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createPromoteBF16ToF32Pass() {
return std::make_unique<PromoteBF16ToF32Pass>();
}

namespace {
struct DemoteF64ToF32Converter
: public PrimitiveTypeConverter<Float64Type, Float32Type> {
Type getTargetType(Float64Type type) override {
return Float32Type::get(type.getContext());
}
};
struct DemoteF64ToF32Pass
: public ConvertTypesPass<DemoteF64ToF32Base<DemoteF64ToF32Pass>,
class DemoteF64ToF32Pass final
: public ConvertTypesPass<impl::DemoteF64ToF32PassBase<DemoteF64ToF32Pass>,
DemoteF64ToF32Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteF64ToF32Pass() {
return std::make_unique<DemoteF64ToF32Pass>();
}

} // namespace mlir::iree_compiler::InputConversion
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/InputConversion/Common/PassDetail.h"
#include "iree/compiler/InputConversion/Common/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -25,10 +24,14 @@

namespace mlir::iree_compiler::InputConversion {

#define GEN_PASS_DEF_IREEIMPORTPUBLICPASS
#include "iree/compiler/InputConversion/Common/Passes.h.inc"

namespace {

struct IREEImportPublicPass
: public IREEImportPublicBase<IREEImportPublicPass> {
class IREEImportPublicPass final
: public impl::IREEImportPublicPassBase<IREEImportPublicPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Input::IREEInputDialect, IREE::Flow::FlowDialect,
IREE::HAL::HALDialect, IREE::Util::UtilDialect,
Expand Down Expand Up @@ -597,8 +600,4 @@ void IREEImportPublicPass::runOnOperation() {
signalPassFailure();
}

std::unique_ptr<OperationPass<ModuleOp>> createIREEImportPublicPass() {
return std::make_unique<IREEImportPublicPass>();
}

} // namespace mlir::iree_compiler::InputConversion
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/InputConversion/Common/PassDetail.h"
#include "iree/compiler/InputConversion/Common/Passes.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
Expand All @@ -25,9 +24,14 @@

namespace mlir::iree_compiler::InputConversion {

#define GEN_PASS_DEF_IMPORTMLPROGRAMPASS
#include "iree/compiler/InputConversion/Common/Passes.h.inc"

namespace {

struct ImportMLProgramPass : public ImportMLProgramBase<ImportMLProgramPass> {
class ImportMLProgramPass final
: public impl::ImportMLProgramPassBase<ImportMLProgramPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, IREE::Util::UtilDialect>();
}
Expand Down Expand Up @@ -264,8 +268,4 @@ void ImportMLProgramPass::runOnOperation() {
signalPassFailure();
}

std::unique_ptr<OperationPass<ModuleOp>> createImportMLProgramPass() {
return std::make_unique<ImportMLProgramPass>();
}

} // namespace mlir::iree_compiler::InputConversion
21 changes: 0 additions & 21 deletions compiler/src/iree/compiler/InputConversion/Common/PassDetail.h

This file was deleted.

17 changes: 0 additions & 17 deletions compiler/src/iree/compiler/InputConversion/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#ifndef IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES_H_
#define IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES_H_

#include "iree/compiler/InputConversion/Common/PassDetail.h"
#include "iree/compiler/Pipelines/Options.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -38,24 +37,8 @@ void buildCommonInputConversionPassPipeline(
// Passes
//===----------------------------------------------------------------------===//

std::unique_ptr<OperationPass<ModuleOp>>
createAutoInputConversionPipelinePass();
std::unique_ptr<OperationPass<ModuleOp>>
createAutoInputConversionPipelinePass(PipelineExtensions *pipelineExtensions);
std::unique_ptr<OperationPass<ModuleOp>> createIREEImportPublicPass();
std::unique_ptr<OperationPass<ModuleOp>> createImportMLProgramPass();
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLinalgQuantizedConvToConvPass();
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLinalgQuantizedMatmulToMatmulPass();
std::unique_ptr<OperationPass<ModuleOp>> createSanitizeModuleNamesPass();

// Type conversion.
std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteI64ToI32Pass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteF32ToF16Pass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteF64ToF32Pass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createPromoteF16ToF32Pass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createPromoteBF16ToF32Pass();

//===----------------------------------------------------------------------===//
// Register all Passes
Expand Down
39 changes: 9 additions & 30 deletions compiler/src/iree/compiler/InputConversion/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,83 +9,62 @@

include "mlir/Pass/PassBase.td"

def IREEImportPublic :
def IREEImportPublicPass :
Pass<"iree-import-public", "ModuleOp"> {
let summary = "Imports IREE public dialect to internal implementation.";
let constructor = "mlir::iree_compiler::InputConversion::createIREEImportPublicPass()";
}

def ImportMLProgram :
def ImportMLProgramPass :
Pass<"iree-import-ml-program", "ModuleOp"> {
let summary = "Imports MLProgram dialect to IREE Equivalents.";
let constructor = "mlir::iree_compiler::InputConversion::createImportMLProgramPass()";
}

def LinalgQuantizedConvToConvPass
: InterfacePass<"iree-linalg-quantized-conv-to-conv", "mlir::FunctionOpInterface"> {
let summary = "lower quantized_conv to conv";
let constructor = "mlir::iree_compiler::InputConversion::createLinalgQuantizedConvToConvPass()";
}

def LinalgQuantizedMatmulToMatmulPass
: InterfacePass<"iree-linalg-quantized-matmul-to-matmul", "mlir::FunctionOpInterface"> {
let summary = "lower quantized_matmul to matmul";
let constructor = "mlir::iree_compiler::InputConversion::createLinalgQuantizedMatmulToMatmulPass()";
}

def SanitizeModuleNames :
def SanitizeModuleNamesPass :
Pass<"iree-sanitize-module-names", "ModuleOp"> {
let summary = "Sanitizes module names for uniformity across target implementations.";
let constructor = "mlir::iree_compiler::InputConversion::createSanitizeModuleNamesPass()";
}

def AutoInputConversionPipeline :
def AutoInputConversionPipelinePass :
Pass<"iree-auto-input-conversion", "ModuleOp"> {
let summary = "Analyzes and runs appropriate input pipeline.";
let description = [{
Analyzes the input program to determine which input pipeline, or part of input
conversion to run, then run that conversion.
}];
let constructor = "mlir::iree_compiler::InputConversion::createAutoInputConversionPipelinePass()";
}

//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//

def DemoteI64ToI32 : Pass<"iree-input-conversion-demote-i64-to-i32", "mlir::ModuleOp"> {
def DemoteI64ToI32Pass : Pass<"iree-input-conversion-demote-i64-to-i32", "mlir::ModuleOp"> {
let summary = "Demotes i64 types to i32 types.";
let constructor = [{
mlir::iree_compiler::InputConversion::createDemoteI64ToI32Pass()
}];
}

def DemoteF32ToF16 : Pass<"iree-input-conversion-demote-f32-to-f16", "mlir::ModuleOp"> {
def DemoteF32ToF16Pass : Pass<"iree-input-conversion-demote-f32-to-f16", "mlir::ModuleOp"> {
let summary = "Demotes f32 types to f16 types.";
let constructor = [{
mlir::iree_compiler::InputConversion::createDemoteF32ToF16Pass()
}];
}

def DemoteF64ToF32 : Pass<"iree-input-conversion-demote-f64-to-f32", "mlir::ModuleOp"> {
def DemoteF64ToF32Pass : Pass<"iree-input-conversion-demote-f64-to-f32", "mlir::ModuleOp"> {
let summary = "Demotes f64 types to f32 types.";
let constructor = [{
mlir::iree_compiler::InputConversion::createDemoteF64ToF32Pass()
}];
}

def PromoteF16ToF32 : Pass<"iree-input-conversion-promote-f16-to-f32", "mlir::ModuleOp"> {
def PromoteF16ToF32Pass : Pass<"iree-input-conversion-promote-f16-to-f32", "mlir::ModuleOp"> {
let summary = "Promotes f16 types to f32 types.";
let constructor = [{
mlir::iree_compiler::InputConversion::createPromoteF16ToF32Pass()
}];
}

def PromoteBF16ToF32 : Pass<"iree-input-conversion-promote-bf16-to-f32", "mlir::ModuleOp"> {
def PromoteBF16ToF32Pass : Pass<"iree-input-conversion-promote-bf16-to-f32", "mlir::ModuleOp"> {
let summary = "Promotes bf16 types to f32 types.";
let constructor = [{
mlir::iree_compiler::InputConversion::createPromoteBF16ToF32Pass()
}];
}

#endif // IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES
Loading

0 comments on commit aeda149

Please sign in to comment.