From 35a376630537f7be15c04eb59a27dd11c4c8cc5d Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Thu, 14 Nov 2024 11:58:10 +0900 Subject: [PATCH] [FIRRTL] Add a new FIRRTL annotation to specify type lowering behavior of module body This allows more fine-grained control over how types are lowered in different contexts. This also adds an "includeHierarchy" option to Convention annotations that allows applying the convention to all modules in the hierarchy below the annotated module. --- docs/Dialects/FIRRTL/FIRRTLAnnotations.md | 37 +++++++++++--- .../circt/Dialect/FIRRTL/AnnotationDetails.h | 2 + .../FIRRTL/Transforms/LowerAnnotations.cpp | 48 ++++++++++++++++--- lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp | 42 ++++++++++------ test/Dialect/FIRRTL/annotations.mlir | 23 ++++++++- test/Dialect/FIRRTL/lower-types.mlir | 39 +++++++++++++++ 6 files changed, 161 insertions(+), 30 deletions(-) diff --git a/docs/Dialects/FIRRTL/FIRRTLAnnotations.md b/docs/Dialects/FIRRTL/FIRRTLAnnotations.md index d2778715af0f..5a98b4c3a5c0 100644 --- a/docs/Dialects/FIRRTL/FIRRTLAnnotations.md +++ b/docs/Dialects/FIRRTL/FIRRTLAnnotations.md @@ -323,11 +323,12 @@ Example: ### Convention -| Property | Type | Description | -| ---------- | ------ | --------------------------------------- | -| class | string | `circt.ConventionAnnotation` | -| convention | string | `scalarized` | -| target | string | Reference target | +| Property | Type | Description | +| ---------------- | ------ | ---------------------------------------------------- | +| class | string | `circt.ConventionAnnotation` | +| convention | string | `scalarized` | +| target | string | Reference target | +| includeHierarchy | bool | Apply the convention to all modules in the hierarchy | Specify the port convention for a module. The port convention controls how a module's ports are transformed, and how that module can be instantiated, in the @@ -341,7 +342,31 @@ The options are: { "class": "circt.ConventionAnnotation", "convention": "scalarized", - "target": "~Foo|Bar/d:Baz" + "target": "~Foo|Bar", + "includeHierarchy": true +} +``` + +### BodyTypeLoweringAnnotation + +| Property | Type | Description | +| ---------------- | ------ | ---------------------------------- | +| class | string | `circt.BodyTypeLoweringAnnotation` | +| convention | string | See `Convention` annotation | +| target | string | See `Convention` annotation | +| includeHierarchy | bool | See `Convention` annotation | + +Specify the type lowering option for module internal signals. +This is similar to the `Convention` annotation, but for internal signals +rather than module ports. Refer to the `Convention` annotation for each +property description. + +```json +{ + "class": "circt.BodyTypeLoweringAnnotation", + "convention": "scalarized", + "target": "~Foo|Bar", + "includeHierarchy": true } ``` diff --git a/include/circt/Dialect/FIRRTL/AnnotationDetails.h b/include/circt/Dialect/FIRRTL/AnnotationDetails.h index 92c5ac7a34f6..13078bd55b7a 100644 --- a/include/circt/Dialect/FIRRTL/AnnotationDetails.h +++ b/include/circt/Dialect/FIRRTL/AnnotationDetails.h @@ -29,6 +29,8 @@ constexpr const char *rawAnnotations = "rawAnnotations"; //===----------------------------------------------------------------------===// constexpr const char *conventionAnnoClass = "circt.ConventionAnnotation"; +constexpr const char *typeLoweringAnnoClass = + "circt.BodyTypeLoweringAnnotation"; constexpr const char *dontTouchAnnoClass = "firrtl.transforms.DontTouchAnnotation"; constexpr const char *enumComponentAnnoClass = diff --git a/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp b/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp index 5e942e5185b6..d976b97e3ec5 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp @@ -275,14 +275,17 @@ static std::optional parseConvention(llvm::StringRef str) { .Default(std::nullopt); } -static LogicalResult applyConventionAnno(const AnnoPathValue &target, - DictionaryAttr anno, - ApplyState &state) { +template +static LogicalResult +applyConventionOrTypeLoweringAnno(const AnnoPathValue &target, + DictionaryAttr anno, ApplyState &state) { auto *op = target.ref.getOp(); auto loc = op->getLoc(); auto error = [&]() { auto diag = mlir::emitError(loc); - diag << "circuit.ConventionAnnotation "; + diag << (IsConventionAnno ? "circuit.ConventionAnnotation " + : "circuit.TypeLoweringAnnotation ") + << " "; return diag; }; @@ -305,13 +308,41 @@ static LogicalResult applyConventionAnno(const AnnoPathValue &target, auto convention = *conventionOpt; + if (convention == Convention::Internal) + // Convention is internal by default so there is nothing to change + return success(); + + auto includeHierarchy = anno.getAs("includeHierarchy"); + auto conventionAttr = ConventionAttr::get(op->getContext(), convention); + auto setConvention = [&](Operation *moduleOp) { + TypeSwitch(moduleOp) + .Case([&](auto moduleOp) { + if (IsConventionAnno) + moduleOp.setConventionAttr(conventionAttr); + else + moduleOp->setDiscardableAttr("body_type_lowering", conventionAttr); + }) + .Default([](auto) {}); + }; + if (auto moduleOp = dyn_cast(op)) { - moduleOp.setConvention(convention); + if (includeHierarchy && includeHierarchy.getValue()) { + // If includeHierarchy is true, update the convention for all modules in + // the hierarchy. + for (auto *node : + llvm::post_order(state.instancePathCache.instanceGraph[moduleOp])) { + if (node && isa(*node->getModule())) + setConvention(node->getModule()); + } + } else { + // Update the convention. + setConvention(moduleOp); + } return success(); } if (auto extModuleOp = dyn_cast(op)) { - extModuleOp.setConvention(convention); + setConvention(extModuleOp); return success(); } @@ -563,7 +594,10 @@ static llvm::StringMap annotationRecords{{ {omirTrackerAnnoClass, {stdResolve, applyWithoutTarget}}, {omirFileAnnoClass, NoTargetAnnotation}, // Miscellaneous Annotations - {conventionAnnoClass, {stdResolve, applyConventionAnno}}, + {conventionAnnoClass, + {stdResolve, applyConventionOrTypeLoweringAnno}}, + {typeLoweringAnnoClass, + {stdResolve, applyConventionOrTypeLoweringAnno}}, {dontTouchAnnoClass, {stdResolve, applyWithoutTarget { TypeLoweringVisitor( MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate, + Convention bodyConvention, PreserveAggregate::PreserveMode memoryPreservationMode, SymbolTable &symTbl, const AttrCache &cache, const llvm::DenseMap &conventionTable) - : context(context), aggregatePreservationMode(preserveAggregate), + : context(context), defaultAggregatePreservationMode(preserveAggregate), memoryPreservationMode(memoryPreservationMode), symTbl(symTbl), - cache(cache), conventionTable(conventionTable) {} + cache(cache), conventionTable(conventionTable) { + bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized + ? PreserveAggregate::None + : defaultAggregatePreservationMode; + } using FIRRTLVisitor::visitDecl; using FIRRTLVisitor::visitExpr; using FIRRTLVisitor::visitStmt; @@ -429,7 +434,7 @@ struct TypeLoweringVisitor : public FIRRTLVisitor { Location errorLoc); PreserveAggregate::PreserveMode - getPreservationModeForModule(FModuleLike moduleLike); + getPreservationModeForPorts(FModuleLike moduleLike); Value getSubWhatever(Value val, size_t index); size_t uniqueIdx = 0; @@ -441,7 +446,8 @@ struct TypeLoweringVisitor : public FIRRTLVisitor { MLIRContext *context; /// Aggregate preservation mode. - PreserveAggregate::PreserveMode aggregatePreservationMode; + PreserveAggregate::PreserveMode defaultAggregatePreservationMode; + PreserveAggregate::PreserveMode bodyAggregatePreservationMode; PreserveAggregate::PreserveMode memoryPreservationMode; /// The builder is set and maintained in the main loop. @@ -460,21 +466,21 @@ struct TypeLoweringVisitor : public FIRRTLVisitor { }; } // namespace -/// Return aggregate preservation mode for the module. If the module has a +/// Return aggregate preservation mode for the module ports. If the module has a /// scalarized linkage, then we may not preserve it's aggregate ports. PreserveAggregate::PreserveMode -TypeLoweringVisitor::getPreservationModeForModule(FModuleLike module) { +TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) { auto lookup = conventionTable.find(module); if (lookup == conventionTable.end()) - return aggregatePreservationMode; + return defaultAggregatePreservationMode; switch (lookup->second) { case Convention::Scalarized: return PreserveAggregate::None; case Convention::Internal: - return aggregatePreservationMode; + return defaultAggregatePreservationMode; } llvm_unreachable("Unknown convention"); - return aggregatePreservationMode; + return defaultAggregatePreservationMode; } Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) { @@ -643,7 +649,7 @@ bool TypeLoweringVisitor::lowerProducer( return false; SmallVector fieldTypes; - if (!peelType(srcFType, fieldTypes, aggregatePreservationMode)) + if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode)) return false; SmallVector lowered; @@ -809,7 +815,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex, // Flatten any bundle types. SmallVector fieldTypes; auto srcType = type_cast(newArgs[argIndex].pi.type); - if (!peelType(srcType, fieldTypes, getPreservationModeForModule(module))) + if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module))) return false; // Ports with internalPath set cannot be lowered. @@ -929,7 +935,7 @@ bool TypeLoweringVisitor::visitStmt(RefDefineOp op) { // Attempt to get the bundle types. SmallVector fields; - if (!peelType(op.getDest().getType(), fields, aggregatePreservationMode)) + if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode)) return false; // Loop over the leaf aggregates. @@ -1454,7 +1460,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) { SmallVector newDirs; SmallVector newNames; SmallVector newPortAnno; - PreserveAggregate::PreserveMode mode = getPreservationModeForModule( + PreserveAggregate::PreserveMode mode = getPreservationModeForPorts( cast(op.getReferencedOperation(symTbl))); endFields.push_back(0); @@ -1667,9 +1673,15 @@ void LowerTypesPass::runOnOperation() { // This lambda, executes in parallel for each Op within the circt. auto lowerModules = [&](FModuleLike op) -> LogicalResult { + // Use body type lowering attribute if it exists, otherwise use internal. + Convention convention = Convention::Internal; + if (auto conventionAttr = dyn_cast_or_null( + op->getDiscardableAttr("body_type_lowering"))) + convention = conventionAttr.getValue(); + auto tl = - TypeLoweringVisitor(&getContext(), preserveAggregate, preserveMemories, - symTbl, cache, conventionTable); + TypeLoweringVisitor(&getContext(), preserveAggregate, convention, + preserveMemories, symTbl, cache, conventionTable); tl.lowerModule(op); return LogicalResult::failure(tl.isFailed()); diff --git a/test/Dialect/FIRRTL/annotations.mlir b/test/Dialect/FIRRTL/annotations.mlir index 3568fc4b498b..470d05833863 100644 --- a/test/Dialect/FIRRTL/annotations.mlir +++ b/test/Dialect/FIRRTL/annotations.mlir @@ -734,14 +734,33 @@ firrtl.circuit "Test" attributes {rawAnnotations = [ // ----- firrtl.circuit "Test" attributes {rawAnnotations =[ - {class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"} + {class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"}, + {class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized"} ]} { - // CHECK: attributes {convention = #firrtl} + // CHECK: attributes {body_type_lowering = #firrtl, convention = #firrtl} firrtl.module @Test() attributes {convention = #firrtl} {} } // ----- +firrtl.circuit "Test" attributes {rawAnnotations = [ + {class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true}, + {class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true} + ]} { + // CHECK: @Test() attributes {body_type_lowering = #firrtl, convention = #firrtl} + firrtl.module @Test() attributes {convention = #firrtl} { + firrtl.instance child @Child() + } + + // CHECK: @Child() attributes {body_type_lowering = #firrtl, convention = #firrtl} + firrtl.module @Child() attributes {convention = #firrtl} {} + + // CHECK: @Child2() { + firrtl.module @Child2() attributes {convention = #firrtl} {} +} + +// ----- + firrtl.circuit "Test" attributes {rawAnnotations =[ {class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>comb", prefix = "Prefix_"}, {class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>seq", prefix = "Prefix_"}, diff --git a/test/Dialect/FIRRTL/lower-types.mlir b/test/Dialect/FIRRTL/lower-types.mlir index ac7f2d10f49f..4366b2b9a581 100644 --- a/test/Dialect/FIRRTL/lower-types.mlir +++ b/test/Dialect/FIRRTL/lower-types.mlir @@ -1404,3 +1404,42 @@ firrtl.circuit "UnrealizedConversion" { firrtl.matchingconnect %w, %b : !firrtl.bundle, tag: uint<1>> } } + +firrtl.circuit "Conventions1" { + // COMMON-LABEL: @Conventions1 + // AGGREGATE-SAME: %input_0 + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.vector, 1> + firrtl.module public @Conventions1(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } + // COMMON-LABEL: @Conventions2 + // AGGREGATE-SAME: %input_0: !firrtl.uint<8> + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.uint<8> + firrtl.module private @Conventions2(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } + // COMMON-LABEL: @Conventions3 + // AGGREGATE-SAME: %input: !firrtl.vector, 1> + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.vector, 1> + firrtl.module private @Conventions3(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } + // COMMON-LABEL: @Conventions4 + // AGGREGATE-SAME: %input: !firrtl.vector, 1> + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.uint<8> + firrtl.module private @Conventions4(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } +}