Skip to content

Commit

Permalink
Support transform ops in scalehls-opt
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Jan 9, 2024
1 parent 02273bb commit 351177c
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 1 deletion.
3 changes: 3 additions & 0 deletions include/scalehls-c/Registration.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ extern "C" {
MLIR_CAPI_EXPORTED void
mlirScaleHLSRegisterAllDialects(MlirDialectRegistry registry);

MLIR_CAPI_EXPORTED void
mlirScaleHLSRegisterAllExtensions(MlirDialectRegistry registry);

MLIR_CAPI_EXPORTED void
mlirScaleHLSRegisterAllInterfaceExternalModels(MlirDialectRegistry registry);

Expand Down
14 changes: 13 additions & 1 deletion include/scalehls/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "scalehls/Dialect/HLS/IR/HLS.h"

Expand Down Expand Up @@ -52,7 +56,10 @@ inline void registerAllDialects(mlir::DialectRegistry &registry) {
mlir::scalehls::hls::HLSDialect,
mlir::LLVM::LLVMDialect,
mlir::DLTIDialect,
mlir::ml_program::MLProgramDialect
mlir::ml_program::MLProgramDialect,
mlir::pdl::PDLDialect,
mlir::pdl_interp::PDLInterpDialect,
mlir::transform::TransformDialect
>();
// clang-format on
}
Expand All @@ -70,6 +77,11 @@ registerAllInterfaceExternalModels(mlir::DialectRegistry &registry) {
hls::registerBufferizableOpInterfaceExternalModels(registry);
}

/// Add all required dialect extensions to the provided registry.
inline void registerAllExtensions(DialectRegistry &registry) {
linalg::registerTransformDialectExtension(registry);
}

} // namespace scalehls
} // namespace mlir

Expand Down
1 change: 1 addition & 0 deletions lib/Bindings/Python/ScaleHLSModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ PYBIND11_MODULE(_scalehls, m) {

MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirScaleHLSRegisterAllDialects(registry);
mlirScaleHLSRegisterAllExtensions(registry);
mlirScaleHLSRegisterAllInterfaceExternalModels(registry);
mlirContextAppendDialectRegistry(context, registry);
mlirContextLoadAllAvailableDialects(context);
Expand Down
4 changes: 4 additions & 0 deletions lib/CAPI/Registration/Registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ void mlirScaleHLSRegisterAllDialects(MlirDialectRegistry registry) {
registerAllDialects(*unwrap(registry));
}

void mlirScaleHLSRegisterAllExtensions(MlirDialectRegistry registry) {
registerAllExtensions(*unwrap(registry));
}

void mlirScaleHLSRegisterAllInterfaceExternalModels(
MlirDialectRegistry registry) {
registerAllInterfaceExternalModels(*unwrap(registry));
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void scalehls::addLinalgTransformPasses(OpPassManager &pm) {

void scalehls::addConvertLinalgToDataflowPasses(OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(scalehls::createConvertLinalgToDataflowPass());
pm.addPass(mlir::createLinalgGeneralizationPass());
pm.addPass(mlir::createCanonicalizerPass());
}

Expand Down
1 change: 1 addition & 0 deletions tools/scalehls-opt/scalehls-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ int main(int argc, char **argv) {
mlir::DialectRegistry registry;
mlir::scalehls::registerAllDialects(registry);
mlir::scalehls::registerAllInterfaceExternalModels(registry);
mlir::scalehls::registerAllExtensions(registry);
mlir::scalehls::registerAllPasses();

return mlir::failed(
Expand Down

0 comments on commit 351177c

Please sign in to comment.