diff --git a/include/scalehls-c/Dialect/HLS/HLS.h b/include/scalehls-c/Dialect/HLS/HLS.h index 888ad7c9..70c9a39a 100644 --- a/include/scalehls-c/Dialect/HLS/HLS.h +++ b/include/scalehls-c/Dialect/HLS/HLS.h @@ -28,6 +28,10 @@ mlirSemanticsInitializeBlockArguments(MlirOperation semantics, // HLS Dialect Types //===----------------------------------------------------------------------===// +MLIR_CAPI_EXPORTED bool mlirTypeIsHLSStructType(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirHLSStructTypeGet(MlirStringRef name, + MlirContext ctx); + MLIR_CAPI_EXPORTED bool mlirTypeIsHLSTypeType(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirHLSTypeTypeGet(MlirContext ctx); diff --git a/include/scalehls/Dialect/HLS/IR/HLS.h b/include/scalehls/Dialect/HLS/IR/HLS.h index 8be33798..655431f8 100644 --- a/include/scalehls/Dialect/HLS/IR/HLS.h +++ b/include/scalehls/Dialect/HLS/IR/HLS.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -38,6 +39,37 @@ struct Push : public MemoryEffects::Effect::Base {}; struct Pop : public MemoryEffects::Effect::Base {}; } // namespace StreamEffects +/// Printer hook for custom directive in assemblyFormat. +/// +/// custom($templates, $staticTemplates) +/// +/// where `template` is of ODS type `Variadic` and `staticTemplates` +/// is of ODS type `ArrayAttr`. Prints a list with either (1) the static +/// attribute value in `staticTemplates` is `dynVal` or (2) the next value +/// otherwise. This allows idiomatic printing of mixed value and attributes in a +/// list. E.g. `<%arg0, 7, f32, %arg42>`. +void printDynamicTemplateList(OpAsmPrinter &printer, Operation *op, + OperandRange templates, + ArrayAttr staticTemplates); + +/// Pasrer hook for custom directive in assemblyFormat. +/// +/// custom($templates, $staticTemplates) +/// +/// where `templates` is of ODS type `Variadic` and `staticTemplates` +/// is of ODS type `ArrayAttr`. Parse a mixed list with either (1) static +/// templates or (2) SSA templates. Fill `staticTemplates` with the ArrayAttr, +/// where `dynVal` encodes the position of SSA templates. Add the parsed SSA +/// templates to `templates` in-order. +// +/// E.g. after parsing "<%arg0, 7, f32, %arg42>": +/// 1. `result` is filled with the ArrayAttr "[`dynVal`, 7, f32, `dynVal`]" +/// 2. `ssa` is filled with "[%arg0, %arg42]". +ParseResult parseDynamicTemplateList( + OpAsmParser &parser, + SmallVectorImpl &templates, + ArrayAttr &staticTemplates); + } // namespace hls } // namespace scalehls } // namespace mlir diff --git a/include/scalehls/Dialect/HLS/IR/HLSAttributes.td b/include/scalehls/Dialect/HLS/IR/HLSAttributes.td index 41887afc..6124d100 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSAttributes.td +++ b/include/scalehls/Dialect/HLS/IR/HLSAttributes.td @@ -15,6 +15,10 @@ include "mlir/IR/EnumAttr.td" class HLSAttr traits = []> : AttrDef; +def IndexArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getIndexArrayAttr($0)"; +} + //===----------------------------------------------------------------------===// // DSE Attributes //===----------------------------------------------------------------------===// @@ -45,7 +49,7 @@ def TaskImplAttr : HLSAttr<"TaskImpl"> { let parameters = (ins IPLibraryParam:$library, IPNameParam:$name); let mnemonic = "impl"; - let assemblyFormat = "`<` $library `_` $name `>`"; + let assemblyFormat = "`<` $library `:` $name `>`"; let extraClassDeclaration = [{ bool isDefaultImpl() { diff --git a/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td b/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td index c9e9d5ad..2a9be351 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td +++ b/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td @@ -9,6 +9,25 @@ include "mlir/IR/OpBase.td" +def TemplatedOpInterface : OpInterface<"TemplatedOpInterface"> { + let methods = [ + InterfaceMethod<"Return the composed templates", + "mlir::SmallVector", "getComposedTemplates", (ins), [{ + SmallVector composedTemplates; + unsigned dynIdx = 0; + for (auto attr : $_op.getStaticTemplates()) { + if (auto intAttr = attr.template dyn_cast()) + if (intAttr.getInt() == ShapedType::kDynamic) { + composedTemplates.push_back($_op.getTemplates()[dynIdx++]); + continue; + } + composedTemplates.push_back(attr); + } + return composedTemplates; + }]> + ]; +} + def ParamLikeInterface : OpInterface<"ParamLikeInterface"> { let methods = [ InterfaceMethod<"Return the value of the parameter if it exists", diff --git a/include/scalehls/Dialect/HLS/IR/HLSTypes.td b/include/scalehls/Dialect/HLS/IR/HLSTypes.td index b66e6c3a..dc25114a 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSTypes.td +++ b/include/scalehls/Dialect/HLS/IR/HLSTypes.td @@ -12,11 +12,40 @@ include "scalehls/Dialect/HLS/IR/HLSAttributes.td" class HLSType traits = []> : TypeDef; +def StreamType : HLSType<"Stream"> { + let summary = "An HLS stream type"; + let description = [{ + Represents a stream of any type that can be transfered between HLS modules. + This type is equal to the hls::stream<> type in Xilinx Vivado HLS. + }]; + let mnemonic = "stream"; + + let parameters = (ins "mlir::Type":$elementType, "unsigned":$depth); + let assemblyFormat = "`<` qualified($elementType) `,` $depth `>`"; + + let extraClassDeclaration = [{ + static StreamType get(mlir::Type elementType, unsigned depth) { + return get(elementType.getContext(), elementType, depth); + } + static StreamType get(mlir::Type elementType) { + return get(elementType, 1); + } + }]; +} + def SpaceType : HLSType<"Space"> { let summary = "Represent a design space containing multiple parameters"; let mnemonic = "space"; } +def StructType : HLSType<"Struct"> { + let summary = "Represent a struct type"; + let mnemonic = "struct"; + + let parameters = (ins StringRefParameter<>:$name); + let assemblyFormat = "`<` $name `>`"; +} + def TypeType : HLSType<"Type"> { let summary = "Used to represent a type"; let mnemonic = "type"; @@ -37,16 +66,4 @@ def MemoryKindType : HLSType<"MemoryKind"> { let mnemonic = "memory"; } -def StreamType : HLSType<"Stream"> { - let summary = "An HLS stream type"; - let description = [{ - Represents a stream of any type that can be transfered between HLS modules. - This type is equal to the hls::stream<> type in Xilinx Vivado HLS. - }]; - let mnemonic = "stream"; - - let parameters = (ins "mlir::Type":$elementType, "unsigned":$depth); - let assemblyFormat = "`<` qualified($elementType) `,` $depth `>`"; -} - #endif // SCALEHLS_DIALECT_HLS_HLSTYPES_TD diff --git a/include/scalehls/Dialect/HLS/IR/HLSUBUFOps.td b/include/scalehls/Dialect/HLS/IR/HLSUBUFOps.td index 3c75d036..65c44df5 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSUBUFOps.td +++ b/include/scalehls/Dialect/HLS/IR/HLSUBUFOps.td @@ -11,6 +11,36 @@ // Unified Buffer (UBUF) Operations //===----------------------------------------------------------------------===// +def TensorToStreamOp : HLSOp<"ubuf.tensor_to_stream",[ + AttrSizedOperandSegments, NoMemoryEffect]> { + let summary = "Convert a tensor to a stream channel"; + + let arguments = (ins AnyTensor:$tensor, Variadic:$dims, + Variadic:$symbols, MemRefLayoutAttrInterface:$stream_layout, + MemRefLayoutAttrInterface:$memory_layout); + let results = (outs StreamOf<[AnyType]>:$stream); + let assemblyFormat = [{ + $tensor (`[` $dims^ `]`)? (`(` $symbols^ `)`)? `stream_layout` + $stream_layout `memory_layout` $memory_layout attr-dict `:` + functional-type($tensor, $stream) + }]; +} + +def StreamToTensorOp : HLSOp<"ubuf.stream_to_tensor", [ + AttrSizedOperandSegments, NoMemoryEffect]> { + let summary = "Convert a stream channel to a tensor"; + + let arguments = (ins StreamOf<[AnyType]>:$stream, Variadic:$dims, + Variadic:$symbols, MemRefLayoutAttrInterface:$stream_layout, + MemRefLayoutAttrInterface:$memory_layout); + let results = (outs AnyTensor:$tensor); + let assemblyFormat = [{ + $stream (`[` $dims^ `]`)? (`(` $symbols^ `)`)? `stream_layout` + $stream_layout `memory_layout` $memory_layout attr-dict `:` + functional-type($stream, $tensor) + }]; +} + def BufferOp : HLSOp<"ubuf.buffer", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { diff --git a/include/scalehls/Dialect/HLS/IR/HLSUIPOps.td b/include/scalehls/Dialect/HLS/IR/HLSUIPOps.td index e11f1fec..0c7e585c 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSUIPOps.td +++ b/include/scalehls/Dialect/HLS/IR/HLSUIPOps.td @@ -39,19 +39,27 @@ def DeclareOp : HLSOp<"uip.declare", [IsolatedFromAbove, SymbolTable, Symbol, }]; } -def InstanceOp : HLSOp<"uip.instance", [ - DeclareOpInterfaceMethods]> { +def InstanceOp : HLSOp<"uip.instance", [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Instantiate an IP declared by DeclareOp"; - let arguments = (ins Variadic:$ports, ArrayAttr:$templates, - SymbolRefAttr:$name); + let arguments = (ins Variadic:$ports, Variadic:$templates, + ArrayAttr:$static_templates, SymbolRefAttr:$name); let results = (outs Variadic:$results); let assemblyFormat = [{ - $name `<` $templates `>` `(` $ports `)` attr-dict `:` + $name custom($templates, $static_templates) `(` $ports + `)` attr-dict `:` (`<` type($templates)^ `>`)? functional-type($ports, $results) }]; + let builders = [ + OpBuilder<(ins "mlir::TypeRange":$results, "mlir::ValueRange":$ports, + "mlir::ArrayRef":$composedTemplates, + "mlir::SymbolRefAttr":$name)> + ]; + let extraClassDeclaration = [{ /// Get the type of operand: input, output, or param. PortKind getPortKind(OpOperand &operand); @@ -65,17 +73,27 @@ def InstanceOp : HLSOp<"uip.instance", [ }]; } -def PortOp : HLSOp<"uip.port", [Symbol, HasParent<"DeclareOp">]> { +def PortOp : HLSOp<"uip.port", [Symbol, AttrSizedOperandSegments, + HasParent<"DeclareOp">]> { let summary = "Declare a port of an IP"; - let arguments = (ins TypeType:$type, Variadic:$sizes, - MemRefLayoutAttrInterface:$layout, PortKindAttr:$kind, + let arguments = (ins TypeType:$type, Variadic:$dims, + Variadic:$symbols, PortKindAttr:$kind, + OptionalAttr:$stream_layout, + MemRefLayoutAttrInterface:$memory_layout, OptionalAttr:$value, SymbolNameAttr:$sym_name); let results = (outs PortType:$result); let assemblyFormat = [{ - $sym_name $kind `type` $type `sizes` `(` $sizes `)` $layout attr-dict - `:` functional-type($sizes, $result) + $sym_name $kind `type` $type (`[` $dims^ `]`)? (`(` $symbols^ `)`)? + (`stream_layout` $stream_layout^)? `memory_layout` $memory_layout attr-dict + `:` (`[` type($dims)^ `]`)? functional-type($symbols, $result) + }]; + + let extraClassDeclaration = [{ + bool isStream() { + return getStreamLayout().has_value(); + } }]; } @@ -86,8 +104,40 @@ def IncludeOp : HLSOp<"uip.include", [HasParent<"DeclareOp">]> { let assemblyFormat = "$paths attr-dict"; } -def IndexArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getIndexArrayAttr($0)"; +def StructOp : HLSOp<"uip.struct", [Symbol, AttrSizedOperandSegments, + HasParent<"SpaceOp, DeclareOp">]> { + let summary = "Declare a struct containing multiple parameters"; + + let arguments = (ins Variadic:$params, Variadic:$templates, + SymbolNameAttr:$sym_name); + let results = (outs StructType:$result); + + let assemblyFormat = [{ + $sym_name (`<` $templates^ `>`)? `(` $params `)` attr-dict `:` + (`<` type($templates)^ `>`)? functional-type($params, $result) + }]; +} + +def StructInstanceOp : HLSOp<"uip.struct_instance", [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Instantiate a struct declared by StructOp"; + + let arguments = (ins Variadic:$params, Variadic:$templates, + ArrayAttr:$static_templates, SymbolRefAttr:$name); + let results = (outs StructType:$result); + + let assemblyFormat = [{ + $name custom($templates, $static_templates) `(` $params + `)` attr-dict `:` (`<` type($templates)^ `>`)? + functional-type($params, $result) + }]; + + let builders = [ + OpBuilder<(ins "mlir::Type":$result, "mlir::ValueRange":$params, + "mlir::ArrayRef":$composedTemplates, + "mlir::SymbolRefAttr":$name)> + ]; } def SemanticsOp : HLSOp<"uip.semantics", [Terminator, IsolatedFromAbove, @@ -120,6 +170,11 @@ def SemanticsOp : HLSOp<"uip.semantics", [Terminator, IsolatedFromAbove, PortKind getPortKind(OpOperand &operand); PortKind getPortKind(unsigned operandIdx); + + /// The template of an IP could be recursively a struct type. This method + /// can recursively peel off all the structs and return the real templates, + /// which are gauranteed to be ParamOp. + SmallVector getStructPeeledTemplates(); }]; } diff --git a/include/scalehls/Utils/Matchers.h b/include/scalehls/Dialect/HLS/Utils/Matchers.h similarity index 96% rename from include/scalehls/Utils/Matchers.h rename to include/scalehls/Dialect/HLS/Utils/Matchers.h index a85f4a6f..85f373c8 100644 --- a/include/scalehls/Utils/Matchers.h +++ b/include/scalehls/Dialect/HLS/Utils/Matchers.h @@ -4,8 +4,8 @@ // //===----------------------------------------------------------------------===// -#ifndef SCALEHLS_UTILS_MATCHERS_H -#define SCALEHLS_UTILS_MATCHERS_H +#ifndef SCALEHLS_DIALECT_HLS_UTILS_MATCHERS_H +#define SCALEHLS_DIALECT_HLS_UTILS_MATCHERS_H #include "scalehls/Dialect/HLS/IR/HLS.h" #include "scalehls/Utils/Utils.h" @@ -15,8 +15,7 @@ namespace mlir { namespace scalehls { - -using namespace hls; +namespace hls { //===----------------------------------------------------------------------===// // BlockMatcher @@ -257,7 +256,7 @@ struct Port : OpFoldResult { struct IPMatchingResult { const SmallVector instPorts; - const SmallVector instTemplates; + const SmallVector instStructPeeledTemplates; const LinalgMatchingResult result; unsigned mapIpResIndexToPayload(unsigned ipResIndex) const { @@ -268,9 +267,10 @@ struct IPMatchingResult { } IPMatchingResult(const SmallVector &instPorts, - const SmallVector &instTemplates, + const SmallVector &instStructPeeledTemplates, const LinalgMatchingResult &result) - : instPorts(instPorts), instTemplates(instTemplates), result(result) {} + : instPorts(instPorts), + instStructPeeledTemplates(instStructPeeledTemplates), result(result) {} }; struct IPMatchingStatus { @@ -369,7 +369,7 @@ struct IPMatcher { unsigned maxIterations = 3) : payload(payload), declare(declare), maxIterations(maxIterations), status(declare.getSemanticsOp().getPorts(), - declare.getSemanticsOp().getTemplates()) {} + declare.getSemanticsOp().getStructPeeledTemplates()) {} FailureOr match(); @@ -380,7 +380,8 @@ struct IPMatcher { IPMatchingStatus status; }; +} // namespace hls } // namespace scalehls } // namespace mlir -#endif // SCALEHLS_UTILS_MATCHERS_H +#endif // SCALEHLS_DIALECT_HLS_UTILS_MATCHERS_H diff --git a/include/scalehls/Utils/Visitor.h b/include/scalehls/Utils/Visitor.h index 3750ef1d..a2d3a305 100644 --- a/include/scalehls/Utils/Visitor.h +++ b/include/scalehls/Utils/Visitor.h @@ -22,11 +22,9 @@ class HLSVisitorBase { return TypeSwitch(op) .template Case< - // HLS Library Ip operation. - hls::InstanceOp, - // HLS dialect operations. - hls::BufferOp, hls::ConstBufferOp, hls::StreamOp, hls::StreamReadOp, + hls::InstanceOp, hls::StructInstanceOp, hls::BufferOp, + hls::ConstBufferOp, hls::StreamOp, hls::StreamReadOp, hls::StreamWriteOp, hls::AffineSelectOp, // Function operations. @@ -98,9 +96,10 @@ class HLSVisitorBase { ResultType visitOp(OPTYPE op, ExtraArgs... args) { \ return static_cast(this)->visitUnhandledOp(op, args...); \ } - // HLS Library Ip operation. - HANDLE(hls::InstanceOp); + // HLS dialect operations. + HANDLE(hls::InstanceOp); + HANDLE(hls::StructInstanceOp); HANDLE(hls::BufferOp); HANDLE(hls::ConstBufferOp); HANDLE(hls::StreamOp); diff --git a/lib/Bindings/Python/HLSDialect.cpp b/lib/Bindings/Python/HLSDialect.cpp index 07091554..50716f95 100644 --- a/lib/Bindings/Python/HLSDialect.cpp +++ b/lib/Bindings/Python/HLSDialect.cpp @@ -67,6 +67,17 @@ void populateHLSAttributes(py::module &m) { //===----------------------------------------------------------------------===// void populateHLSTypes(py::module &m) { + auto StructType = + mlir_type_subclass(m, "StructType", mlirTypeIsHLSStructType); + StructType.def_classmethod( + "get", + [](py::object cls, std::string name, MlirContext ctx) { + return cls(mlirHLSStructTypeGet( + mlirStringRefCreateFromCString(name.c_str()), ctx)); + }, + "Get an instance of StructType in given context.", py::arg("cls"), + py::arg("name"), py::arg("context") = py::none()); + auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsHLSTypeType); typeType.def_classmethod( "get", diff --git a/lib/CAPI/Dialect/HLS/HLS.cpp b/lib/CAPI/Dialect/HLS/HLS.cpp index 23f41f56..0ffdf7dc 100644 --- a/lib/CAPI/Dialect/HLS/HLS.cpp +++ b/lib/CAPI/Dialect/HLS/HLS.cpp @@ -31,6 +31,13 @@ void mlirSemanticsInitializeBlockArguments( // HLS Dialect Types //===----------------------------------------------------------------------===// +bool mlirTypeIsHLSStructType(MlirType type) { + return unwrap(type).isa(); +} +MlirType mlirHLSStructTypeGet(MlirStringRef name, MlirContext ctx) { + return wrap(hls::StructType::get(unwrap(ctx), unwrap(name))); +} + bool mlirTypeIsHLSTypeType(MlirType type) { return unwrap(type).isa(); } diff --git a/lib/Dialect/HLS/IR/HLS.cpp b/lib/Dialect/HLS/IR/HLS.cpp index d5f17386..90178ceb 100644 --- a/lib/Dialect/HLS/IR/HLS.cpp +++ b/lib/Dialect/HLS/IR/HLS.cpp @@ -36,6 +36,83 @@ void HLSDialect::initialize() { >(); } +//===----------------------------------------------------------------------===// +// Custom Printer/Parser Hooks +//===----------------------------------------------------------------------===// + +/// Printer hook for custom directive in assemblyFormat. +/// +/// custom($templates, $staticTemplates) +/// +/// where `template` is of ODS type `Variadic` and `staticTemplates` +/// is of ODS type `ArrayAttr`. Prints a list with either (1) the static +/// attribute value in `staticTemplates` is `dynVal` or (2) the next value +/// otherwise. This allows idiomatic printing of mixed value and attributes in a +/// list. E.g. `<%arg0, 7, f32, %arg42>`. +void hls::printDynamicTemplateList(OpAsmPrinter &printer, Operation *op, + OperandRange templates, + ArrayAttr staticTemplates) { + char leftDelimiter = '<'; + char rightDelimiter = '>'; + printer << leftDelimiter; + if (staticTemplates.empty()) { + printer << rightDelimiter; + return; + } + unsigned idx = 0; + llvm::interleaveComma(staticTemplates, printer, [&](Attribute attr) { + if (auto integerAttr = attr.dyn_cast()) + if (ShapedType::isDynamic(integerAttr.getInt())) { + printer << templates[idx++]; + return; + } + printer << attr; + }); + printer << rightDelimiter; +} + +/// Pasrer hook for custom directive in assemblyFormat. +/// +/// custom($templates, $staticTemplates) +/// +/// where `templates` is of ODS type `Variadic` and `staticTemplates` +/// is of ODS type `ArrayAttr`. Parse a mixed list with either (1) static +/// templates or (2) SSA templates. Fill `staticTemplates` with the ArrayAttr, +/// where `dynVal` encodes the position of SSA templates. Add the parsed SSA +/// templates to `templates` in-order. +// +/// E.g. after parsing "<%arg0, 7, f32, %arg42>": +/// 1. `result` is filled with the ArrayAttr "[`dynVal`, 7, f32, `dynVal`]" +/// 2. `ssa` is filled with "[%arg0, %arg42]". +ParseResult hls::parseDynamicTemplateList( + OpAsmParser &parser, + SmallVectorImpl &templates, + ArrayAttr &staticTemplates) { + auto builder = parser.getBuilder(); + SmallVector templateVals; + auto parseIntegerOrValue = [&]() { + OpAsmParser::UnresolvedOperand operand; + auto res = parser.parseOptionalOperand(operand); + if (res.has_value() && succeeded(res.value())) { + templates.push_back(operand); + templateVals.push_back(builder.getI64IntegerAttr(ShapedType::kDynamic)); + } else { + Attribute staticTemplate; + if (failed(parser.parseAttribute(staticTemplate))) + return failure(); + templateVals.push_back(staticTemplate); + } + return success(); + }; + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, + parseIntegerOrValue, + " in dynamic template list")) + return parser.emitError(parser.getNameLoc()) + << "expected SSA value or integer"; + staticTemplates = builder.getArrayAttr(templateVals); + return success(); +} + //===----------------------------------------------------------------------===// // Affine SelectOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/HLS/IR/HLSUIPOps.cpp b/lib/Dialect/HLS/IR/HLSUIPOps.cpp index e3466c64..999a9e0f 100644 --- a/lib/Dialect/HLS/IR/HLSUIPOps.cpp +++ b/lib/Dialect/HLS/IR/HLSUIPOps.cpp @@ -25,6 +25,31 @@ SemanticsOp DeclareOp::getSemanticsOp() { // InstanceOp //===----------------------------------------------------------------------===// +static void dispatchTemplateOpFoldResults( + ArrayRef composedTemplates, SmallVectorImpl &templates, + SmallVectorImpl &staticTemplates, OpBuilder &builder) { + for (OpFoldResult ofr : composedTemplates) { + if (ofr.is()) { + staticTemplates.push_back(ofr.get()); + continue; + } + templates.push_back(ofr.get()); + staticTemplates.push_back(builder.getI64IntegerAttr(ShapedType::kDynamic)); + } +} + +void InstanceOp::build(OpBuilder &builder, OperationState &state, + TypeRange results, ValueRange ports, + ArrayRef composedTemplates, + SymbolRefAttr name) { + SmallVector templates; + SmallVector staticTemplates; + dispatchTemplateOpFoldResults(composedTemplates, templates, staticTemplates, + builder); + build(builder, state, results, ports, templates, + builder.getArrayAttr(staticTemplates), name); +} + LogicalResult InstanceOp::verifySymbolUses(mlir::SymbolTableCollection &table) { if (!getDeclareOp()) return (*this)->emitOpError("unknown IP name ") << getNameAttr(); @@ -57,6 +82,30 @@ DeclareOp InstanceOp::getDeclareOp() { (*this)->getParentOfType(), getNameAttr()); } +//===----------------------------------------------------------------------===// +// StructInstanceOp +//===----------------------------------------------------------------------===// + +void StructInstanceOp::build(OpBuilder &builder, OperationState &state, + Type result, ValueRange params, + ArrayRef composedTemplates, + SymbolRefAttr name) { + SmallVector templates; + SmallVector staticTemplates; + dispatchTemplateOpFoldResults(composedTemplates, templates, staticTemplates, + builder); + build(builder, state, result, params, templates, + builder.getArrayAttr(staticTemplates), name); +} + +LogicalResult +StructInstanceOp::verifySymbolUses(mlir::SymbolTableCollection &table) { + if (!table.lookupNearestSymbolFrom( + (*this)->getParentOfType(), getNameAttr())) + return (*this)->emitOpError("unknown struct name ") << getNameAttr(); + return success(); +} + //===----------------------------------------------------------------------===// // SemanticsOp //===----------------------------------------------------------------------===// @@ -82,11 +131,11 @@ void SemanticsOp::initializeBlockArguments( // TODO: Handle constant sized port. auto port = value.getDefiningOp(); assert(port && port.getKind() != PortKind::PARAM && "invalid port"); - if (port.getSizes().empty()) + if (port.getDims().empty()) argTypes.push_back(/*port.getType().getType()*/ builder.getF32Type()); else argTypes.push_back(RankedTensorType::get( - SmallVector(port.getSizes().size(), ShapedType::kDynamic), + SmallVector(port.getDims().size(), ShapedType::kDynamic), /*port.getType().getType()*/ builder.getF32Type(), nullptr)); argLocs.push_back(port.getLoc()); } @@ -136,6 +185,31 @@ SemanticsOutputOp SemanticsOp::getSemanticsOutputOp() { return cast(getBody().front().getTerminator()); } +/// The template of an IP could be recursively a struct type. This method +/// can recursively peel off all the structs and return the real templates, +/// which are gauranteed to be ParamOp. +SmallVector SemanticsOp::getStructPeeledTemplates() { + SmallVector temps; + SmallVector worklist; + for (auto temp : getTemplates()) { + if (auto structOp = temp.getDefiningOp()) + worklist.push_back(structOp); + else + temps.push_back(temp); + } + + while (!worklist.empty()) { + auto structOp = worklist.pop_back_val(); + for (auto temp : structOp.getParams()) { + if (auto structOp = temp.getDefiningOp()) + worklist.push_back(structOp); + else + temps.push_back(temp); + } + } + return temps; +} + //===----------------------------------------------------------------------===// // SemanticsOutputOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/HLS/Transforms/GenerateTaskDesignSpace.cpp b/lib/Dialect/HLS/Transforms/GenerateTaskDesignSpace.cpp index b1db9083..91b27ce7 100644 --- a/lib/Dialect/HLS/Transforms/GenerateTaskDesignSpace.cpp +++ b/lib/Dialect/HLS/Transforms/GenerateTaskDesignSpace.cpp @@ -9,8 +9,8 @@ #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "scalehls/Dialect/HLS/Transforms/Passes.h" +#include "scalehls/Dialect/HLS/Utils/Matchers.h" #include "scalehls/Dialect/HLS/Utils/Utils.h" -#include "scalehls/Utils/Matchers.h" using namespace mlir; using namespace scalehls; @@ -95,7 +95,7 @@ struct GenerateTaskDesignSpacePattern : public OpRewritePattern { continue; // Otherwise, match the two linalg ops with LinalgMatcher. - if (succeeded(LinalgMatcher(linalgOp, ipLinalgOp).match())) { + if (succeeded(IPMatcher(linalgOp, declare).match())) { implCandidates.push_back( TaskImplAttr::get(rewriter.getContext(), libraryName, ipName)); diff --git a/lib/Dialect/HLS/Transforms/ImplementTaskDesignSpace.cpp b/lib/Dialect/HLS/Transforms/ImplementTaskDesignSpace.cpp index dacd588f..6909d5b0 100644 --- a/lib/Dialect/HLS/Transforms/ImplementTaskDesignSpace.cpp +++ b/lib/Dialect/HLS/Transforms/ImplementTaskDesignSpace.cpp @@ -9,13 +9,240 @@ #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "scalehls/Dialect/HLS/Transforms/Passes.h" +#include "scalehls/Dialect/HLS/Utils/Matchers.h" #include "scalehls/Dialect/HLS/Utils/Utils.h" -#include "scalehls/Utils/Matchers.h" using namespace mlir; using namespace scalehls; using namespace hls; +static FailureOr getParamAttr(Value param, ParamKind paramKind) { + // If the param is not defined by ParamLikeInterface or doesn't align with the + // given paramKind, return nullptr. + auto paramOp = param.getDefiningOp(); + if (!paramOp || paramOp.getKind() != paramKind) + return failure(); + + // If the param hasn't been assigned a value, return nullptr. + if (!paramOp.getValue().has_value()) + return failure(); + return *paramOp.getValue(); +} + +static FailureOr getImplSpaceOp(SpaceOp parentSpace, + SymbolRefAttr &implSymbolRef) { + // For now, the implementation design space must be defined by SpaceSelectOp. + auto implSpace = parentSpace.getSpacePackOp().getOperands().back(); + auto implSelect = implSpace.getDefiningOp(); + assert(implSelect && "invalid task implementation design space"); + + // Return failure if we cannot get the task implementation parameter. + auto implParamAttr = getParamAttr(implSelect.getArg(), ParamKind::TASK_IMPL); + if (failed(implParamAttr)) + return failure(); + + // Find the space corresponding to the selected task implementation. + for (auto [candidate, space] : + llvm::zip(implSelect.getConditions(), implSelect.getSpaces())) + if (candidate == *implParamAttr) { + implSpace = space; + break; + } + auto implSpaceOp = implSpace.getDefiningOp(); + assert(implSpaceOp && "invalid task implementation candidates"); + + // Get the symbol name of the IP declaration if applicable. + implSymbolRef = implParamAttr->cast().getSymbolRef(); + return implSpaceOp; +} + +static FailureOr +tileLinalgOp(linalg::LinalgOp linalgOp, ValueRange tileParams, + ParamKind paramKind, PatternRewriter &rewriter) { + // Extract the tile sizes from the given tile params. + SmallVector tileSizes; + for (auto tileParam : tileParams) { + auto tileParamAttr = getParamAttr(tileParam, paramKind); + if (failed(tileParamAttr)) + return failure(); + tileSizes.push_back(tileParamAttr->cast().getInt()); + } + + // Tile the linalg op with the collected tile sizes. + linalg::LinalgTilingOptions options; + options.setTileSizes(tileSizes); + auto tiledLinalgOp = linalg::tileLinalgOp(rewriter, linalgOp, options); + if (failed(tiledLinalgOp)) + return failure(); + + // Replace the original linalg op with the tiled one. + rewriter.replaceOp(linalgOp, tiledLinalgOp->tensorResults); + return tiledLinalgOp; +} + +// If the given pointer union is a value, return it. Otherwise, create a +// constant op and return it. +static Value getValueOrCreateConstant(OpFoldResult valueOrAttr, + PatternRewriter &rewriter) { + if (valueOrAttr.is()) + return valueOrAttr.get(); + auto attr = valueOrAttr.get(); + return rewriter.create(rewriter.getUnknownLoc(), + TypedAttr(attr)); +} + +static FailureOr +replaceLinalgOpWithInstanceOp(SpaceOp implSpaceOp, linalg::LinalgOp linalgOp, + SymbolRefAttr symbol, PatternRewriter &rewriter) { + auto ipDeclare = SymbolTable::lookupNearestSymbolFrom( + linalgOp->getParentOfType(), symbol); + assert(ipDeclare && "invalid IP declaration"); + + auto matchingResult = IPMatcher(linalgOp, ipDeclare).match(); + if (failed(matchingResult)) + return failure(); + + // Mapping from a value (port/template) in an IP declaration to a payload IR. + llvm::SmallDenseMap ipToInstMap; + + // Collect the instance ports and the result types of the IP instance. + rewriter.setInsertionPoint(linalgOp); + SmallVector instOutputTypes; + SmallVector instPorts; + unsigned outputIdx = 0; + for (auto [ipPort, instPortOrAttr] : llvm::zip( + ipDeclare.getSemanticsOp().getPorts(), matchingResult->instPorts)) { + Value instPort = getValueOrCreateConstant(instPortOrAttr, rewriter); + instPorts.push_back(instPort); + ipToInstMap[ipPort] = instPort; + + // If the port is an output, collect its type for later use. + auto portOp = ipPort.getDefiningOp(); + if (portOp.getKind() == PortKind::OUTPUT) { + auto linalgIdx = matchingResult->mapIpResIndexToPayload(outputIdx++); + auto outputType = + linalgOp->getResult(linalgIdx).getType().cast(); + assert(instPort.getType() == outputType && "invalid result type"); + + if (portOp.isStream()) + instOutputTypes.push_back(StreamType::get(outputType.getElementType())); + else + instOutputTypes.push_back(outputType); + } + } + + // Collect the instance template values. + for (auto [ipTemplate, instTemplateAttr] : + llvm::zip(ipDeclare.getSemanticsOp().getStructPeeledTemplates(), + matchingResult->instStructPeeledTemplates)) { + OpFoldResult instTemplate = instTemplateAttr; + if (!instTemplateAttr) { + // Template value may have been inferred from the IP matching, e.g., the + // array shapes. However, if the template value cannot be inferred and has + // been explored by DSE, we need to extract its result here. + auto templateParam = implSpaceOp.getSpacePackOp().findOperand( + ipTemplate.getDefiningOp().getNameAttr()); + assert(templateParam && "invalid template parameter"); + + auto templateParamAttr = + getParamAttr(templateParam, ParamKind::IP_TEMPLATE); + if (failed(templateParamAttr)) + return failure(); + instTemplate = *templateParamAttr; + } + if (!ipTemplate.getType().isa()) + instTemplate = getValueOrCreateConstant(instTemplate, rewriter); + ipToInstMap[ipTemplate] = instTemplate; + } + + // Instantiate all structs of the IP in a topological order, which implicitly + // ensured the dependencies between them. + for (auto structOp : ipDeclare.getOps()) { + SmallVector structInstParams; + for (auto param : structOp.getParams()) { + assert(ipToInstMap.count(param) && "invalid struct parameter"); + structInstParams.push_back(ipToInstMap.lookup(param).get()); + } + + SmallVector structNestedSymbols( + symbol.getNestedReferences()); + structNestedSymbols.push_back( + FlatSymbolRefAttr::get(structOp.getNameAttr())); + + auto structSymbolRef = + SymbolRefAttr::get(symbol.getRootReference(), structNestedSymbols); + auto structInst = rewriter.create( + structOp.getLoc(), structOp.getType(), structInstParams, + SmallVector(), structSymbolRef); + ipToInstMap[structOp] = structInst.getResult(); + } + + // Finally, we can construct the list of templates that is going to be passed + // to the IP instance. + SmallVector instTemplates; + for (auto ipTemplate : ipDeclare.getSemanticsOp().getTemplates()) { + assert(ipToInstMap.count(ipTemplate) && "invalid struct template"); + instTemplates.push_back(ipToInstMap.lookup(ipTemplate)); + } + + // Now, we can create the instance op. + auto instance = rewriter.create( + linalgOp.getLoc(), instOutputTypes, instPorts, instTemplates, symbol); + + // If the IP expects stream interface, we will convert tensor to stream for + // inputs and vice versa for outputs. + outputIdx = 0; + for (auto [ipPort, instPort] : + llvm::zip(ipDeclare.getSemanticsOp().getPorts(), instPorts)) { + auto portOp = ipPort.getDefiningOp(); + + // We use the ipToInstMap to get the corresponding dim/symbol values + // in the payload IR. + SmallVector instDims; + for (auto ipDim : portOp.getDims()) { + assert(ipToInstMap.count(ipDim) && "invalid dim"); + instDims.push_back(ipToInstMap.lookup(ipDim).get()); + } + SmallVector instSymbols; + for (auto ipSymbol : portOp.getSymbols()) { + assert(ipToInstMap.count(ipSymbol) && "invalid dim"); + instDims.push_back(ipToInstMap.lookup(ipSymbol).get()); + } + + // Replace the original input port with a stream port is applicable. + if (portOp.isStream()) { + rewriter.setInsertionPoint(instance); + auto streamType = StreamType::get( + instPort.getType().cast().getElementType()); + auto stream = rewriter.create( + linalgOp.getLoc(), streamType, instPort, instDims, instSymbols, + portOp.getStreamLayoutAttr(), portOp.getMemoryLayoutAttr()); + instPort.replaceUsesWithIf(stream.getResult(), [&](OpOperand &use) { + return use.getOwner() == instance; + }); + } + + // Convert stream output to tensors and replace all uses of the original + // linalg op results. + if (portOp.getKind() == PortKind::OUTPUT) { + auto tensor = instance.getResult(outputIdx); + + if (portOp.isStream()) { + rewriter.setInsertionPointAfter(instance); + tensor = rewriter.create( + linalgOp.getLoc(), instPort.getType(), tensor, instDims, + instSymbols, portOp.getStreamLayoutAttr(), + portOp.getMemoryLayoutAttr()); + } + auto linalgIdx = matchingResult->mapIpResIndexToPayload(outputIdx++); + linalgOp->getResult(linalgIdx).replaceAllUsesWith(tensor); + } + } + + rewriter.eraseOp(linalgOp); + return instance; +} + namespace { struct ImplementTaskDesignSpacePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -33,153 +260,36 @@ struct ImplementTaskDesignSpacePattern : public OpRewritePattern { if (!space || !linalgOp) return op.removeSpaceAttr(), failure(); - // Collect tile sizes. All operands of the terminator SpacePackOp except the - // last one are the tile size parameters. - SmallVector tileSizes; - for (auto param : llvm::drop_end(space.getSpacePackOp().getOperands())) { - // The tile size parameter must be TILE_SIZE kind and have an index type. - auto paramOp = param.getDefiningOp(); - assert(paramOp && paramOp.getKind() == ParamKind::TILE_SIZE && - "invalid tile parameter"); - - // If the param hasn't been assigned a value, return failure. This pass is - // expected to run after design space exploration of each param. - if (!paramOp.getValue().has_value()) - return op.removeSpaceAttr(), failure(); - - // Get the tile size value store as an attribute of the ParamOp. - tileSizes.push_back(paramOp.getValue()->cast().getInt()); - } - - // Tile the linalg op with the collected tile sizes. - linalg::LinalgTilingOptions options; - options.setTileSizes(tileSizes); - auto tiledLinalgOp = linalg::tileLinalgOp(rewriter, linalgOp, options); + // First, tile the linalg op with the tile parameters. All operands of the + // terminator SpacePackOp except the last one should be tile parameters. + auto tiledLinalgOp = tileLinalgOp( + linalgOp, llvm::drop_end(space.getSpacePackOp().getOperands()), + ParamKind::TILE_SIZE, rewriter); if (failed(tiledLinalgOp)) return op.removeSpaceAttr(), failure(); - - // Replace the original linalg op with the tiled one. - rewriter.replaceOp(linalgOp, tiledLinalgOp->tensorResults); linalgOp = tiledLinalgOp->op; - // Then, we start to handle the implementation design space, which must be - // defined by a SpaceSelectOp for now. - auto implSpace = space.getSpacePackOp().getOperands().back(); - auto implSelect = implSpace.getDefiningOp(); - assert(implSelect && "invalid task implementation design space"); - - // Similarly, the task implementation parameter must be TASK_IMPL kind and - // has a TaskImplType. - auto implParamOp = - implSelect.getArg().getDefiningOp(); - assert(implParamOp && implParamOp.getKind() == ParamKind::TASK_IMPL && - "invalid task implementation parameter"); - - // Again, if the param hasn't been assigned a value, return failure. - if (!implParamOp.getValue().has_value()) + // Then, we find the implementation design space of the task to handle the + // IP substitution or parallelization of the task. + SymbolRefAttr symbol; + auto implSpaceOp = getImplSpaceOp(space, symbol); + if (failed(implSpaceOp)) return op.removeSpaceAttr(), failure(); - // Find the space corresponding to the selected task implementation. - for (auto [candidate, space] : - llvm::zip(implSelect.getConditionsAttr(), implSelect.getSpaces())) - if (candidate == implParamOp.getValue()) { - implSpace = space; - break; - } - auto implSpaceOp = implSpace.getDefiningOp(); - assert(implSpaceOp && "invalid task implementation candidates"); - - if (auto symbol = - implParamOp.getValue()->cast().getSymbolRef()) { + if (symbol) { // If the task will be implemented with an IP, we substitute the original // linalg operation with an IP instance. - auto ipDeclare = SymbolTable::lookupNearestSymbolFrom( - op->getParentOfType(), symbol); - assert(ipDeclare && "invalid IP declaration"); - - auto matchingResult = IPMatcher(linalgOp, ipDeclare).match(); - if (failed(matchingResult)) + if (failed(replaceLinalgOpWithInstanceOp(*implSpaceOp, linalgOp, symbol, + rewriter))) return op.removeSpaceAttr(), failure(); - - // Collect the result types of the IP instance. - SmallVector instResultTypes; - for (unsigned i = 0; i < linalgOp.getNumDpsInits(); ++i) { - auto resIndex = matchingResult->mapIpResIndexToPayload(i); - instResultTypes.push_back(linalgOp->getResult(resIndex).getType()); - } - - // Collect the instance ports. - rewriter.setInsertionPoint(linalgOp); - SmallVector instPorts; - for (auto port : matchingResult->instPorts) { - if (port.is()) - instPorts.push_back(port.get()); - else if (auto portAttr = port.get()) - instPorts.push_back(rewriter.create( - linalgOp.getLoc(), TypedAttr(portAttr))); - } - - // Collect the instance templates. - SmallVector instTemplates; - for (auto [tempOperand, tempAttr] : - llvm::zip(ipDeclare.getSemanticsOp().getTemplates(), - matchingResult->instTemplates)) { - if (tempAttr) { - instTemplates.push_back(tempAttr); - continue; - } - auto tempParam = implSpaceOp.getSpacePackOp().findOperand( - tempOperand.getDefiningOp().getNameAttr()); - assert(tempParam && "invalid template parameter"); - - auto tempParamOp = tempParam.getDefiningOp(); - assert(tempParamOp && tempParamOp.getKind() == ParamKind::IP_TEMPLATE && - "invalid template parameter"); - - // Again, if the param hasn't been assigned a value, return failure. - if (!tempParamOp.getValue().has_value()) - return op.removeSpaceAttr(), failure(); - instTemplates.push_back(*tempParamOp.getValue()); - } - - // Finally, we can create the instance op. - auto instance = rewriter.create( - linalgOp.getLoc(), instResultTypes, instPorts, - rewriter.getArrayAttr(instTemplates), symbol); - - // Replace the original linalg op results with the instance op. - for (unsigned i = 0; i < instance.getNumResults(); ++i) { - auto resIndex = matchingResult->mapIpResIndexToPayload(i); - linalgOp->getResult(resIndex).replaceAllUsesWith(instance.getResult(i)); - } - rewriter.eraseOp(linalgOp); } else { - // Parallelize and use the default method. - SmallVector parallelParam; - for (auto param : implSpaceOp.getSpacePackOp().getArgs()) { - // The tile size parameter must be PARALLEL_SIZE kind and have an index - // type. - auto paramOp = param.getDefiningOp(); - - // Check if the params are valid - assert(paramOp.getKind() == ParamKind::PARALLEL_SIZE && - "invalid parallel parameter"); - if (!paramOp.getValue().has_value()) - return op.removeSpaceAttr(), failure(); - - // Get the parallel size value store as an attribute of the ParamOp. - parallelParam.push_back( - paramOp.getValue()->cast().getInt()); - } - - linalg::LinalgTilingOptions options; - options.setTileSizes(parallelParam); - auto parallelLinalgOp = linalg::tileLinalgOp(rewriter, linalgOp, options); + // Otherwise, use the default implementation that parallelize the linalg + // operation with the parallel parameters. + auto parallelLinalgOp = + tileLinalgOp(linalgOp, implSpaceOp->getSpacePackOp().getArgs(), + ParamKind::PARALLEL_SIZE, rewriter); if (failed(parallelLinalgOp)) return op.removeSpaceAttr(), failure(); - - // Replace the original linalg op with the parallel one. - rewriter.replaceOp(linalgOp, parallelLinalgOp->tensorResults); linalgOp = parallelLinalgOp->op; } diff --git a/lib/Dialect/HLS/Transforms/SimplifyDesignSpace.cpp b/lib/Dialect/HLS/Transforms/SimplifyDesignSpace.cpp index 04aeb6af..a2682850 100644 --- a/lib/Dialect/HLS/Transforms/SimplifyDesignSpace.cpp +++ b/lib/Dialect/HLS/Transforms/SimplifyDesignSpace.cpp @@ -4,13 +4,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "scalehls/Dialect/HLS/Transforms/Passes.h" #include "scalehls/Dialect/HLS/Utils/Utils.h" -#include "scalehls/Utils/Matchers.h" using namespace mlir; using namespace scalehls; diff --git a/lib/Dialect/HLS/Utils/CMakeLists.txt b/lib/Dialect/HLS/Utils/CMakeLists.txt index 956affeb..06ede4c8 100644 --- a/lib/Dialect/HLS/Utils/CMakeLists.txt +++ b/lib/Dialect/HLS/Utils/CMakeLists.txt @@ -1,3 +1,4 @@ add_mlir_dialect_library(MLIRScaleHLSHLSUtils + Matchers.cpp Utils.cpp ) diff --git a/lib/Utils/Matchers.cpp b/lib/Dialect/HLS/Utils/Matchers.cpp similarity index 97% rename from lib/Utils/Matchers.cpp rename to lib/Dialect/HLS/Utils/Matchers.cpp index e1f91bcf..4dcb5cc3 100644 --- a/lib/Utils/Matchers.cpp +++ b/lib/Dialect/HLS/Utils/Matchers.cpp @@ -4,13 +4,14 @@ // //===----------------------------------------------------------------------===// -#include "scalehls/Utils/Matchers.h" +#include "scalehls/Dialect/HLS/Utils/Matchers.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "scalehls-matchers" using namespace mlir; using namespace scalehls; +using namespace hls; //===----------------------------------------------------------------------===// // BlockMatcher @@ -297,7 +298,7 @@ FailureOr IPMatcher::match() { // Infer the port sizes. for (auto [portSizeOperand, size] : - llvm::zip(portOp.getSizes(), tensorType.getShape())) { + llvm::zip(portOp.getDims(), tensorType.getShape())) { auto sizeAttr = builder.getIndexAttr(size); if (portSizeOperand.getType().isa()) { if (!status.updateMatchedPort( @@ -324,8 +325,8 @@ FailureOr IPMatcher::match() { // 2) semantics block argument -> semantics linalg/output operand // 3) semantics linalg/output operand -> payload linalg operand // - // Then, the paylod linalg operand should be used by the new InstanceOp to - // replace the payload linalg op. Note that from step 2), we need to + // Then, the payload linalg operand should be used by the new InstanceOp + // to replace the payload linalg op. Note that from step 2), we need to // separate the handling of input, initiation, and output port. // We start from step 1) mapping. If failed, it indicates the current port @@ -407,13 +408,13 @@ FailureOr IPMatcher::match() { } // After the iteration, we will try to fill those unmatched ports and template - // parameters with default value it has. + // parameters with default value if it has. for (auto port : ipSemantics.getPorts()) { auto portOp = port.getDefiningOp(); if (auto defaultValue = portOp.getValue()) status.updateMatchedPort(port, defaultValue.value()); } - for (auto temp : ipSemantics.getTemplates()) { + for (auto temp : ipSemantics.getStructPeeledTemplates()) { auto tempOp = temp.getDefiningOp(); if (auto defaultValue = tempOp.getValue()) status.updateMatchedTemplate(temp, defaultValue.value()); diff --git a/lib/Translation/EmitHLSCpp.cpp b/lib/Translation/EmitHLSCpp.cpp index 7e5443d8..8d2ab94f 100644 --- a/lib/Translation/EmitHLSCpp.cpp +++ b/lib/Translation/EmitHLSCpp.cpp @@ -46,7 +46,9 @@ static std::string getDataTypeName(Type type) { return vectorName; } - // Handle scalar types, including float and integer. + // Handle scalar types, including float and integer, or struct type. + if (auto structType = type.dyn_cast()) + return structType.getName().str(); if (type.isa()) return "float"; else if (type.isa()) @@ -290,6 +292,7 @@ class ModuleEmitter : public ScaleHLSEmitterBase { /// Lib Ip operation emitter. void emitInstanceOp(InstanceOp op); + void emitStructInstanceOp(StructInstanceOp op); /// HLS dialect operation emitters. void emitConstBuffer(ConstBufferOp op); @@ -357,6 +360,7 @@ class ModuleEmitter : public ScaleHLSEmitterBase { unsigned emitNestedLoopHeader(Value val); void emitNestedLoopFooter(unsigned rank); void emitInfoAndNewLine(Operation *op); + void emitTemplateList(hls::TemplatedOpInterface op); /// MLIR component and HLS C++ pragma emitters. void emitBlock(Block &block); @@ -460,10 +464,11 @@ class StmtVisitor : public HLSVisitorBase { StmtVisitor(ModuleEmitter &emitter) : emitter(emitter) {} using HLSVisitorBase::visitOp; - // Test registered ip expression. - bool visitOp(InstanceOp op) { return emitter.emitInstanceOp(op), true; } - /// HLS dialect operations. + bool visitOp(InstanceOp op) { return emitter.emitInstanceOp(op), true; } + bool visitOp(StructInstanceOp op) { + return emitter.emitStructInstanceOp(op), true; + } bool visitOp(BufferOp op) { if (op.getDepth() == 1) return emitter.emitAlloc(op), true; @@ -701,38 +706,31 @@ void ModuleEmitter::emitConstBuffer(ConstBufferOp op) { /// Library Ip emitter. void ModuleEmitter::emitInstanceOp(InstanceOp op) { indent(); - - // Get Ip name, print Ip name. - auto ipName = op.getNameAttr(); - os << ipName.getNestedReferences()[0].getValue().str(); - - // Emit template. - os << "<"; - for (auto [i, curTemplate] : llvm::enumerate(op.getTemplates())) { - if (auto curAttr = curTemplate.dyn_cast()) { - os << getDataTypeName(curAttr.getValue()); - } else if (auto curAttr = curTemplate.dyn_cast()) { - os << curAttr.getInt(); - } else { - llvm_unreachable("Invalid template parameter"); - } - if (i != op.getTemplates().size() - 1) { + os << op.getName().getLeafReference().getValue().str(); + emitTemplateList(op); + + // Emit IP ports. + os << "("; + for (auto port : op.getPorts()) { + emitValue(port); + if (port != op.getPorts().back()) os << ", "; - } } - os << ">("; + os << ");"; + emitInfoAndNewLine(op); +} - // Emit Variables. - for (auto [i, curVar] : llvm::enumerate(op.getOperands())) { - emitValue(curVar); - if (i != op.getOperands().size() - 1) { +// TODO: Emit template list. +void ModuleEmitter::emitStructInstanceOp(StructInstanceOp op) { + indent(); + emitValue(op.getResult()); + os << " = {"; + for (auto param : op.getParams()) { + emitValue(param); + if (param != op.getParams().back()) os << ", "; - } } - os << ")"; - - // Emit ends. - os << ";"; + os << "};"; emitInfoAndNewLine(op); } @@ -1693,6 +1691,22 @@ void ModuleEmitter::emitInfoAndNewLine(Operation *op) { os << "\n"; } +void ModuleEmitter::emitTemplateList(hls::TemplatedOpInterface op) { + os << "<"; + auto composedTemplates = op.getComposedTemplates(); + for (auto temp : composedTemplates) { + if (temp.is()) + emitValue(temp.get()); + else if (auto typeAttr = temp.get().dyn_cast()) { + assert(typeAttr && "should be a type attribute."); + os << getDataTypeName(typeAttr.getValue()); + } + if (temp != composedTemplates.back()) + os << ", "; + } + os << ">"; +} + /// MLIR component and HLS C++ pragma emitters. void ModuleEmitter::emitBlock(Block &block) { for (auto &op : block) { @@ -1936,6 +1950,7 @@ void ModuleEmitter::emitModule(ModuleOp module) { } }); }); + for (const Attribute &curPath : emittedIncludeAttrs) { os << "#include "; os << curPath; diff --git a/lib/Utils/CMakeLists.txt b/lib/Utils/CMakeLists.txt index d51c621c..044f6916 100644 --- a/lib/Utils/CMakeLists.txt +++ b/lib/Utils/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_library(MLIRScaleHLSUtils - Matchers.cpp Utils.cpp LINK_LIBS PUBLIC diff --git a/python/scalehls/_mlir_libs/_hls_dialect.pyi b/python/scalehls/_mlir_libs/_hls_dialect.pyi index 174a2267..7eeb231e 100644 --- a/python/scalehls/_mlir_libs/_hls_dialect.pyi +++ b/python/scalehls/_mlir_libs/_hls_dialect.pyi @@ -6,6 +6,7 @@ from importlib._bootstrap import MemoryKindType from importlib._bootstrap import ParamKindAttr from importlib._bootstrap import PortKindAttr from importlib._bootstrap import PortType +from importlib._bootstrap import StructType from importlib._bootstrap import TaskImplType from importlib._bootstrap import TypeType @@ -16,6 +17,7 @@ __all__ = [ "PortKind", "PortKindAttr", "PortType", + "StructType", "TaskImplType", "TypeType", "semantics_init_args" diff --git a/test/EmitHLSCpp/test-instanceOp.mlir b/test/EmitHLSCpp/test-instance.mlir similarity index 66% rename from test/EmitHLSCpp/test-instanceOp.mlir rename to test/EmitHLSCpp/test-instance.mlir index 2945d67b..74ea911c 100644 --- a/test/EmitHLSCpp/test-instanceOp.mlir +++ b/test/EmitHLSCpp/test-instance.mlir @@ -11,11 +11,11 @@ module attributes { torch.debug_module_name = "MLP" } { %1 = hls.dse.param @template1