diff --git a/docs/Dialects/FIRRTL/FIRRTLAnnotations.md b/docs/Dialects/FIRRTL/FIRRTLAnnotations.md index d2778715af0f..482c8b96e4d2 100644 --- a/docs/Dialects/FIRRTL/FIRRTLAnnotations.md +++ b/docs/Dialects/FIRRTL/FIRRTLAnnotations.md @@ -323,11 +323,12 @@ Example: ### Convention -| Property | Type | Description | -| ---------- | ------ | --------------------------------------- | -| class | string | `circt.ConventionAnnotation` | -| convention | string | `scalarized` | -| target | string | Reference target | +| Property | Type | Description | +| ---------------- | ------ | ---------------------------------------------------- | +| class | string | `circt.ConventionAnnotation` | +| convention | string | `scalarized` | +| target | string | Reference target | +| includeHierarchy | bool | Apply the convention to all modules in the hierarchy | Specify the port convention for a module. The port convention controls how a module's ports are transformed, and how that module can be instantiated, in the @@ -337,11 +338,41 @@ The options are: - `scalarized`: Convert aggregate ports (i.e. vector or bundles) into multiple ground-typed ports. +`includeHierarchy` is optional and defaults to `false`, meaning that the +convention is applied only to the specified module. If `includeHierarchy` is +`true`, the convention is applied to all modules in the hierarchy. If there are +multiple annotation instances that specify conventions, the `scalarized` convention +takes precedence over the `internal` convention. + ```json { "class": "circt.ConventionAnnotation", "convention": "scalarized", - "target": "~Foo|Bar/d:Baz" + "target": "~Foo|Bar", + "includeHierarchy": true +} +``` + +### BodyTypeLoweringAnnotation + +| Property | Type | Description | +| ---------------- | ------ | ---------------------------------- | +| class | string | `circt.BodyTypeLoweringAnnotation` | +| convention | string | See `Convention` annotation | +| target | string | See `Convention` annotation | +| includeHierarchy | bool | See `Convention` annotation | + +Specify the type lowering option for module internal signals. +This is similar to the `Convention` annotation, but for internal signals +rather than module ports. Refer to the `Convention` annotation for each +property description. + +```json +{ + "class": "circt.BodyTypeLoweringAnnotation", + "convention": "scalarized", + "target": "~Foo|Bar", + "includeHierarchy": true } ``` diff --git a/include/circt/Dialect/FIRRTL/AnnotationDetails.h b/include/circt/Dialect/FIRRTL/AnnotationDetails.h index 92c5ac7a34f6..13078bd55b7a 100644 --- a/include/circt/Dialect/FIRRTL/AnnotationDetails.h +++ b/include/circt/Dialect/FIRRTL/AnnotationDetails.h @@ -29,6 +29,8 @@ constexpr const char *rawAnnotations = "rawAnnotations"; //===----------------------------------------------------------------------===// constexpr const char *conventionAnnoClass = "circt.ConventionAnnotation"; +constexpr const char *typeLoweringAnnoClass = + "circt.BodyTypeLoweringAnnotation"; constexpr const char *dontTouchAnnoClass = "firrtl.transforms.DontTouchAnnotation"; constexpr const char *enumComponentAnnoClass = diff --git a/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp b/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp index 5e942e5185b6..d976b97e3ec5 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp @@ -275,14 +275,17 @@ static std::optional parseConvention(llvm::StringRef str) { .Default(std::nullopt); } -static LogicalResult applyConventionAnno(const AnnoPathValue &target, - DictionaryAttr anno, - ApplyState &state) { +template +static LogicalResult +applyConventionOrTypeLoweringAnno(const AnnoPathValue &target, + DictionaryAttr anno, ApplyState &state) { auto *op = target.ref.getOp(); auto loc = op->getLoc(); auto error = [&]() { auto diag = mlir::emitError(loc); - diag << "circuit.ConventionAnnotation "; + diag << (IsConventionAnno ? "circuit.ConventionAnnotation " + : "circuit.TypeLoweringAnnotation ") + << " "; return diag; }; @@ -305,13 +308,41 @@ static LogicalResult applyConventionAnno(const AnnoPathValue &target, auto convention = *conventionOpt; + if (convention == Convention::Internal) + // Convention is internal by default so there is nothing to change + return success(); + + auto includeHierarchy = anno.getAs("includeHierarchy"); + auto conventionAttr = ConventionAttr::get(op->getContext(), convention); + auto setConvention = [&](Operation *moduleOp) { + TypeSwitch(moduleOp) + .Case([&](auto moduleOp) { + if (IsConventionAnno) + moduleOp.setConventionAttr(conventionAttr); + else + moduleOp->setDiscardableAttr("body_type_lowering", conventionAttr); + }) + .Default([](auto) {}); + }; + if (auto moduleOp = dyn_cast(op)) { - moduleOp.setConvention(convention); + if (includeHierarchy && includeHierarchy.getValue()) { + // If includeHierarchy is true, update the convention for all modules in + // the hierarchy. + for (auto *node : + llvm::post_order(state.instancePathCache.instanceGraph[moduleOp])) { + if (node && isa(*node->getModule())) + setConvention(node->getModule()); + } + } else { + // Update the convention. + setConvention(moduleOp); + } return success(); } if (auto extModuleOp = dyn_cast(op)) { - extModuleOp.setConvention(convention); + setConvention(extModuleOp); return success(); } @@ -563,7 +594,10 @@ static llvm::StringMap annotationRecords{{ {omirTrackerAnnoClass, {stdResolve, applyWithoutTarget}}, {omirFileAnnoClass, NoTargetAnnotation}, // Miscellaneous Annotations - {conventionAnnoClass, {stdResolve, applyConventionAnno}}, + {conventionAnnoClass, + {stdResolve, applyConventionOrTypeLoweringAnno}}, + {typeLoweringAnnoClass, + {stdResolve, applyConventionOrTypeLoweringAnno}}, {dontTouchAnnoClass, {stdResolve, applyWithoutTarget { TypeLoweringVisitor( MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate, + Convention bodyConvention, PreserveAggregate::PreserveMode memoryPreservationMode, SymbolTable &symTbl, const AttrCache &cache, const llvm::DenseMap &conventionTable) - : context(context), aggregatePreservationMode(preserveAggregate), + : context(context), defaultAggregatePreservationMode(preserveAggregate), memoryPreservationMode(memoryPreservationMode), symTbl(symTbl), - cache(cache), conventionTable(conventionTable) {} + cache(cache), conventionTable(conventionTable) { + bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized + ? PreserveAggregate::None + : defaultAggregatePreservationMode; + } using FIRRTLVisitor::visitDecl; using FIRRTLVisitor::visitExpr; using FIRRTLVisitor::visitStmt; @@ -429,7 +434,7 @@ struct TypeLoweringVisitor : public FIRRTLVisitor { Location errorLoc); PreserveAggregate::PreserveMode - getPreservationModeForModule(FModuleLike moduleLike); + getPreservationModeForPorts(FModuleLike moduleLike); Value getSubWhatever(Value val, size_t index); size_t uniqueIdx = 0; @@ -441,7 +446,8 @@ struct TypeLoweringVisitor : public FIRRTLVisitor { MLIRContext *context; /// Aggregate preservation mode. - PreserveAggregate::PreserveMode aggregatePreservationMode; + PreserveAggregate::PreserveMode defaultAggregatePreservationMode; + PreserveAggregate::PreserveMode bodyAggregatePreservationMode; PreserveAggregate::PreserveMode memoryPreservationMode; /// The builder is set and maintained in the main loop. @@ -460,21 +466,21 @@ struct TypeLoweringVisitor : public FIRRTLVisitor { }; } // namespace -/// Return aggregate preservation mode for the module. If the module has a +/// Return aggregate preservation mode for the module ports. If the module has a /// scalarized linkage, then we may not preserve it's aggregate ports. PreserveAggregate::PreserveMode -TypeLoweringVisitor::getPreservationModeForModule(FModuleLike module) { +TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) { auto lookup = conventionTable.find(module); if (lookup == conventionTable.end()) - return aggregatePreservationMode; + return defaultAggregatePreservationMode; switch (lookup->second) { case Convention::Scalarized: return PreserveAggregate::None; case Convention::Internal: - return aggregatePreservationMode; + return defaultAggregatePreservationMode; } llvm_unreachable("Unknown convention"); - return aggregatePreservationMode; + return defaultAggregatePreservationMode; } Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) { @@ -643,7 +649,7 @@ bool TypeLoweringVisitor::lowerProducer( return false; SmallVector fieldTypes; - if (!peelType(srcFType, fieldTypes, aggregatePreservationMode)) + if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode)) return false; SmallVector lowered; @@ -809,7 +815,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex, // Flatten any bundle types. SmallVector fieldTypes; auto srcType = type_cast(newArgs[argIndex].pi.type); - if (!peelType(srcType, fieldTypes, getPreservationModeForModule(module))) + if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module))) return false; // Ports with internalPath set cannot be lowered. @@ -929,7 +935,7 @@ bool TypeLoweringVisitor::visitStmt(RefDefineOp op) { // Attempt to get the bundle types. SmallVector fields; - if (!peelType(op.getDest().getType(), fields, aggregatePreservationMode)) + if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode)) return false; // Loop over the leaf aggregates. @@ -1454,7 +1460,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) { SmallVector newDirs; SmallVector newNames; SmallVector newPortAnno; - PreserveAggregate::PreserveMode mode = getPreservationModeForModule( + PreserveAggregate::PreserveMode mode = getPreservationModeForPorts( cast(op.getReferencedOperation(symTbl))); endFields.push_back(0); @@ -1667,9 +1673,15 @@ void LowerTypesPass::runOnOperation() { // This lambda, executes in parallel for each Op within the circt. auto lowerModules = [&](FModuleLike op) -> LogicalResult { + // Use body type lowering attribute if it exists, otherwise use internal. + Convention convention = Convention::Internal; + if (auto conventionAttr = dyn_cast_or_null( + op->getDiscardableAttr("body_type_lowering"))) + convention = conventionAttr.getValue(); + auto tl = - TypeLoweringVisitor(&getContext(), preserveAggregate, preserveMemories, - symTbl, cache, conventionTable); + TypeLoweringVisitor(&getContext(), preserveAggregate, convention, + preserveMemories, symTbl, cache, conventionTable); tl.lowerModule(op); return LogicalResult::failure(tl.isFailed()); diff --git a/test/Dialect/FIRRTL/annotations.mlir b/test/Dialect/FIRRTL/annotations.mlir index 3568fc4b498b..470d05833863 100644 --- a/test/Dialect/FIRRTL/annotations.mlir +++ b/test/Dialect/FIRRTL/annotations.mlir @@ -734,14 +734,33 @@ firrtl.circuit "Test" attributes {rawAnnotations = [ // ----- firrtl.circuit "Test" attributes {rawAnnotations =[ - {class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"} + {class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"}, + {class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized"} ]} { - // CHECK: attributes {convention = #firrtl} + // CHECK: attributes {body_type_lowering = #firrtl, convention = #firrtl} firrtl.module @Test() attributes {convention = #firrtl} {} } // ----- +firrtl.circuit "Test" attributes {rawAnnotations = [ + {class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true}, + {class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true} + ]} { + // CHECK: @Test() attributes {body_type_lowering = #firrtl, convention = #firrtl} + firrtl.module @Test() attributes {convention = #firrtl} { + firrtl.instance child @Child() + } + + // CHECK: @Child() attributes {body_type_lowering = #firrtl, convention = #firrtl} + firrtl.module @Child() attributes {convention = #firrtl} {} + + // CHECK: @Child2() { + firrtl.module @Child2() attributes {convention = #firrtl} {} +} + +// ----- + firrtl.circuit "Test" attributes {rawAnnotations =[ {class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>comb", prefix = "Prefix_"}, {class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>seq", prefix = "Prefix_"}, diff --git a/test/Dialect/FIRRTL/lower-types.mlir b/test/Dialect/FIRRTL/lower-types.mlir index ac7f2d10f49f..4366b2b9a581 100644 --- a/test/Dialect/FIRRTL/lower-types.mlir +++ b/test/Dialect/FIRRTL/lower-types.mlir @@ -1404,3 +1404,42 @@ firrtl.circuit "UnrealizedConversion" { firrtl.matchingconnect %w, %b : !firrtl.bundle, tag: uint<1>> } } + +firrtl.circuit "Conventions1" { + // COMMON-LABEL: @Conventions1 + // AGGREGATE-SAME: %input_0 + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.vector, 1> + firrtl.module public @Conventions1(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } + // COMMON-LABEL: @Conventions2 + // AGGREGATE-SAME: %input_0: !firrtl.uint<8> + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.uint<8> + firrtl.module private @Conventions2(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } + // COMMON-LABEL: @Conventions3 + // AGGREGATE-SAME: %input: !firrtl.vector, 1> + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.vector, 1> + firrtl.module private @Conventions3(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } + // COMMON-LABEL: @Conventions4 + // AGGREGATE-SAME: %input: !firrtl.vector, 1> + // AGGREGATE-NEXT: firrtl.reg + // AGGREGATE-SAME: !firrtl.uint<8> + firrtl.module private @Conventions4(in %input: !firrtl.vector, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector, 1>) attributes {convention = #firrtl, body_type_lowering = #firrtl}{ + %r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector, 1> + firrtl.matchingconnect %r, %input : !firrtl.vector, 1> + firrtl.matchingconnect %port, %r : !firrtl.vector, 1> + } +}