diff --git a/include/scalehls-c/Registration.h b/include/scalehls-c/Registration.h index 04cdb1b4..7d5c3967 100644 --- a/include/scalehls-c/Registration.h +++ b/include/scalehls-c/Registration.h @@ -16,6 +16,9 @@ extern "C" { MLIR_CAPI_EXPORTED void mlirScaleHLSRegisterAllDialects(MlirDialectRegistry registry); +MLIR_CAPI_EXPORTED void +mlirScaleHLSRegisterAllExtensions(MlirDialectRegistry registry); + MLIR_CAPI_EXPORTED void mlirScaleHLSRegisterAllInterfaceExternalModels(MlirDialectRegistry registry); diff --git a/include/scalehls/InitAllDialects.h b/include/scalehls/InitAllDialects.h index 77c2f012..92d1ed6d 100644 --- a/include/scalehls/InitAllDialects.h +++ b/include/scalehls/InitAllDialects.h @@ -14,12 +14,16 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "scalehls/Dialect/HLS/IR/HLS.h" @@ -52,7 +56,10 @@ inline void registerAllDialects(mlir::DialectRegistry ®istry) { mlir::scalehls::hls::HLSDialect, mlir::LLVM::LLVMDialect, mlir::DLTIDialect, - mlir::ml_program::MLProgramDialect + mlir::ml_program::MLProgramDialect, + mlir::pdl::PDLDialect, + mlir::pdl_interp::PDLInterpDialect, + mlir::transform::TransformDialect >(); // clang-format on } @@ -70,6 +77,11 @@ registerAllInterfaceExternalModels(mlir::DialectRegistry ®istry) { hls::registerBufferizableOpInterfaceExternalModels(registry); } +/// Add all required dialect extensions to the provided registry. +inline void registerAllExtensions(DialectRegistry ®istry) { + linalg::registerTransformDialectExtension(registry); +} + } // namespace scalehls } // namespace mlir diff --git a/lib/Bindings/Python/ScaleHLSModule.cpp b/lib/Bindings/Python/ScaleHLSModule.cpp index 0f873878..85249fe4 100644 --- a/lib/Bindings/Python/ScaleHLSModule.cpp +++ b/lib/Bindings/Python/ScaleHLSModule.cpp @@ -37,6 +37,7 @@ PYBIND11_MODULE(_scalehls, m) { MlirDialectRegistry registry = mlirDialectRegistryCreate(); mlirScaleHLSRegisterAllDialects(registry); + mlirScaleHLSRegisterAllExtensions(registry); mlirScaleHLSRegisterAllInterfaceExternalModels(registry); mlirContextAppendDialectRegistry(context, registry); mlirContextLoadAllAvailableDialects(context); diff --git a/lib/CAPI/Registration/Registration.cpp b/lib/CAPI/Registration/Registration.cpp index fb2b4bf8..616141e6 100644 --- a/lib/CAPI/Registration/Registration.cpp +++ b/lib/CAPI/Registration/Registration.cpp @@ -16,6 +16,10 @@ void mlirScaleHLSRegisterAllDialects(MlirDialectRegistry registry) { registerAllDialects(*unwrap(registry)); } +void mlirScaleHLSRegisterAllExtensions(MlirDialectRegistry registry) { + registerAllExtensions(*unwrap(registry)); +} + void mlirScaleHLSRegisterAllInterfaceExternalModels( MlirDialectRegistry registry) { registerAllInterfaceExternalModels(*unwrap(registry)); diff --git a/lib/Transforms/Pipelines.cpp b/lib/Transforms/Pipelines.cpp index 553f42db..b2e9f897 100644 --- a/lib/Transforms/Pipelines.cpp +++ b/lib/Transforms/Pipelines.cpp @@ -32,6 +32,7 @@ void scalehls::addLinalgTransformPasses(OpPassManager &pm) { void scalehls::addConvertLinalgToDataflowPasses(OpPassManager &pm) { pm.addNestedPass(scalehls::createConvertLinalgToDataflowPass()); + pm.addPass(mlir::createLinalgGeneralizationPass()); pm.addPass(mlir::createCanonicalizerPass()); } diff --git a/tools/scalehls-opt/scalehls-opt.cpp b/tools/scalehls-opt/scalehls-opt.cpp index 038569a3..d0088820 100644 --- a/tools/scalehls-opt/scalehls-opt.cpp +++ b/tools/scalehls-opt/scalehls-opt.cpp @@ -12,6 +12,7 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; mlir::scalehls::registerAllDialects(registry); mlir::scalehls::registerAllInterfaceExternalModels(registry); + mlir::scalehls::registerAllExtensions(registry); mlir::scalehls::registerAllPasses(); return mlir::failed(