From 146e4f48423bbd5ddcd1d66bdebf4d08479cb1fc Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Fri, 23 Feb 2024 16:30:24 -0600 Subject: [PATCH] Remove the layer of structural dataflow abstraction --- include/scalehls-c/Dialect/HLS/HLS.h | 63 +- include/scalehls/Dialect/HLS/IR/HLS.h | 36 +- .../scalehls/Dialect/HLS/IR/HLSAttributes.td | 17 - .../scalehls/Dialect/HLS/IR/HLSInterfaces.td | 28 - include/scalehls/Dialect/HLS/IR/HLSOps.td | 141 +---- .../scalehls/Dialect/HLS/Transforms/Passes.h | 3 +- .../scalehls/Dialect/HLS/Transforms/Passes.td | 11 +- lib/Bindings/Python/HLSDialect.cpp | 30 +- lib/CAPI/Dialect/HLS/HLS.cpp | 21 +- lib/Dialect/HLS/IR/HLS.cpp | 275 +-------- lib/Dialect/HLS/IR/HLSOps.cpp | 536 ++---------------- .../BufferizableOpInterfaceImpl.cpp | 20 +- lib/Dialect/HLS/Transforms/CMakeLists.txt | 3 +- .../HLS/Transforms/ConvertDataflowToFunc.cpp | 115 ++-- lib/Dialect/HLS/Transforms/LowerDataflow.cpp | 184 ------ .../HLS/Transforms/ScalarizeStream.cpp | 6 +- ...reateDataflow.cpp => ScheduleDataflow.cpp} | 18 +- lib/Pipelines/Pipelines.cpp | 4 +- 18 files changed, 232 insertions(+), 1279 deletions(-) delete mode 100644 lib/Dialect/HLS/Transforms/LowerDataflow.cpp rename lib/Dialect/HLS/Transforms/{CreateDataflow.cpp => ScheduleDataflow.cpp} (72%) diff --git a/include/scalehls-c/Dialect/HLS/HLS.h b/include/scalehls-c/Dialect/HLS/HLS.h index fb472feb..4065a111 100644 --- a/include/scalehls-c/Dialect/HLS/HLS.h +++ b/include/scalehls-c/Dialect/HLS/HLS.h @@ -24,56 +24,31 @@ MLIR_CAPI_EXPORTED void mlirSemanticsInitializeBlockArguments(MlirOperation semantics, const std::vector &ports); -//===----------------------------------------------------------------------===// -// HLS Dialect Types -//===----------------------------------------------------------------------===// - -MLIR_CAPI_EXPORTED bool mlirTypeIsHLSStructType(MlirType type); -MLIR_CAPI_EXPORTED MlirType mlirHLSStructTypeGet(MlirStringRef name, - MlirContext ctx); - -MLIR_CAPI_EXPORTED bool mlirTypeIsHLSTypeType(MlirType type); -MLIR_CAPI_EXPORTED MlirType mlirHLSTypeTypeGet(MlirContext ctx); - -MLIR_CAPI_EXPORTED bool mlirTypeIsHLSIntParamType(MlirType type); -MLIR_CAPI_EXPORTED MlirType mlirHLSIntParamTypeGet(MlirContext ctx); - -MLIR_CAPI_EXPORTED bool mlirTypeIsHLSFloatParamType(MlirType type); -MLIR_CAPI_EXPORTED MlirType mlirHLSFloatParamTypeGet(MlirContext ctx); - -MLIR_CAPI_EXPORTED bool mlirTypeIsHLSPortType(MlirType type); -MLIR_CAPI_EXPORTED MlirType mlirHLSPortTypeGet(MlirContext ctx); - -MLIR_CAPI_EXPORTED bool mlirTypeIsHLSTaskImplType(MlirType type); -MLIR_CAPI_EXPORTED MlirType mlirHLSTaskImplTypeGet(MlirContext ctx); - -MLIR_CAPI_EXPORTED bool mlirTypeIsHLSMemoryKindType(MlirType type); -MLIR_CAPI_EXPORTED MlirType mlirHLSMemoryKindTypeGet(MlirContext ctx); - //===----------------------------------------------------------------------===// // HLS Dialect Attributes //===----------------------------------------------------------------------===// -enum class MlirParamKind : uint32_t { - TILE_SIZE = 0, - PARALLEL_SIZE = 1, - IP_TEMPLATE = 2, - TASK_IMPL = 3, - MEMORY_KIND = 4 +enum class MlirMemoryKind : uint32_t { + UNKNOWN = 0, + LUTRAM_1P = 1, + LUTRAM_2P = 2, + LUTRAM_S2P = 3, + BRAM_1P = 4, + BRAM_2P = 5, + BRAM_S2P = 6, + BRAM_T2P = 7, + URAM_1P = 8, + URAM_2P = 9, + URAM_S2P = 10, + URAM_T2P = 11, + DRAM = 12 }; -MLIR_CAPI_EXPORTED bool mlirAttrIsHLSParamKindAttr(MlirAttribute attr); -MLIR_CAPI_EXPORTED MlirAttribute mlirHLSParamKindAttrGet(MlirContext ctx, - MlirParamKind kind); -MLIR_CAPI_EXPORTED MlirParamKind -mlirHLSParamKindAttrGetValue(MlirAttribute attr); - -enum class MlirPortKind : uint32_t { INPUT = 0, OUTPUT = 1, PARAM = 2 }; - -MLIR_CAPI_EXPORTED bool mlirAttrIsHLSPortKindAttr(MlirAttribute attr); -MLIR_CAPI_EXPORTED MlirAttribute mlirHLSPortKindAttrGet(MlirContext ctx, - MlirPortKind kind); -MLIR_CAPI_EXPORTED MlirPortKind mlirHLSPortKindAttrGetValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttrIsHLSMemoryKindAttr(MlirAttribute attr); +MLIR_CAPI_EXPORTED MlirAttribute mlirHLSMemoryKindAttrGet(MlirContext ctx, + MlirMemoryKind kind); +MLIR_CAPI_EXPORTED MlirMemoryKind +mlirHLSMemoryKindAttrGetValue(MlirAttribute attr); #ifdef __cplusplus } diff --git a/include/scalehls/Dialect/HLS/IR/HLS.h b/include/scalehls/Dialect/HLS/IR/HLS.h index 6d31579b..8dc1be27 100644 --- a/include/scalehls/Dialect/HLS/IR/HLS.h +++ b/include/scalehls/Dialect/HLS/IR/HLS.h @@ -63,7 +63,7 @@ void setRuntimeAttr(Operation *op); /// Wrap the operations in the block with dispatch op. Return a nullptr if /// failed. -DispatchOp dispatchBlock(StringRef name, Block *block, +ScheduleOp scheduleBlock(StringRef name, Block *block, PatternRewriter &rewriter); /// Fuse the given operations into a new task. The new task will be created @@ -72,9 +72,6 @@ DispatchOp dispatchBlock(StringRef name, Block *block, TaskOp fuseOpsIntoTask(ArrayRef ops, PatternRewriter &rewriter, Operation *insertToOp = nullptr); -/// Fuse multiple nodes into a new node. -NodeOp fuseNodeOps(ArrayRef nodes, PatternRewriter &rewriter); - //===----------------------------------------------------------------------===// // Analysis Utils //===----------------------------------------------------------------------===// @@ -88,34 +85,6 @@ bool isRamT2P(MemRefType type); bool isDram(MemRefType type); bool isUnknown(MemRefType type); -/// Get the consumer/producer nodes of the given buffer expect the given op. -SmallVector getConsumersExcept(Value buffer, NodeOp except); -SmallVector getProducersExcept(Value buffer, NodeOp except); -SmallVector getConsumers(Value buffer); -SmallVector getProducers(Value buffer); -SmallVector getDependentConsumers(Value buffer, NodeOp node); - -/// Get the nested consumer/producer nodes of the given buffer expect the given -/// node. The corresponding buffer values are also returned. -SmallVector> getNestedConsumersExcept(Value buffer, - NodeOp except); -SmallVector> getNestedProducersExcept(Value buffer, - NodeOp except); -SmallVector> getNestedConsumers(Value buffer); -SmallVector> getNestedProducers(Value buffer); - -/// Get the depth of a buffer or stream channel. Note that only if the defining -/// operation of the buffer is not a BufferOp or stream types, the returned -/// result will be 1. -unsigned getBufferDepth(Value memref); - -/// Find buffer value or buffer op across the dataflow hierarchy. -Value findBuffer(Value memref); -hls::BufferLikeInterface findBufferOp(Value memref); - -/// Check whether the given buffer is external. -bool isExtBuffer(Value memref); - /// Check whether the given use has read/write semantics. bool isRead(OpOperand &use); bool isWritten(OpOperand &use); @@ -132,9 +101,6 @@ bool isFullyPartitioned(MemRefType memrefType); int64_t getPartitionFactors(MemRefType memrefType, SmallVectorImpl *factors = nullptr); -/// The current op or contained ops have effect on external buffers. -bool hasEffectOnExternalBuffer(Operation *op); - } // namespace hls } // namespace scalehls } // namespace mlir diff --git a/include/scalehls/Dialect/HLS/IR/HLSAttributes.td b/include/scalehls/Dialect/HLS/IR/HLSAttributes.td index 0f3e3e20..5f4ae891 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSAttributes.td +++ b/include/scalehls/Dialect/HLS/IR/HLSAttributes.td @@ -19,23 +19,6 @@ def IndexArrayAttr : TypedArrayAttrBase { let constBuilderCall = "$_builder.getIndexArrayAttr($0)"; } -//===----------------------------------------------------------------------===// -// PortKind Attribute -//===----------------------------------------------------------------------===// - -def PortKind: I32EnumAttr<"PortKind", "Kind of a port", [ - I32EnumAttrCase<"INPUT", 0, "input">, - I32EnumAttrCase<"OUTPUT", 1, "output">, - I32EnumAttrCase<"PARAM", 2, "param">]> { - let cppNamespace = "mlir::scalehls::hls"; - let genSpecializedAttr = 0; -} -def PortKindParam: EnumParameter; -def PortKindAttr: EnumAttr { - let mnemonic = "port"; - let assemblyFormat = "`<` $value `>`"; -} - //===----------------------------------------------------------------------===// // MemoryKind Attribute //===----------------------------------------------------------------------===// diff --git a/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td b/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td index 4f8ed70e..387c0ef6 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td +++ b/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td @@ -55,32 +55,4 @@ def BufferLikeInterface : OpInterface<"BufferLikeInterface"> { ]; } -def ContainerLikeInterface : OpInterface<"ContainerLikeInterface"> { - let description = [{ - This interface is used to represent containers, including dispatch, task, - schedule, and node. - }]; - string cppNamespace = "mlir::scalehls::hls"; - - let methods = [ - InterfaceMethod<"Return body region of the stage", - "mlir::Region &", "getBody", (ins), "return $_op.getBody();">, - InterfaceMethod<"Check whether the stage has hierarchy", - "bool", "hasHierarchy", (ins), [{ - return $_op.walk([&](ContainerLikeInterface stage) { - if (stage != $_op) - return WalkResult::interrupt(); - return WalkResult::advance(); - }).wasInterrupted(); - }]>, - InterfaceMethod<"Return whether the value is a stage livein", - "bool", "isLivein", (ins "mlir::Value":$value)>, - InterfaceMethod<"Return the liveins of the stage", - "llvm::SmallVector", "getLiveins">, - InterfaceMethod<"Return the internal users of a stage livein", - "llvm::SmallVector", "getLiveinUsers", - (ins "mlir::Value":$livein)>, - ]; -} - #endif // SCALEHLS_DIALECT_HLS_HLSINTERFACES_TD diff --git a/include/scalehls/Dialect/HLS/IR/HLSOps.td b/include/scalehls/Dialect/HLS/IR/HLSOps.td index 167457a2..74f5f00f 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSOps.td +++ b/include/scalehls/Dialect/HLS/IR/HLSOps.td @@ -220,13 +220,13 @@ def StreamCastOp : HLSOp<"stream_cast", [Pure, // Functional Dataflow (FDF) Operations //===----------------------------------------------------------------------===// -def DispatchOp : HLSOp<"dispatch", [RecursiveMemoryEffects, +def ScheduleOp : HLSOp<"schedule", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"YieldOp">, ParentOneOf<["func::FuncOp", "affine::AffineForOp", "scf::ForOp"]>]> { - let summary = "Represent a dataflow dispatch"; + let summary = "Represent a dataflow schedule"; let description = [{ - Dispatch op has a transparent region that contains a list of task ops to be - dispatched. This op is designed to organize and manipulate task ops at a + Schedule op has a transparent region that contains a list of task ops to be + scheduleed. This op is designed to organize and manipulate task ops at a high level and will be lowered to schedule op for dataflow scheduling. }]; @@ -242,9 +242,8 @@ def DispatchOp : HLSOp<"dispatch", [RecursiveMemoryEffects, }]; } -def TaskOp : HLSOp<"task", [DeclareOpInterfaceMethods, - RecursiveMemoryEffects, SingleBlockImplicitTerminator<"YieldOp">, - HasParent<"DispatchOp">]> { +def TaskOp : HLSOp<"task", [RecursiveMemoryEffects, HasParent<"ScheduleOp">, + SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Represent a dataflow task"; let description = [{ Task op has a transparent region that contains a list of ops to be executed @@ -263,17 +262,19 @@ def TaskOp : HLSOp<"task", [DeclareOpInterfaceMethods, let extraClassDeclaration = [{ /// Return true if this task op contains nested sub-tasks. bool hasHierarchy() { - return cast(this->getOperation()).hasHierarchy(); + return walk([&](TaskOp task) { + return task != *this ? WalkResult::interrupt() : WalkResult::advance(); + }).wasInterrupted(); } - DispatchOp getDispatchOp(); + ScheduleOp getScheduleOp(); YieldOp getYieldOp(); }]; } def YieldOp : HLSOp<"yield", [Pure, ReturnLike, Terminator, - ParentOneOf<["DispatchOp", "TaskOp"]>]> { - let summary = "Terminate and yield results of a dispatch or task op"; + ParentOneOf<["ScheduleOp", "TaskOp"]>]> { + let summary = "Terminate and yield results of a schedule or task op"; let arguments = (ins Variadic:$results); let assemblyFormat = "$results attr-dict `:` type($results)"; @@ -281,124 +282,6 @@ def YieldOp : HLSOp<"yield", [Pure, ReturnLike, Terminator, let builders = [OpBuilder<(ins), "build($_builder, $_state, std::nullopt);">]; } -//===----------------------------------------------------------------------===// -// Structural Dataflow (SDF) Operations -//===----------------------------------------------------------------------===// - -def ScheduleOp : HLSOp<"schedule", [ - DeclareOpInterfaceMethods, IsolatedFromAbove, - AffineScope, SingleBlock, NoTerminator, - ParentOneOf<["func::FuncOp", "affine::AffineForOp", "scf::ForOp"]>]> { - let summary = "Represent a dataflow schedule"; - let description = [{ - Schedule op has an isolated region to contain a list of dataflow node ops to - be scheduled. This op can be explicitly marked as legal when all the - dataflow violations have been resolved and all the nodes has been scheduled. - }]; - - let arguments = (ins Variadic:$operands, UnitAttr:$isLegal); - let regions = (region SizedRegion<1>:$body); - let assemblyFormat = [{ - (`legal` $isLegal^)? (`(` $operands^ `)`)? (`:` type($operands)^)? $body - attr-dict - }]; - - let hasVerifier = 1; - let hasCanonicalizer = 1; - - let extraClassDeclaration = [{ - /// FIXME: Check whether the schedule is dependence free. - bool isDependenceFree(); - - /// Update the signature of the schedule op recursively. - void updateSignatureRecursively(); - }]; -} - -def NodeOp : HLSOp<"node", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, IsolatedFromAbove, - AffineScope, SingleBlock, NoTerminator, AttrSizedOperandSegments, - HasParent<"ScheduleOp">]> { - let summary = "Represent a dataflow node"; - let description = [{ - Node op has an isolated region to represent the ops contained by the node. - The node can only take buffers or streams as inputs and outputs. Meanwhile, - scalar values can be passed into a node as parameters, which will not be - considered in the dataflow. An attribute "inputTaps" is used to represent - the level of buffer or stream channel tap of each input. - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, Variadic:$params, - I32ArrayAttr:$inputTaps, OptionalAttr:$level); - let regions = (region SizedRegion<1>:$body); - - let assemblyFormat = [{ - `(` $inputs `)` `->` `(` $outputs `)` (`[` $params^ `]`)? `:` - functional-type($inputs, $outputs) (`[` type($params)^ `]`)? $body attr-dict - }]; - - let hasVerifier = 1; - let hasCanonicalizer = 1; - let builders = [ - OpBuilder<(ins "mlir::ValueRange":$inputs, "mlir::ValueRange":$outputs, - "mlir::ValueRange":$params, "ArrayRef":$inputTaps, - "mlir::IntegerAttr":$level), [{ - auto newInputTaps = SmallVector( - llvm::map_range(inputTaps, [](unsigned a) { return (int32_t)a; })); - build($_builder, $_state, inputs, outputs, params, - $_builder.getI32ArrayAttr(newInputTaps), level); - }]>, - - OpBuilder<(ins "mlir::ValueRange":$inputs, "mlir::ValueRange":$outputs, - "mlir::ValueRange":$params, "ArrayRef":$inputTaps), [{ - build($_builder, $_state, inputs, outputs, params, inputTaps, nullptr); - }]>, - OpBuilder<(ins "mlir::ValueRange":$inputs, "mlir::ValueRange":$outputs, - "ArrayRef":$inputTaps), [{ - build($_builder, $_state, inputs, outputs, ValueRange(), inputTaps); - }]>, - - OpBuilder<(ins "mlir::ValueRange":$inputs, "mlir::ValueRange":$outputs, - "mlir::ValueRange":$params), [{ - build($_builder, $_state, inputs, outputs, params, - SmallVector(inputs.size(), 0)); - }]>, - OpBuilder<(ins "mlir::ValueRange":$inputs, "mlir::ValueRange":$outputs), - "build($_builder, $_state, inputs, outputs, ValueRange());"> - ]; - - let extraClassDeclaration = [{ - /// Get input taps. - void setInputTap(unsigned idx, unsigned tap); - unsigned getInputTap(unsigned idx); - SmallVector getInputTapsAsInt(); - - /// Get the number of inputs, outputs, and params. - unsigned getNumInputs(); - unsigned getNumOutputs(); - unsigned getNumParams(); - - /// Get the type of operand: input, output, or param. - PortKind getPortKind(OpOperand &operand); - PortKind getPortKind(unsigned operandIdx); - - /// Get the input, output, and param arguments. - iterator_range getInputArgs(); - iterator_range getOutputArgs(); - iterator_range getParamArgs(); - - bool hasHierarchy() { - return cast(this->getOperation()).hasHierarchy(); - } - - /// Update the signature of the node op recursively. - void updateSignatureRecursively(); - - ScheduleOp getScheduleOp(); - }]; -} - def BufferOp : HLSOp<"buffer", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "Represent a dataflow buffer"; diff --git a/include/scalehls/Dialect/HLS/Transforms/Passes.h b/include/scalehls/Dialect/HLS/Transforms/Passes.h index af690f3d..1e650029 100644 --- a/include/scalehls/Dialect/HLS/Transforms/Passes.h +++ b/include/scalehls/Dialect/HLS/Transforms/Passes.h @@ -28,8 +28,7 @@ std::unique_ptr createReduceTensorToStreamPass(); std::unique_ptr createMaterializeStreamPass(bool enablePacking = true); std::unique_ptr createScalarizeStreamPass(); -std::unique_ptr createCreateDataflowPass(); -std::unique_ptr createLowerDataflowPass(); +std::unique_ptr createScheduleDataflowPass(); std::unique_ptr createConvertDataflowToFuncPass(); std::unique_ptr createApplyTransformPatternPass(); diff --git a/include/scalehls/Dialect/HLS/Transforms/Passes.td b/include/scalehls/Dialect/HLS/Transforms/Passes.td index 9a5bb52c..093ff458 100644 --- a/include/scalehls/Dialect/HLS/Transforms/Passes.td +++ b/include/scalehls/Dialect/HLS/Transforms/Passes.td @@ -35,14 +35,9 @@ def ScalarizeStream : Pass<"scalehls-scalarize-stream", "func::FuncOp"> { let constructor = "mlir::scalehls::hls::createScalarizeStreamPass()"; } -def CreateDataflow : Pass<"scalehls-create-dataflow", "func::FuncOp"> { - let summary = "Convert linalg to functional dataflow"; - let constructor = "mlir::scalehls::hls::createCreateDataflowPass()"; -} - -def LowerDataflow : Pass<"scalehls-lower-dataflow", "func::FuncOp"> { - let summary = "Convert functional to structural dataflow"; - let constructor = "mlir::scalehls::hls::createLowerDataflowPass()"; +def ScheduleDataflow : Pass<"scalehls-schedule-dataflow", "func::FuncOp"> { + let summary = "Create a dataflow schedule"; + let constructor = "mlir::scalehls::hls::createScheduleDataflowPass()"; } def ConvertDataflowToFunc : diff --git a/lib/Bindings/Python/HLSDialect.cpp b/lib/Bindings/Python/HLSDialect.cpp index bf92ac1e..b283afd5 100644 --- a/lib/Bindings/Python/HLSDialect.cpp +++ b/lib/Bindings/Python/HLSDialect.cpp @@ -21,24 +21,34 @@ using namespace mlir::python::adaptors; //===----------------------------------------------------------------------===// void populateHLSAttributes(py::module &m) { - py::enum_(m, "PortKind", py::module_local()) - .value("input", MlirPortKind::INPUT) - .value("output", MlirPortKind::OUTPUT) - .value("param", MlirPortKind::PARAM); + py::enum_(m, "MemoryKind", py::module_local()) + .value("UNKNOWN", MlirMemoryKind::UNKNOWN) + .value("LUTRAM_1P", MlirMemoryKind::LUTRAM_1P) + .value("LUTRAM_2P", MlirMemoryKind::LUTRAM_2P) + .value("LUTRAM_S2P", MlirMemoryKind::LUTRAM_S2P) + .value("BRAM_1P", MlirMemoryKind::BRAM_1P) + .value("BRAM_2P", MlirMemoryKind::BRAM_2P) + .value("BRAM_S2P", MlirMemoryKind::BRAM_S2P) + .value("BRAM_T2P", MlirMemoryKind::BRAM_T2P) + .value("URAM_1P", MlirMemoryKind::URAM_1P) + .value("URAM_2P", MlirMemoryKind::URAM_2P) + .value("URAM_S2P", MlirMemoryKind::URAM_S2P) + .value("URAM_T2P", MlirMemoryKind::URAM_T2P) + .value("DRAM", MlirMemoryKind::DRAM); auto portKindAttr = - mlir_attribute_subclass(m, "PortKindAttr", mlirAttrIsHLSPortKindAttr); + mlir_attribute_subclass(m, "MemoryKindAttr", mlirAttrIsHLSMemoryKindAttr); portKindAttr.def_classmethod( "get", - [](py::object cls, MlirPortKind kind, MlirContext ctx) { - return cls(mlirHLSPortKindAttrGet(ctx, kind)); + [](py::object cls, MlirMemoryKind kind, MlirContext ctx) { + return cls(mlirHLSMemoryKindAttrGet(ctx, kind)); }, - "Get an instance of PortKindAttr in given context.", py::arg("cls"), + "Get an instance of MemoryKindAttr in given context.", py::arg("cls"), py::arg("kind"), py::arg("context") = py::none()); portKindAttr.def_property_readonly( "value", - [](MlirAttribute attr) { return mlirHLSPortKindAttrGetValue(attr); }, - "Returns the value of PortKindAttr."); + [](MlirAttribute attr) { return mlirHLSMemoryKindAttrGetValue(attr); }, + "Returns the value of MemoryKindAttr."); } //===----------------------------------------------------------------------===// diff --git a/lib/CAPI/Dialect/HLS/HLS.cpp b/lib/CAPI/Dialect/HLS/HLS.cpp index 66f994f1..7eb4f9e1 100644 --- a/lib/CAPI/Dialect/HLS/HLS.cpp +++ b/lib/CAPI/Dialect/HLS/HLS.cpp @@ -22,19 +22,14 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(HLS, hls, hls::HLSDialect) // HLS Dialect Attributes //===----------------------------------------------------------------------===// -static_assert(static_cast(MlirPortKind::INPUT) == - static_cast(PortKind::INPUT) && - static_cast(MlirPortKind::OUTPUT) == - static_cast(PortKind::OUTPUT), - "MlirPortKind (C-API) and PortKind (C++) mismatch"); - -bool mlirAttrIsHLSPortKindAttr(MlirAttribute attr) { - return unwrap(attr).isa(); +bool mlirAttrIsHLSMemoryKindAttr(MlirAttribute attr) { + return unwrap(attr).isa(); } -MlirAttribute mlirHLSPortKindAttrGet(MlirContext ctx, MlirPortKind kind) { - return wrap(hls::PortKindAttr::get(unwrap(ctx), static_cast(kind))); +MlirAttribute mlirHLSMemoryKindAttrGet(MlirContext ctx, MlirMemoryKind kind) { + return wrap( + hls::MemoryKindAttr::get(unwrap(ctx), static_cast(kind))); } -MlirPortKind mlirHLSPortKindAttrGetValue(MlirAttribute attr) { - return static_cast( - unwrap(attr).cast().getValue()); +MlirMemoryKind mlirHLSMemoryKindAttrGetValue(MlirAttribute attr) { + return static_cast( + unwrap(attr).cast().getValue()); } diff --git a/lib/Dialect/HLS/IR/HLS.cpp b/lib/Dialect/HLS/IR/HLS.cpp index 25a9ca14..317a1759 100644 --- a/lib/Dialect/HLS/IR/HLS.cpp +++ b/lib/Dialect/HLS/IR/HLS.cpp @@ -300,35 +300,35 @@ bool hls::hasRuntimeAttr(Operation *op) { // Transform Utils //===----------------------------------------------------------------------===// -/// Wrap the operations in the block with dispatch op. -DispatchOp hls::dispatchBlock(StringRef name, Block *block, +/// Wrap the operations in the block with schedule op. +ScheduleOp hls::scheduleBlock(StringRef name, Block *block, PatternRewriter &rewriter) { - if (!block->getOps().empty() || + if (!block->getOps().empty() || !isa(block->getParentOp())) return nullptr; auto loc = rewriter.getUnknownLoc(); ValueRange returnValues(block->getTerminator()->getOperands()); rewriter.setInsertionPointToStart(block); - auto dispatch = rewriter.create(loc, returnValues); + auto schedule = rewriter.create(loc, returnValues); - auto &dispatchBlock = dispatch.getBody().emplaceBlock(); - rewriter.setInsertionPointToEnd(&dispatchBlock); + auto &scheduleBlock = schedule.getBody().emplaceBlock(); + rewriter.setInsertionPointToEnd(&scheduleBlock); rewriter.create(loc, returnValues); - auto &dispatchOps = dispatchBlock.getOperations(); + auto &scheduleOps = scheduleBlock.getOperations(); auto &parentOps = block->getOperations(); - dispatchOps.splice(dispatchBlock.begin(), parentOps, + scheduleOps.splice(scheduleBlock.begin(), parentOps, std::next(parentOps.begin()), std::prev(parentOps.end())); - block->getTerminator()->setOperands(dispatch.getResults()); + block->getTerminator()->setOperands(schedule.getResults()); unsigned taskId = 0; - for (auto &op : llvm::make_early_inc_range(dispatch.getOps())) { + for (auto &op : llvm::make_early_inc_range(schedule.getOps())) { assert(!isa(op) && !isa(op) && !isa(op) && - "stream op must be materialized before being dispatched"); + "stream op must be materialized before being scheduleed"); assert(!isa(op.getDialect()) && - "tensor op must be bufferized before being dispatched"); + "tensor op must be bufferized before being scheduleed"); if (isa(op)) { auto task = fuseOpsIntoTask({&op}, rewriter); std::string taskName = name.str() + "_" + std::to_string(taskId++); @@ -336,7 +336,7 @@ DispatchOp hls::dispatchBlock(StringRef name, Block *block, task->setAttr(taskName, rewriter.getUnitAttr()); } } - return dispatch; + return schedule; } /// Fuse the given operations into a new task. The new task will be created @@ -390,64 +390,6 @@ TaskOp hls::fuseOpsIntoTask(ArrayRef ops, return task; } -/// Fuse multiple nodes into a new node. -NodeOp hls::fuseNodeOps(ArrayRef nodes, PatternRewriter &rewriter) { - assert((nodes.size() > 1) && "must fuse at least two nodes"); - - // Collect inputs, outputs, and params of the new node. - llvm::SetVector inputs; - llvm::SmallVector inputTaps; - llvm::SmallVector inputLocs; - llvm::SetVector outputs; - llvm::SmallVector outputLocs; - llvm::SetVector params; - llvm::SmallVector paramLocs; - - for (auto node : nodes) { - for (auto output : node.getOutputs()) - if (outputs.insert(output)) - outputLocs.push_back(output.getLoc()); - for (auto param : node.getParams()) - if (params.insert(param)) - paramLocs.push_back(param.getLoc()); - } - for (auto node : nodes) - for (auto input : llvm::enumerate(node.getInputs())) { - if (outputs.count(input.value())) - continue; - if (inputs.insert(input.value())) { - inputLocs.push_back(input.value().getLoc()); - inputTaps.push_back(node.getInputTap(input.index())); - } - } - - // Construct the new node after the last node. - rewriter.setInsertionPointAfter(nodes.back()); - auto newNode = rewriter.create( - rewriter.getUnknownLoc(), inputs.getArrayRef(), outputs.getArrayRef(), - params.getArrayRef(), inputTaps); - auto block = rewriter.createBlock(&newNode.getBody()); - block->addArguments(ValueRange(inputs.getArrayRef()), inputLocs); - block->addArguments(ValueRange(outputs.getArrayRef()), outputLocs); - block->addArguments(ValueRange(params.getArrayRef()), paramLocs); - - // Inline all nodes into the new node. - for (auto node : nodes) { - auto &nodeOps = node.getBody().front().getOperations(); - auto &newNodeOps = newNode.getBody().front().getOperations(); - newNodeOps.splice(newNode.end(), nodeOps); - for (auto t : llvm::zip(node.getBody().getArguments(), node.getOperands())) - std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); - rewriter.eraseOp(node); - } - - for (auto t : llvm::zip(newNode.getOperands(), block->getArguments())) - std::get<0>(t).replaceUsesWithIf(std::get<1>(t), [&](OpOperand &use) { - return newNode->isProperAncestor(use.getOwner()); - }); - return newNode; -} - //===----------------------------------------------------------------------===// // Analysis Utils //===----------------------------------------------------------------------===// @@ -488,179 +430,27 @@ bool hls::isUnknown(MemRefType type) { return kind == MemoryKind::UNKNOWN; } -/// A helper to get all users of a buffer except the given node and with the -/// given kind (producer or consumer). -static auto getUsersExcept(Value buffer, PortKind kind, NodeOp except) { - SmallVector nodes; - for (auto &use : buffer.getUses()) - if (auto node = dyn_cast(use.getOwner())) - if (node != except && node.getPortKind(use) == kind) - nodes.push_back(node); - return nodes; -} - -/// Get the consumer/producer nodes of the given buffer expect the given op. -SmallVector hls::getConsumersExcept(Value buffer, NodeOp except) { - return getUsersExcept(buffer, PortKind::INPUT, except); -} -SmallVector hls::getProducersExcept(Value buffer, NodeOp except) { - return getUsersExcept(buffer, PortKind::OUTPUT, except); -} -SmallVector hls::getConsumers(Value buffer) { - return getConsumersExcept(buffer, NodeOp()); -} -SmallVector hls::getProducers(Value buffer) { - return getProducersExcept(buffer, NodeOp()); -} -SmallVector hls::getDependentConsumers(Value buffer, NodeOp node) { - // If the buffer is defined outside of a dependence free schedule op, we can - // ignore back dependences. - bool ignoreBackDependence = - buffer.isa() && node.getScheduleOp().isDependenceFree(); - - DominanceInfo domInfo; - SmallVector nodes; - for (auto consumer : getConsumersExcept(buffer, node)) - if (!ignoreBackDependence || domInfo.properlyDominates(node, consumer)) - nodes.push_back(consumer); - return nodes; -} - -/// A helper to get all nested users of a buffer except the given node and with -/// the given kind (producer or consumer). -static SmallVector> -getNestedUsersExcept(Value buffer, PortKind kind, NodeOp except) { - SmallVector> worklist; - - // A helper to append all node users of the given buffer. - auto appendWorklist = [&](Value buffer) { - for (auto &use : buffer.getUses()) - if (auto node = dyn_cast(use.getOwner())) - if (node != except) - worklist.push_back({node, buffer, node.getPortKind(use)}); - }; - - // Initialize the worklist. - appendWorklist(buffer); - - SmallVector> nestedUsers; - while (!worklist.empty()) { - auto current = worklist.pop_back_val(); - auto node = std::get<0>(current); - auto nodeBuffer = std::get<1>(current); - auto nodeKind = std::get<2>(current); - - // If the current node doesn't have hierarchy, we add it to results if the - // node kind is aligned. - if (!node.hasHierarchy()) { - if (nodeKind == kind) - nestedUsers.push_back({node, nodeBuffer}); - continue; - } - - // Otherwise, we should delve into the hierarchy and traverse all contained - // schedules. - auto index = - llvm::find(node.getOperands(), nodeBuffer) - node.operand_begin(); - assert(index != node.getNumOperands() && "invalid node or node buffer"); - auto arg = node.getBody().getArgument(index); - - for (auto &use : arg.getUses()) - if (auto schedule = dyn_cast(use.getOwner())) - appendWorklist(schedule.getBody().getArgument(use.getOperandNumber())); - } - return nestedUsers; -} - -/// Get the nested consumer/producer nodes of the given buffer expect the given -/// node. The corresponding buffer values are also returned. -SmallVector> -hls::getNestedConsumersExcept(Value buffer, NodeOp except) { - return getNestedUsersExcept(buffer, PortKind::INPUT, except); -} -SmallVector> -hls::getNestedProducersExcept(Value buffer, NodeOp except) { - return getNestedUsersExcept(buffer, PortKind::OUTPUT, except); -} -SmallVector> hls::getNestedConsumers(Value buffer) { - return getNestedConsumersExcept(buffer, NodeOp()); -} -SmallVector> hls::getNestedProducers(Value buffer) { - return getNestedProducersExcept(buffer, NodeOp()); -} - -/// Get the depth of a buffer or stream channel. Note that only if the defining -/// operation of the buffer is not a BufferOp or stream types, the returned -/// result will be 1. -unsigned hls::getBufferDepth(Value memref) { - if (auto streamType = memref.getType().dyn_cast()) { - return streamType.getDepth(); - } - return 1; -} - -/// Find buffer value or buffer op across the dataflow hierarchy. -Value hls::findBuffer(Value memref) { - if (auto arg = memref.dyn_cast()) { - if (auto node = dyn_cast(arg.getParentBlock()->getParentOp())) - return findBuffer(node->getOperand(arg.getArgNumber())); - else if (auto schedule = - dyn_cast(arg.getParentBlock()->getParentOp())) - return findBuffer(schedule->getOperand(arg.getArgNumber())); - return memref; - } else if (auto viewOp = memref.getDefiningOp()) - return findBuffer(viewOp.getViewSource()); - else if (auto buffer = memref.getDefiningOp()) - return buffer.getMemref(); - return Value(); -} -hls::BufferLikeInterface hls::findBufferOp(Value memref) { - if (auto buffer = findBuffer(memref)) - return buffer.getDefiningOp(); - return hls::BufferLikeInterface(); -} - -/// Check whether the given buffer is external. -bool hls::isExtBuffer(Value memref) { - if (auto type = memref.getType().dyn_cast()) - return isDram(type); - return false; -} - /// Check whether the given use has read/write semantics. bool hls::isRead(OpOperand &use) { - // For NodeOp and ScheduleOp, we don't rely on memory effect interface. - // Instead, we delve into its region to figure out the effect. However, for - // InstanceOp, we don't need this recursive approach any more. - if (auto node = dyn_cast(use.getOwner())) - return llvm::any_of( - node.getBody().getArgument(use.getOperandNumber()).getUses(), - [](OpOperand &argUse) { return isRead(argUse); }); - else if (auto schedule = dyn_cast(use.getOwner())) - return llvm::any_of( - schedule.getBody().getArgument(use.getOperandNumber()).getUses(), - [](OpOperand &argUse) { return isRead(argUse); }); - else if (auto view = dyn_cast(use.getOwner())) + if (auto view = dyn_cast(use.getOwner())) return llvm::any_of(view->getUses(), [](OpOperand &viewUse) { return isRead(viewUse); }); - return hasEffect(use.getOwner(), use.get()) || - isa(use.getOwner()); + else if (auto streamView = dyn_cast(use.getOwner())) + return llvm::any_of(streamView->getUses(), [](OpOperand &streamViewUse) { + return isRead(streamViewUse); + }); + return hasEffect(use.getOwner(), use.get()); } + bool hls::isWritten(OpOperand &use) { - // For ScheduleOp, we don't rely on memory effect interface. Instead, we delve - // into its region to figure out the effect. However, for InstanceOp and - // NodeOp, we don't need this recursive approach any more. - if (auto node = dyn_cast(use.getOwner())) - return node.getPortKind(use) == PortKind::OUTPUT; - else if (auto schedule = dyn_cast(use.getOwner())) - return llvm::any_of( - schedule.getBody().getArgument(use.getOperandNumber()).getUses(), - [](OpOperand &argUse) { return isWritten(argUse); }); - else if (auto view = dyn_cast(use.getOwner())) + if (auto view = dyn_cast(use.getOwner())) return llvm::any_of(view->getUses(), [](OpOperand &viewUse) { return isWritten(viewUse); }); - return hasEffect(use.getOwner(), use.get()) || - isa(use.getOwner()); + else if (auto streamView = dyn_cast(use.getOwner())) + return llvm::any_of(streamView->getUses(), [](OpOperand &streamViewUse) { + return isWritten(streamViewUse); + }); + return hasEffect(use.getOwner(), use.get()); } func::FuncOp hls::getTopFunc(ModuleOp module, std::string topFuncName) { @@ -718,16 +508,3 @@ int64_t hls::getPartitionFactors(MemRefType memrefType, factors->assign(memrefType.getRank(), 1); return accumFactor; } - -/// The current op or contained ops have effect on external buffers. -bool hls::hasEffectOnExternalBuffer(Operation *op) { - auto result = op->walk([](MemoryEffectOpInterface effectOp) { - SmallVector effects; - effectOp.getEffects(effects); - for (auto effect : effects) - if (isExtBuffer(effect.getValue())) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - return result.wasInterrupted(); -} diff --git a/lib/Dialect/HLS/IR/HLSOps.cpp b/lib/Dialect/HLS/IR/HLSOps.cpp index 63430e7c..58d597cd 100644 --- a/lib/Dialect/HLS/IR/HLSOps.cpp +++ b/lib/Dialect/HLS/IR/HLSOps.cpp @@ -62,9 +62,8 @@ LogicalResult StreamToTensorOp::verify() { LogicalResult StreamOp::verify() { unsigned numWrites = 0; - for (auto user : (*this)->getUsers()) - if (isa(user)) - numWrites++; + for (auto &use : (*this)->getUses()) + numWrites += isWritten(use); if (numWrites > 1) return emitOpError() << "stream is written more than once"; return success(); @@ -120,7 +119,8 @@ LogicalResult StreamReadOp::verify() { if (getInit()) if (getInit().getType() != getResult().getType()) return emitOpError("initial value type doesn't align with result type"); - return verifyTripCountsAndSteps(*this, getChannel()); + return success(); + // return verifyTripCountsAndSteps(*this, getChannel()); } void StreamReadOp::getEffects( @@ -137,7 +137,8 @@ void StreamReadOp::getEffects( LogicalResult StreamWriteOp::verify() { if (getChannel().getType().getElementType() != getValue().getType()) return emitOpError("value type doesn't align with channel type"); - return verifyTripCountsAndSteps(*this, getChannel()); + return success(); + // return verifyTripCountsAndSteps(*this, getChannel()); } void StreamWriteOp::getEffects( @@ -267,12 +268,12 @@ OpFoldResult StreamCastOp::fold(FoldAdaptor adaptor) { } //===----------------------------------------------------------------------===// -// DispatchOp +// ScheduleOp //===----------------------------------------------------------------------===// namespace { template -struct SimplifyDispatchOrTaskOutputs : public OpRewritePattern { +struct SimplifyScheduleOrTaskOutputs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpType op, @@ -315,8 +316,8 @@ struct SimplifyDispatchOrTaskOutputs : public OpRewritePattern { namespace { template -struct InlineDispatchOrTask : public OpRewritePattern { - InlineDispatchOrTask(MLIRContext *context, +struct InlineScheduleOrTask : public OpRewritePattern { + InlineScheduleOrTask(MLIRContext *context, llvm::function_ref condition) : OpRewritePattern(context), condition(condition) {} @@ -340,7 +341,7 @@ struct InlineDispatchOrTask : public OpRewritePattern { namespace { template -struct DemoteYieldedOutput : public OpRewritePattern { +struct DemoteScheduleOrTaskOutputs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpType op, @@ -369,23 +370,23 @@ struct DemoteYieldedOutput : public OpRewritePattern { }; } // namespace -void DispatchOp::getCanonicalizationPatterns(RewritePatternSet &results, +void ScheduleOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add>(context); - results.add>(context, [](DispatchOp op) { + results.add>(context); + results.add>(context, [](ScheduleOp op) { return op.getOps().empty() || llvm::hasSingleElement(op.getOps()); }); - results.add>(context); + results.add>(context); } -LogicalResult DispatchOp::verify() { +LogicalResult ScheduleOp::verify() { if (getResultTypes() != getYieldOp().getOperandTypes()) return emitOpError("yield type doesn't align with result type"); return success(); } /// Get the terminator yield op. -YieldOp DispatchOp::getYieldOp() { +YieldOp ScheduleOp::getYieldOp() { return cast(getBody().front().getTerminator()); } @@ -395,13 +396,13 @@ YieldOp DispatchOp::getYieldOp() { void TaskOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add>(context); - results.add>(context, [](TaskOp op) { + results.add>(context); + results.add>(context, [](TaskOp op) { return llvm::hasSingleElement( - op.getParentOp().getOps()) || + op.getParentOp().getOps()) || llvm::hasSingleElement(op.getOps()); }); - results.add>(context); + results.add>(context); } LogicalResult TaskOp::verify() { @@ -411,8 +412,8 @@ LogicalResult TaskOp::verify() { } /// Get the parent dispatch op. -DispatchOp TaskOp::getDispatchOp() { - return (*this)->getParentOfType(); +ScheduleOp TaskOp::getScheduleOp() { + return (*this)->getParentOfType(); } /// Get the terminator yield op. @@ -420,371 +421,27 @@ YieldOp TaskOp::getYieldOp() { return cast(getBody().front().getTerminator()); } -bool TaskOp::isLivein(Value value) { - auto liveins = Liveness(*this).getLiveIn(&(*this).getBody().front()); - return liveins.count(value); -} - -SmallVector TaskOp::getLiveins() { - auto liveins = Liveness(*this).getLiveIn(&(*this).getBody().front()); - return {liveins.begin(), liveins.end()}; -} - -SmallVector TaskOp::getLiveinUsers(Value livein) { - assert(isLivein(livein) && "invalid livein"); - auto users = llvm::make_filter_range(livein.getUsers(), [&](Operation *user) { - return (*this)->isAncestor(user); - }); - return {users.begin(), users.end()}; -} - -//===----------------------------------------------------------------------===// -// ScheduleOp -//===----------------------------------------------------------------------===// - -namespace { -struct SimplifyScheduleOperands : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ScheduleOp schedule, - PatternRewriter &rewriter) const override { - bool hasUnusedPort = false; - - // Identify input values that are used. - llvm::SmallDenseSet unusedArgs; - SmallVector usedOperands; - for (auto arg : schedule.getBody().getArguments()) - if (arg.use_empty()) { - hasUnusedPort = true; - unusedArgs.insert(arg); - } else { - usedOperands.push_back(schedule.getOperand(arg.getArgNumber())); - } - schedule.getBody().front().eraseArguments( - [&](BlockArgument arg) { return unusedArgs.count(arg); }); - - // Construct new schedule. - if (hasUnusedPort) { - rewriter.setInsertionPoint(schedule); - auto newSchedule = - rewriter.create(schedule.getLoc(), usedOperands); - rewriter.inlineRegionBefore(schedule.getBody(), newSchedule.getBody(), - newSchedule.getBody().end()); - rewriter.eraseOp(schedule); - return success(); - } - return failure(); - } -}; -} // namespace - -namespace { -template -struct InlineScheduleOrNode : public OpRewritePattern { - InlineScheduleOrNode(MLIRContext *context, - llvm::function_ref condition) - : OpRewritePattern(context), condition(condition) {} - - LogicalResult matchAndRewrite(OpType op, - PatternRewriter &rewriter) const override { - if (condition(op)) { - auto &ops = op.getBody().front().getOperations(); - auto &parentOps = op->getBlock()->getOperations(); - parentOps.splice(op->getIterator(), ops); - - for (auto t : llvm::zip(op.getBody().getArguments(), op.getOperands())) - std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); - rewriter.eraseOp(op); - return success(); - } - return failure(); - } - -private: - llvm::function_ref condition; -}; -} // namespace - -void ScheduleOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); - results.add>( - context, [](ScheduleOp op) { return op.getOps().empty(); }); -} - -LogicalResult ScheduleOp::verify() { - if (getOperandTypes() != getBody().getArgumentTypes()) - return emitOpError("operand type doesn't align with argument type"); - - if (getIsLegal()) - for (auto &op : getOps()) - if (!isa(op)) { - auto diag = emitOpError("legal schedule has illegal ops:\n"); - diag.attachNote(op.getLoc()) - .append("see current op: ") - .appendOp(op, OpPrintingFlags().printGenericOpForm()); - return diag; - } - return success(); -} +// bool TaskOp::isLivein(Value value) { +// auto liveins = Liveness(*this).getLiveIn(&(*this).getBody().front()); +// return liveins.count(value); +// } -void ScheduleOp::getEffects( - SmallVectorImpl> - &effects) { - for (auto value : getOperands()) - if (value.getType().isa()) { - effects.emplace_back(MemoryEffects::Read::get(), value, - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, - SideEffects::DefaultResource::get()); - } -} +// SmallVector TaskOp::getLiveins() { +// auto liveins = Liveness(*this).getLiveIn(&(*this).getBody().front()); +// return {liveins.begin(), liveins.end()}; +// } -/// FIXME: Check whether the schedule is dependence free. -bool ScheduleOp::isDependenceFree() { - return isa((*this)->getParentOp()); -} - -/// Update the signature of the schedule op recursively. -void ScheduleOp::updateSignatureRecursively() { - for (auto [operand, arg] : llvm::zip(getOperands(), getBody().getArguments())) - arg.setType(operand.getType()); - for (auto node : getOps()) - node.updateSignatureRecursively(); -} +// SmallVector TaskOp::getLiveinUsers(Value livein) { +// assert(isLivein(livein) && "invalid livein"); +// auto users = llvm::make_filter_range(livein.getUsers(), [&](Operation +// *user) { +// return (*this)->isAncestor(user); +// }); +// return {users.begin(), users.end()}; +// } //===----------------------------------------------------------------------===// -// NodeOp -//===----------------------------------------------------------------------===// - -namespace { -struct SimplifyNodeIOs : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(NodeOp node, - PatternRewriter &rewriter) const override { - bool hasUnusedPort = false; - - // Identify input values that are used. - llvm::SmallDenseSet unusedArgs; - SmallVector usedInputs; - SmallVector usedInputTaps; - SmallVector usedOutputs; - SmallVector usedParams; - for (auto arg : node.getBody().getArguments()) - if (arg.use_empty()) { - hasUnusedPort = true; - unusedArgs.insert(arg); - } else { - auto idx = arg.getArgNumber(); - if (node.getPortKind(idx) == PortKind::INPUT) { - usedInputs.push_back(node.getOperand(idx)); - usedInputTaps.push_back(node.getInputTap(idx)); - } else if (node.getPortKind(idx) == PortKind::OUTPUT) - usedOutputs.push_back(node.getOperand(idx)); - else - usedParams.push_back(node.getOperand(idx)); - } - node.getBody().front().eraseArguments( - [&](BlockArgument arg) { return unusedArgs.count(arg); }); - - // Construct new dataflow node. - if (hasUnusedPort) { - rewriter.setInsertionPoint(node); - auto newNode = rewriter.create( - node.getLoc(), usedInputs, usedOutputs, usedParams, - rewriter.getI32ArrayAttr(usedInputTaps), node.getLevelAttr()); - rewriter.inlineRegionBefore(node.getBody(), newNode.getBody(), - newNode.getBody().end()); - rewriter.eraseOp(node); - return success(); - } - return failure(); - } -}; -} // namespace - -void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); - results.add>(context, [](NodeOp op) { - return false; - // return llvm::hasSingleElement(op.getScheduleOp().getOps()); - }); -} - -LogicalResult NodeOp::verify() { - if (getOperandTypes() != getBody().getArgumentTypes()) - return emitOpError("operand type doesn't align with argument type"); - - if (llvm::any_of(getParams(), [](Value param) { - return param.getType().isa(); - })) - return emitOpError("node params should not be memref or stream typed"); - - if (getInputs().size() != getInputTaps().size()) - return emitOpError("number of node inputs and input taps are not aligned"); - for (auto t : llvm::zip(getInputs(), getInputTapsAsInt())) { - auto depth = getBufferDepth(std::get<0>(t)); - auto inputTap = (unsigned)std::get<1>(t); - if (depth <= inputTap) { - auto diag = emitOpError("node input tap is larger than buffer depth, "); - diag << "input tap: " << inputTap << ", depth: " << depth; - } - } - - for (auto inputArg : getInputArgs()) - if (llvm::any_of(inputArg.getUses(), isWritten)) { - auto diag = emitOpError("input operand "); - diag << inputArg << " is written"; - return diag; - } - - for (auto outputArg : getOutputArgs()) - if (!llvm::any_of(outputArg.getUses(), isWritten)) { - auto diag = emitOpError("output operand "); - diag << outputArg << " is not written"; - return diag; - } - - if (getScheduleOp().getIsLegal()) { - if (!getLevel()) - return emitOpError("node is not scheduled"); - - for (auto output : getOutputs()) { - // DRAM buffer is not considered - the dependencies associated with them - // are handled later by tokens. - if (isExtBuffer(output)) - continue; - - if (getDependentConsumers(output, *this).size() > 1 || - getProducers(output).size() > 1) { - auto diag = emitOpError( - "legal schedule violates single-consumer or single-producer, "); - diag << "see current buffer: " << output << "\n"; - for (auto user : output.getUsers()) - diag.attachNote(user->getLoc()) - .append("see current buffer user: ") - .appendOp(*user, OpPrintingFlags().printGenericOpForm()); - return diag; - } - } - } - return success(); -} - -void NodeOp::getEffects( - SmallVectorImpl> - &effects) { - for (auto value : getInputs()) - effects.emplace_back(MemoryEffects::Read::get(), value, - SideEffects::DefaultResource::get()); - for (auto value : getOutputs()) { - effects.emplace_back(MemoryEffects::Read::get(), value, - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, - SideEffects::DefaultResource::get()); - } -} - -/// Get the parent schedule op. -ScheduleOp NodeOp::getScheduleOp() { - return (*this)->getParentOfType(); -} - -/// Get input taps. -void NodeOp::setInputTap(unsigned idx, unsigned tap) { - SmallVector newInputTaps(llvm::map_range( - getInputTapsAsInt(), [](unsigned a) { return (int32_t)a; })); - newInputTaps[idx] = tap; - Builder builder(getContext()); - setInputTapsAttr(builder.getI32ArrayAttr(newInputTaps)); -} -unsigned NodeOp::getInputTap(unsigned idx) { - return getInputTaps()[idx].cast().getInt(); -} -SmallVector NodeOp::getInputTapsAsInt() { - auto array = llvm::map_range(getInputTaps(), [](Attribute attr) { - return attr.cast().getInt(); - }); - return {array.begin(), array.end()}; -} - -/// Return the number of inputs, outputs, and params. -unsigned NodeOp::getNumInputs() { - return getODSOperandIndexAndLength(0).second; -} -unsigned NodeOp::getNumOutputs() { - return getODSOperandIndexAndLength(1).second; -} -unsigned NodeOp::getNumParams() { - return getODSOperandIndexAndLength(2).second; -} - -/// Get the type of operand: input, output, or param. -PortKind NodeOp::getPortKind(OpOperand &operand) { - assert(operand.getOwner() == *this && "invalid operand"); - return getPortKind(operand.getOperandNumber()); -} -PortKind NodeOp::getPortKind(unsigned operandIdx) { - if (operandIdx >= getODSOperandIndexAndLength(2).first) - return PortKind::PARAM; - else if (operandIdx >= getODSOperandIndexAndLength(1).first) - return PortKind::OUTPUT; - else - return PortKind::INPUT; -} - -/// Get the input, output, and param arguments. -iterator_range NodeOp::getInputArgs() { - auto range = getODSOperandIndexAndLength(0); - return {std::next(getBody().args_begin(), range.first), - std::next(getBody().args_begin(), range.first + range.second)}; -} -iterator_range NodeOp::getOutputArgs() { - auto range = getODSOperandIndexAndLength(1); - return {std::next(getBody().args_begin(), range.first), - std::next(getBody().args_begin(), range.first + range.second)}; -} -iterator_range NodeOp::getParamArgs() { - auto range = getODSOperandIndexAndLength(2); - return {std::next(getBody().args_begin(), range.first), - std::next(getBody().args_begin(), range.first + range.second)}; -} - -bool NodeOp::isLivein(Value value) { - return value.isa() && - value.getParentRegion() == &(*this).getBody(); -} - -SmallVector NodeOp::getLiveins() { - auto args = (*this).getBody().getArguments(); - return {args.begin(), args.end()}; -} - -SmallVector NodeOp::getLiveinUsers(Value livein) { - assert(isLivein(livein) && "invalid livein"); - auto users = livein.getUsers(); - return {users.begin(), users.end()}; -} - -/// Update the signature of the node op recursively. -void NodeOp::updateSignatureRecursively() { - llvm::SmallDenseSet schedules; - for (auto [operand, arg] : - llvm::zip(getOperands(), getBody().getArguments())) { - arg.setType(operand.getType()); - for (auto user : arg.getUsers()) - if (auto schedule = dyn_cast(user)) - schedules.insert(schedule); - } - // TODO: How to traverse all schedule ops? - for (auto schedule : schedules) - schedule.updateSignatureRecursively(); -} - -//===----------------------------------------------------------------------===// -// BufferOp and ConstBufferOp +// BufferOp //===----------------------------------------------------------------------===// namespace { @@ -809,124 +466,15 @@ struct FlattenReadOnlyBuffer : public OpRewritePattern { }; } // namespace -static NodeOp sinkBufferIntoNode(NodeOp node, BufferOp buffer, - PatternRewriter &rewriter) { - assert(node->getParentRegion() == buffer->getParentRegion() && - "node and buffer is not at the same region"); - SmallVector inputs; - SmallVector inputTaps; - SmallVector outputs; - llvm::BitVector eraseIndices; - - for (auto input : llvm::enumerate(node.getInputs())) { - if (input.value() != buffer) { - inputs.push_back(input.value()); - inputTaps.push_back(node.getInputTap(input.index())); - eraseIndices.push_back(false); - } else { - auto arg = node.getBody().getArgument(input.index()); - arg.replaceAllUsesWith(buffer); - eraseIndices.push_back(true); - } - } - for (auto output : llvm::enumerate(node.getOutputs())) { - if (output.value() != buffer) { - outputs.push_back(output.value()); - eraseIndices.push_back(false); - } else { - auto arg = - node.getBody().getArgument(node.getNumInputs() + output.index()); - arg.replaceAllUsesWith(buffer); - eraseIndices.push_back(true); - } - } - for (unsigned i = 0; i < node.getNumParams(); ++i) - eraseIndices.push_back(false); - - auto &nodeBlock = node.getBody().front(); - buffer->moveBefore(&nodeBlock.front()); - nodeBlock.eraseArguments(eraseIndices); - - rewriter.setInsertionPointAfter(node); - auto newNode = - rewriter.create(node.getLoc(), inputs, outputs, node.getParams(), - inputTaps, node.getLevelAttr()); - rewriter.inlineRegionBefore(node.getBody(), newNode.getBody(), - newNode.getBody().begin()); - rewriter.eraseOp(node); - return newNode; -} - -static ScheduleOp sinkBufferIntoSchedule(ScheduleOp schedule, BufferOp buffer, - PatternRewriter &rewriter) { - assert(schedule->getParentRegion() == buffer->getParentRegion() && - "node and buffer is not at the same region"); - SmallVector operands; - llvm::BitVector eraseIndices; - - for (auto operand : llvm::enumerate(schedule.getOperands())) { - if (operand.value() != buffer) { - operands.push_back(operand.value()); - eraseIndices.push_back(false); - } else - eraseIndices.push_back(true); - } - - auto &scheduleBlock = schedule.getBody().front(); - buffer->moveBefore(&scheduleBlock.front()); - scheduleBlock.eraseArguments(eraseIndices); - - rewriter.setInsertionPointAfter(schedule); - auto newSchedule = rewriter.create(schedule.getLoc(), operands, - schedule.getIsLegalAttr()); - rewriter.inlineRegionBefore(schedule.getBody(), newSchedule.getBody(), - newSchedule.getBody().begin()); - rewriter.eraseOp(schedule); - return newSchedule; -} - -namespace { -struct SinkInternalBuffer : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BufferOp buffer, - PatternRewriter &rewriter) const override { - if (!isExtBuffer(buffer) && llvm::hasSingleElement(buffer->getUsers())) { - auto user = *buffer->getUsers().begin(); - - // Sink the buffer into the node or schedule user. - if (user->getParentRegion() == buffer->getParentRegion() && - isa(user)) { - if (auto node = dyn_cast(user)) - sinkBufferIntoNode(node, buffer, rewriter); - else if (auto schedule = dyn_cast(user)) - sinkBufferIntoSchedule(schedule, buffer, rewriter); - return success(); - } - } - return failure(); - } -}; -} // namespace - void BufferOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); - results.add(context); } LogicalResult BufferOp::verify() { if (auto initValue = getInitValue()) if (initValue.value().getType() != getType().getElementType()) return emitOpError("initial value's type doesn't align with memref type"); - - if (isExtBuffer(*this)) { - if (auto node = dyn_cast((*this)->getParentOp())) - return emitOpError("external buffer should not be placed in node"); - if (auto schedule = dyn_cast((*this)->getParentOp())) - if (!isa(schedule->getParentOp())) - return emitOpError("external buffer must be placed in top schedule"); - } return success(); } @@ -941,6 +489,10 @@ void BufferOp::getEffects( SideEffects::DefaultResource::get()); } +//===----------------------------------------------------------------------===// +// ConstBufferOp +//===----------------------------------------------------------------------===// + std::optional ConstBufferOp::getBufferInitValue() { return std::optional(); } diff --git a/lib/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.cpp b/lib/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.cpp index 15eb46f3..f0877306 100644 --- a/lib/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/lib/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.cpp @@ -16,20 +16,20 @@ using namespace bufferization; using namespace scalehls; using namespace hls; -/// Bufferization of dispatch/task operation. Replace with a new dispatch/task +/// Bufferization of schedule/task operation. Replace with a new schedule/task /// that yields memrefs. template -struct DispatchOrTaskOpInterface +struct ScheduleOrTaskOpInterface : public BufferizableOpInterface::ExternalModel< - DispatchOrTaskOpInterface, OpType> { - /// Dispatch/task do not have tensor OpOperands. Thus, no OpOperand will be + ScheduleOrTaskOpInterface, OpType> { + /// Schedule/task do not have tensor OpOperands. Thus, no OpOperand will be /// bufferized to memory read/write or be aliased to any returned values. AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } - // Dispatch/task do not have tensor OpOperands. The yielded value can be any + // Schedule/task do not have tensor OpOperands. The yielded value can be any // SSA value that is in scope. To allow for use-def chain traversal in the // analysis, the yielded value is aliasing with the result. AliasingOpOperandList @@ -58,13 +58,13 @@ struct DispatchOrTaskOpInterface newTypes.push_back(*bufferType); } - // Create new dispatch/task op. + // Create new schedule/task op. rewriter.setInsertionPoint(concreteOp); auto newOp = rewriter.create(concreteOp.getLoc(), newTypes); rewriter.inlineRegionBefore(concreteOp.getBody(), newOp.getBody(), newOp.getBody().end()); - // Replace dispatch/task op results. + // Replace schedule/task op results. replaceOpWithBufferizedValues(rewriter, concreteOp, newOp->getResults()); return success(); } @@ -104,7 +104,7 @@ struct YieldOpInterface AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - if (isa(op->getParentOp())) + if (isa(op->getParentOp())) return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), BufferRelation::Equivalent}}; return {}; @@ -238,8 +238,8 @@ struct TensorInitOpInterface void mlir::scalehls::hls::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, HLSDialect *dialect) { - DispatchOp::attachInterface>(*ctx); - TaskOp::attachInterface>(*ctx); + ScheduleOp::attachInterface>(*ctx); + TaskOp::attachInterface>(*ctx); YieldOp::attachInterface(*ctx); hls::TensorInitOp::attachInterface(*ctx); }); diff --git a/lib/Dialect/HLS/Transforms/CMakeLists.txt b/lib/Dialect/HLS/Transforms/CMakeLists.txt index bf31e56e..3f74d502 100644 --- a/lib/Dialect/HLS/Transforms/CMakeLists.txt +++ b/lib/Dialect/HLS/Transforms/CMakeLists.txt @@ -4,8 +4,7 @@ add_mlir_dialect_library(MLIRScaleHLSHLSTransforms ReduceTensorToStream.cpp MaterializeStream.cpp ScalarizeStream.cpp - CreateDataflow.cpp - LowerDataflow.cpp + ScheduleDataflow.cpp ConvertDataflowToFunc.cpp ApplyTransformPattern.cpp ComprehensiveBufferize.cpp diff --git a/lib/Dialect/HLS/Transforms/ConvertDataflowToFunc.cpp b/lib/Dialect/HLS/Transforms/ConvertDataflowToFunc.cpp index b7c3fce1..903a67d2 100644 --- a/lib/Dialect/HLS/Transforms/ConvertDataflowToFunc.cpp +++ b/lib/Dialect/HLS/Transforms/ConvertDataflowToFunc.cpp @@ -4,6 +4,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -23,27 +24,22 @@ struct InlineSchedule : public OpRewritePattern { PatternRewriter &rewriter) const override { auto &scheduleOps = schedule.getBody().front().getOperations(); auto &parentOps = schedule->getBlock()->getOperations(); - parentOps.splice(schedule->getIterator(), scheduleOps); - - for (auto t : - llvm::zip(schedule.getBody().getArguments(), schedule.getOperands())) - std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); - - if (schedule.getIsLegal()) { - if (auto func = dyn_cast(schedule->getParentOp())) - setFuncDirective(func, /*pipeline=*/false, /*targetInterval=*/1, - /*dataflow=*/true); - else if (auto loop = dyn_cast(schedule->getParentOp())) { - // If the schedule is located inside of a loop nest, try to coalesce - // them into a flattened loop. - AffineLoopBand band; - getLoopBandFromInnermost(loop, band); - auto dataflowLoop = loop; - if (isPerfectlyNested(band) && succeeded(coalesceLoops(band))) - dataflowLoop = band.front(); - setLoopDirective(dataflowLoop, /*pipeline=*/false, /*targetII=*/1, - /*dataflow=*/true, /*flattern=*/false); - } + parentOps.splice(schedule->getIterator(), scheduleOps, scheduleOps.begin(), + std::prev(scheduleOps.end())); + + if (auto func = dyn_cast(schedule->getParentOp())) + setFuncDirective(func, /*pipeline=*/false, /*targetInterval=*/1, + /*dataflow=*/true); + else if (auto loop = dyn_cast(schedule->getParentOp())) { + // If the schedule is located inside of a loop nest, try to coalesce + // them into a flattened loop. + AffineLoopBand band; + getLoopBandFromInnermost(loop, band); + auto dataflowLoop = loop; + if (isPerfectlyNested(band) && succeeded(coalesceLoops(band))) + dataflowLoop = band.front(); + setLoopDirective(dataflowLoop, /*pipeline=*/false, /*targetII=*/1, + /*dataflow=*/true, /*flattern=*/false); } rewriter.eraseOp(schedule); return success(); @@ -52,39 +48,63 @@ struct InlineSchedule : public OpRewritePattern { } // namespace namespace { -struct ConvertNodeToFunc : public OpRewritePattern { - ConvertNodeToFunc(MLIRContext *context, StringRef prefix, unsigned &nodeIdx) - : OpRewritePattern(context), prefix(prefix), nodeIdx(nodeIdx) {} +struct ConvertTaskToFunc : public OpRewritePattern { + ConvertTaskToFunc(MLIRContext *context, StringRef prefix, unsigned &taskIdx) + : OpRewritePattern(context), prefix(prefix), taskIdx(taskIdx) {} - LogicalResult matchAndRewrite(NodeOp node, + LogicalResult matchAndRewrite(TaskOp task, PatternRewriter &rewriter) const override { + if (task.getNumResults()) + return task.emitOpError("should not yield any results"); + + // Collect all live-ins of the task. + SmallVector operands; + SmallVector operandLocs; + auto liveins = Liveness(task).getLiveIn(&task.getBody().front()); + for (auto livein : liveins) { + if (task.getBody().isAncestor(livein.getParentRegion())) + continue; + operands.push_back(livein); + operandLocs.push_back(livein.getLoc()); + } + // Create a new sub-function. - rewriter.setInsertionPoint(node->getParentOfType()); + rewriter.setInsertionPoint(task->getParentOfType()); auto subFunc = rewriter.create( - node.getLoc(), prefix.str() + "_node" + std::to_string(nodeIdx++), - rewriter.getFunctionType(node.getOperandTypes(), TypeRange())); + task.getLoc(), prefix.str() + "_task" + std::to_string(taskIdx++), + rewriter.getFunctionType(TypeRange(operands), TypeRange())); + subFunc->setAttrs(task->getAttrs()); // FIXME: A better method to judge whether to inline the node. - if (!node.hasHierarchy() && - llvm::hasSingleElement(node.getOps())) + if (!task.hasHierarchy() && + llvm::hasSingleElement(task.getOps())) subFunc->setAttr("inline", rewriter.getUnitAttr()); - // Inline the contents of the dataflow node. - rewriter.inlineRegionBefore(node.getBodyRegion(), subFunc.getBody(), - subFunc.end()); - rewriter.setInsertionPointToEnd(&subFunc.front()); - rewriter.create(rewriter.getUnknownLoc()); + // Construct the body and arguments of the sub-function. + auto subFuncBlock = rewriter.createBlock(&subFunc.getBody()); + auto args = subFuncBlock->addArguments(TypeRange(operands), operandLocs); + for (auto [operand, arg] : llvm::zip(operands, args)) + operand.replaceUsesWithIf(arg, [&](OpOperand &use) { + return task->isAncestor(use.getOwner()); + }); + + // Inline the task body into the sub-function. + auto &subFuncOps = subFuncBlock->getOperations(); + auto &taskOps = task.getBody().front().getOperations(); + subFuncOps.splice(subFuncOps.begin(), taskOps, taskOps.begin(), + std::prev(taskOps.end())); + rewriter.setInsertionPointToEnd(subFuncBlock); + rewriter.create(task.getYieldOp().getLoc()); // Replace original with a function call. - rewriter.setInsertionPoint(node); - rewriter.replaceOpWithNewOp(node, subFunc, - node.getOperands()); + rewriter.setInsertionPoint(task); + rewriter.replaceOpWithNewOp(task, subFunc, operands); return success(); } private: StringRef prefix; - unsigned &nodeIdx; + unsigned &taskIdx; }; } // namespace @@ -94,13 +114,26 @@ struct ConvertDataflowToFunc void runOnOperation() override { auto module = getOperation(); auto context = module.getContext(); + auto builder = OpBuilder(context); + + // Collect all constants in the function and localize them to uses. + SmallVector constants; + module.walk([&](arith::ConstantOp op) { constants.push_back(op); }); + for (auto constant : constants) { + for (auto &use : llvm::make_early_inc_range(constant->getUses())) { + builder.setInsertionPoint(use.getOwner()); + auto cloneConstant = cast(builder.clone(*constant)); + use.set(cloneConstant.getResult()); + } + constant->erase(); + } for (auto func : llvm::make_early_inc_range(module.getOps())) { - unsigned nodeIdx = 0; + unsigned taskIdx = 0; mlir::RewritePatternSet patterns(context); patterns.add(context); - patterns.add(context, func.getName(), nodeIdx); + patterns.add(context, func.getName(), taskIdx); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } } diff --git a/lib/Dialect/HLS/Transforms/LowerDataflow.cpp b/lib/Dialect/HLS/Transforms/LowerDataflow.cpp deleted file mode 100644 index 55687170..00000000 --- a/lib/Dialect/HLS/Transforms/LowerDataflow.cpp +++ /dev/null @@ -1,184 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright 2020-2021 The ScaleHLS Authors. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Analysis/Liveness.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "scalehls/Dialect/HLS/Transforms/Passes.h" -#include "scalehls/Utils/Utils.h" - -using namespace mlir; -using namespace scalehls; -using namespace hls; - -namespace { -struct ConvertDispatchToSchedule : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DispatchOp dispatch, - PatternRewriter &rewriter) const override { - if (dispatch.getNumResults()) - return dispatch.emitOpError("should not yield any results"); - - auto isInDispatch = [&](OpOperand &use) { - return dispatch->isAncestor(use.getOwner()); - }; - - SmallVector inputs; - SmallVector inputLocs; - - auto liveins = Liveness(dispatch).getLiveIn(&dispatch.getBody().front()); - for (auto livein : liveins) { - if (dispatch.getBody().isAncestor(livein.getParentRegion())) - continue; - inputs.push_back(livein); - inputLocs.push_back(livein.getLoc()); - } - - rewriter.setInsertionPoint(dispatch); - auto schedule = - rewriter.create(rewriter.getUnknownLoc(), inputs); - auto scheduleBlock = rewriter.createBlock(&schedule.getBody()); - - auto inputArgs = scheduleBlock->addArguments(ValueRange(inputs), inputLocs); - for (auto t : llvm::zip(inputs, inputArgs)) - std::get<0>(t).replaceUsesWithIf(std::get<1>(t), isInDispatch); - - auto &scheduleOps = scheduleBlock->getOperations(); - auto &dispatchOps = dispatch.getBody().front().getOperations(); - scheduleOps.splice(scheduleOps.begin(), dispatchOps, dispatchOps.begin(), - std::prev(dispatchOps.end())); - - rewriter.eraseOp(dispatch); - return success(); - } -}; -} // namespace - -namespace { -struct ConvertTaskToNode : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TaskOp task, - PatternRewriter &rewriter) const override { - if (task.getNumResults()) - return task.emitOpError("should not yield any results"); - - auto isInTask = [&](OpOperand &use) { - return task->isAncestor(use.getOwner()); - }; - - SmallVector inputs; - SmallVector inputLocs; - SmallVector outputs; - SmallVector outputLocs; - SmallVector params; - SmallVector paramLocs; - - auto liveins = Liveness(task).getLiveIn(&task.getBody().front()); - for (auto livein : liveins) { - if (task.getBody().isAncestor(livein.getParentRegion())) - continue; - - if (livein.getType().isa()) { - auto uses = llvm::make_filter_range(livein.getUses(), isInTask); - if (llvm::any_of(uses, [](OpOperand &use) { return isWritten(use); })) { - outputs.push_back(livein); - outputLocs.push_back(livein.getLoc()); - } else { - inputs.push_back(livein); - inputLocs.push_back(livein.getLoc()); - } - } else { - params.push_back(livein); - paramLocs.push_back(livein.getLoc()); - } - } - - rewriter.setInsertionPoint(task); - auto node = rewriter.create(rewriter.getUnknownLoc(), inputs, - outputs, params); - node->setAttrs(task->getAttrs()); - auto nodeBlock = rewriter.createBlock(&node.getBody()); - - auto inputArgs = nodeBlock->addArguments(ValueRange(inputs), inputLocs); - for (auto t : llvm::zip(inputs, inputArgs)) - std::get<0>(t).replaceUsesWithIf(std::get<1>(t), isInTask); - - auto outputArgs = - node.getBody().addArguments(ValueRange(outputs), outputLocs); - for (auto t : llvm::zip(outputs, outputArgs)) - std::get<0>(t).replaceUsesWithIf(std::get<1>(t), isInTask); - - auto paramArgs = nodeBlock->addArguments(ValueRange(params), paramLocs); - for (auto t : llvm::zip(params, paramArgs)) - std::get<0>(t).replaceUsesWithIf(std::get<1>(t), isInTask); - - auto &nodeOps = nodeBlock->getOperations(); - auto &taskOps = task.getBody().front().getOperations(); - nodeOps.splice(nodeOps.begin(), taskOps, taskOps.begin(), - std::prev(taskOps.end())); - - rewriter.eraseOp(task); - return success(); - } -}; -} // namespace - -namespace { -struct ConvertConstantToConstBuffer - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(bufferization::ToMemrefOp op, - PatternRewriter &rewriter) const override { - if (auto constant = op.getTensor().getDefiningOp()) { - rewriter.replaceOpWithNewOp( - op, op.getType(), constant.getValue().cast()); - return success(); - } - return failure(); - } -}; -} // namespace - -namespace { -struct LowerDataflow : public LowerDataflowBase { - void runOnOperation() override { - auto func = getOperation(); - auto context = func.getContext(); - auto builder = OpBuilder(context); - - // Collect all constants in the function and localize them to uses. - SmallVector constants; - func.walk([&](arith::ConstantOp op) { constants.push_back(op); }); - for (auto constant : constants) { - for (auto &use : llvm::make_early_inc_range(constant->getUses())) { - builder.setInsertionPoint(use.getOwner()); - auto cloneConstant = cast(builder.clone(*constant)); - use.set(cloneConstant.getResult()); - } - constant->erase(); - } - - // Convert dispatch, task, and to_memref operations. - ConversionTarget target(*context); - target.addIllegalOp(); - target.addLegalOp(); - - mlir::RewritePatternSet patterns(context); - patterns.add(context); - patterns.add(context); - // patterns.add(context); - if (failed(applyPartialConversion(func, target, std::move(patterns)))) - return signalPassFailure(); - } -}; -} // namespace - -std::unique_ptr scalehls::hls::createLowerDataflowPass() { - return std::make_unique(); -} diff --git a/lib/Dialect/HLS/Transforms/ScalarizeStream.cpp b/lib/Dialect/HLS/Transforms/ScalarizeStream.cpp index 1c688cc0..b815d90e 100644 --- a/lib/Dialect/HLS/Transforms/ScalarizeStream.cpp +++ b/lib/Dialect/HLS/Transforms/ScalarizeStream.cpp @@ -176,7 +176,7 @@ struct ScalarizeStreamReassociateOp namespace { template -struct ScalarizeDispatchOrTaskOp : public OpRewritePattern { +struct ScalarizeScheduleOrTaskOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, @@ -222,8 +222,8 @@ struct ScalarizeStream : public ScalarizeStreamBase { patterns.add(context); patterns.add(context); patterns.add(context); - patterns.add>(context); - patterns.add>(context); + patterns.add>(context); + patterns.add>(context); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; diff --git a/lib/Dialect/HLS/Transforms/CreateDataflow.cpp b/lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp similarity index 72% rename from lib/Dialect/HLS/Transforms/CreateDataflow.cpp rename to lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp index 2a8e5396..ec57a2cc 100644 --- a/lib/Dialect/HLS/Transforms/CreateDataflow.cpp +++ b/lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp @@ -15,13 +15,13 @@ using namespace scalehls; using namespace hls; namespace { -struct DispatchFuncOp : public OpRewritePattern { +struct ScheduleFuncOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(func::FuncOp func, PatternRewriter &rewriter) const override { - auto funcDispatch = dispatchBlock(func.getName(), &func.front(), rewriter); - if (!funcDispatch) + auto funcSchedule = scheduleBlock(func.getName(), &func.front(), rewriter); + if (!funcSchedule) return failure(); unsigned loopId; @@ -30,7 +30,7 @@ struct DispatchFuncOp : public OpRewritePattern { std::string name = func.getName().str() + "_loop" + std::to_string(loopId++); auto loopBody = &op->getRegion(0).getBlocks().front(); - dispatchBlock(name, loopBody, rewriter); + scheduleBlock(name, loopBody, rewriter); } }); return success(); @@ -39,19 +39,19 @@ struct DispatchFuncOp : public OpRewritePattern { } // namespace namespace { -struct CreateDataflow : public CreateDataflowBase { +struct ScheduleDataflow : public ScheduleDataflowBase { void runOnOperation() override { auto func = getOperation(); auto context = func.getContext(); - // Dispatch the current function to create the dataflow hierarchy. + // Schedule the current function to create the dataflow hierarchy. mlir::RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); (void)applyOpPatternsAndFold({func}, std::move(patterns)); } }; } // namespace -std::unique_ptr scalehls::hls::createCreateDataflowPass() { - return std::make_unique(); +std::unique_ptr scalehls::hls::createScheduleDataflowPass() { + return std::make_unique(); } diff --git a/lib/Pipelines/Pipelines.cpp b/lib/Pipelines/Pipelines.cpp index e53c4c2e..7f086854 100644 --- a/lib/Pipelines/Pipelines.cpp +++ b/lib/Pipelines/Pipelines.cpp @@ -73,9 +73,7 @@ void scalehls::registerScaleHLSPyTorchPipeline() { [](OpPassManager &pm, const ScaleHLSPyTorchPipelineOptions &opts) { addLinalgTransformPasses(pm); addComprehensiveBufferizePasses(pm); - pm.addNestedPass(hls::createCreateDataflowPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(hls::createLowerDataflowPass()); + pm.addNestedPass(hls::createScheduleDataflowPass()); pm.addPass(mlir::createCanonicalizerPass()); addConvertDataflowToFuncPasses(pm); });