Skip to content

Commit 0c1465d

Browse files
authored
[FIRRTL] Add a new FIRRTL annotation to specify type lowering behavior of module body (#7751)
Add a new annotation to control type lowering behavior for internal signals within a module, separate from the port convention. This allows more fine-grained control over how aggregate types are handled inside modules. The new annotation works similarly to ConventionAnnotation but applies to internal signals rather than module ports. It supports the same conventions and includes an 'includeHierarchy' option to apply the setting to all modules in the hierarchy.
1 parent baccf51 commit 0c1465d

File tree

6 files changed

+184
-17
lines changed

6 files changed

+184
-17
lines changed

docs/Dialects/FIRRTL/FIRRTLAnnotations.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,34 @@ The options are:
345345
}
346346
```
347347

348+
### BodyTypeLoweringAnnotation
349+
350+
| Property | Type | Description |
351+
| ------------------- | ------ | ---------------------------------------------------- |
352+
| class | string | `circt.BodyTypeLoweringAnnotation` |
353+
| convention | string | See `Convention` annotation |
354+
| target | string | See `Convention` annotation |
355+
| includeHierarchy | bool | Apply the convention to all modules in the hierarchy |
356+
357+
Specify the type lowering option for module internal signals.
358+
This is similar to the `Convention` annotation, but for internal signals
359+
rather than module ports. Refer to the `Convention` annotation for each
360+
property description.
361+
362+
When `includeHierarchy` is `false`, it indicates the convention is applied only to
363+
the specified module. If `includeHierarchy` is `true`, the convention is applied to
364+
all modules in the hierarchy. If there are multiple annotation instances that specify
365+
conventions, the `scalarized` convention takes precedence over the `internal` convention.
366+
367+
```json
368+
{
369+
"class": "circt.BodyTypeLoweringAnnotation",
370+
"convention": "scalarized",
371+
"target": "~Foo|Bar",
372+
"includeHierarchy": true
373+
}
374+
```
375+
348376
### ElaborationArtefactsDirectory
349377

350378
| Property | Type | Description |

include/circt/Dialect/FIRRTL/AnnotationDetails.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ constexpr const char *rawAnnotations = "rawAnnotations";
2929
//===----------------------------------------------------------------------===//
3030

3131
constexpr const char *conventionAnnoClass = "circt.ConventionAnnotation";
32+
constexpr const char *typeLoweringAnnoClass =
33+
"circt.BodyTypeLoweringAnnotation";
3234
constexpr const char *dontTouchAnnoClass =
3335
"firrtl.transforms.DontTouchAnnotation";
3436
constexpr const char *enumComponentAnnoClass =

lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,72 @@ static LogicalResult applyConventionAnno(const AnnoPathValue &target,
311311
return error() << "can only target to a module or extmodule";
312312
}
313313

314+
static LogicalResult applyBodyTypeLoweringAnno(const AnnoPathValue &target,
315+
DictionaryAttr anno,
316+
ApplyState &state) {
317+
auto *op = target.ref.getOp();
318+
auto loc = op->getLoc();
319+
auto error = [&]() {
320+
auto diag = mlir::emitError(loc);
321+
diag << typeLoweringAnnoClass;
322+
return diag;
323+
};
324+
325+
auto opTarget = dyn_cast<OpAnnoTarget>(target.ref);
326+
if (!opTarget)
327+
return error() << "must target a module object";
328+
329+
if (!target.isLocal())
330+
return error() << "must be local";
331+
332+
auto moduleOp = dyn_cast<FModuleOp>(op);
333+
334+
if (!moduleOp)
335+
return error() << "can only target to a module";
336+
337+
auto conventionStrAttr =
338+
tryGetAs<StringAttr>(anno, anno, "convention", loc, conventionAnnoClass);
339+
340+
if (!conventionStrAttr)
341+
return failure();
342+
343+
auto conventionStr = conventionStrAttr.getValue();
344+
auto conventionOpt = parseConvention(conventionStr);
345+
if (!conventionOpt)
346+
return error() << "unknown convention " << conventionStr;
347+
348+
auto convention = *conventionOpt;
349+
350+
if (convention == Convention::Internal)
351+
// Convention is internal by default so there is nothing to change
352+
return success();
353+
354+
auto conventionAttr = ConventionAttr::get(op->getContext(), convention);
355+
356+
// `includeHierarchy` only valid in BodyTypeLowering.
357+
bool includeHierarchy = false;
358+
if (auto includeHierarchyAttr = tryGetAs<BoolAttr>(
359+
anno, anno, "includeHierarchy", loc, conventionAnnoClass))
360+
includeHierarchy = includeHierarchyAttr.getValue();
361+
362+
if (includeHierarchy) {
363+
// If includeHierarchy is true, update the convention for all modules in
364+
// the hierarchy.
365+
for (auto *node :
366+
llvm::post_order(state.instancePathCache.instanceGraph[moduleOp])) {
367+
if (!node)
368+
continue;
369+
if (auto fmodule = dyn_cast<FModuleOp>(*node->getModule()))
370+
fmodule->setAttr("body_type_lowering", conventionAttr);
371+
}
372+
} else {
373+
// Update the convention.
374+
moduleOp->setAttr("body_type_lowering", conventionAttr);
375+
}
376+
377+
return success();
378+
}
379+
314380
static LogicalResult applyModulePrefixAnno(const AnnoPathValue &target,
315381
DictionaryAttr anno,
316382
ApplyState &state) {
@@ -553,6 +619,7 @@ static llvm::StringMap<AnnoRecord> annotationRecords{{
553619
{memTapBlackboxClass, {stdResolve, applyWithoutTarget<true>}},
554620
// Miscellaneous Annotations
555621
{conventionAnnoClass, {stdResolve, applyConventionAnno}},
622+
{typeLoweringAnnoClass, {stdResolve, applyBodyTypeLoweringAnno}},
556623
{dontTouchAnnoClass,
557624
{stdResolve, applyWithoutTarget<true, true, WireOp, NodeOp, RegOp,
558625
RegResetOp, InstanceOp, MemOp, CombMemOp,

lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,17 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
332332

333333
TypeLoweringVisitor(
334334
MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate,
335+
Convention bodyConvention,
335336
PreserveAggregate::PreserveMode memoryPreservationMode,
336337
SymbolTable &symTbl, const AttrCache &cache,
337338
const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
338-
: context(context), aggregatePreservationMode(preserveAggregate),
339+
: context(context), defaultAggregatePreservationMode(preserveAggregate),
339340
memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
340-
cache(cache), conventionTable(conventionTable) {}
341+
cache(cache), conventionTable(conventionTable) {
342+
bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
343+
? PreserveAggregate::None
344+
: defaultAggregatePreservationMode;
345+
}
341346
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
342347
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitExpr;
343348
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitStmt;
@@ -422,7 +427,7 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
422427
Location errorLoc);
423428

424429
PreserveAggregate::PreserveMode
425-
getPreservationModeForModule(FModuleLike moduleLike);
430+
getPreservationModeForPorts(FModuleLike moduleLike);
426431
Value getSubWhatever(Value val, size_t index);
427432

428433
size_t uniqueIdx = 0;
@@ -434,7 +439,8 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
434439
MLIRContext *context;
435440

436441
/// Aggregate preservation mode.
437-
PreserveAggregate::PreserveMode aggregatePreservationMode;
442+
PreserveAggregate::PreserveMode defaultAggregatePreservationMode;
443+
PreserveAggregate::PreserveMode bodyAggregatePreservationMode;
438444
PreserveAggregate::PreserveMode memoryPreservationMode;
439445

440446
/// The builder is set and maintained in the main loop.
@@ -453,21 +459,21 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
453459
};
454460
} // namespace
455461

456-
/// Return aggregate preservation mode for the module. If the module has a
462+
/// Return aggregate preservation mode for the module ports. If the module has a
457463
/// scalarized linkage, then we may not preserve it's aggregate ports.
458464
PreserveAggregate::PreserveMode
459-
TypeLoweringVisitor::getPreservationModeForModule(FModuleLike module) {
465+
TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
460466
auto lookup = conventionTable.find(module);
461467
if (lookup == conventionTable.end())
462-
return aggregatePreservationMode;
468+
return defaultAggregatePreservationMode;
463469
switch (lookup->second) {
464470
case Convention::Scalarized:
465471
return PreserveAggregate::None;
466472
case Convention::Internal:
467-
return aggregatePreservationMode;
473+
return defaultAggregatePreservationMode;
468474
}
469475
llvm_unreachable("Unknown convention");
470-
return aggregatePreservationMode;
476+
return defaultAggregatePreservationMode;
471477
}
472478

473479
Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
@@ -636,7 +642,7 @@ bool TypeLoweringVisitor::lowerProducer(
636642
return false;
637643
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
638644

639-
if (!peelType(srcFType, fieldTypes, aggregatePreservationMode))
645+
if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
640646
return false;
641647

642648
SmallVector<Value> lowered;
@@ -805,7 +811,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
805811
// Flatten any bundle types.
806812
SmallVector<FlatBundleFieldEntry> fieldTypes;
807813
auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].pi.type);
808-
if (!peelType(srcType, fieldTypes, getPreservationModeForModule(module)))
814+
if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
809815
return false;
810816

811817
// Ports with internalPath set cannot be lowered.
@@ -925,7 +931,7 @@ bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
925931
// Attempt to get the bundle types.
926932
SmallVector<FlatBundleFieldEntry> fields;
927933

928-
if (!peelType(op.getDest().getType(), fields, aggregatePreservationMode))
934+
if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
929935
return false;
930936

931937
// Loop over the leaf aggregates.
@@ -1458,7 +1464,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
14581464
SmallVector<Direction> newDirs;
14591465
SmallVector<Attribute> newNames;
14601466
SmallVector<Attribute> newPortAnno;
1461-
PreserveAggregate::PreserveMode mode = getPreservationModeForModule(
1467+
PreserveAggregate::PreserveMode mode = getPreservationModeForPorts(
14621468
cast<FModuleLike>(op.getReferencedOperation(symTbl)));
14631469

14641470
endFields.push_back(0);
@@ -1662,9 +1668,15 @@ void LowerTypesPass::runOnOperation() {
16621668

16631669
// This lambda, executes in parallel for each Op within the circt.
16641670
auto lowerModules = [&](FModuleLike op) -> LogicalResult {
1671+
// Use body type lowering attribute if it exists, otherwise use internal.
1672+
Convention convention = Convention::Internal;
1673+
if (auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
1674+
op->getDiscardableAttr("body_type_lowering")))
1675+
convention = conventionAttr.getValue();
1676+
16651677
auto tl =
1666-
TypeLoweringVisitor(&getContext(), preserveAggregate, preserveMemories,
1667-
symTbl, cache, conventionTable);
1678+
TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
1679+
preserveMemories, symTbl, cache, conventionTable);
16681680
tl.lowerModule(op);
16691681

16701682
return LogicalResult::failure(tl.isFailed());

test/Dialect/FIRRTL/annotations.mlir

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -734,14 +734,33 @@ firrtl.circuit "Test" attributes {rawAnnotations = [
734734
// -----
735735

736736
firrtl.circuit "Test" attributes {rawAnnotations =[
737-
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"}
737+
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"},
738+
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = false}
738739
]} {
739-
// CHECK: attributes {convention = #firrtl<convention scalarized>}
740+
// CHECK: attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
740741
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {}
741742
}
742743

743744
// -----
744745

746+
firrtl.circuit "Test" attributes {rawAnnotations = [
747+
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"},
748+
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true}
749+
]} {
750+
// CHECK: @Test() attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
751+
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {
752+
firrtl.instance child @Child()
753+
}
754+
755+
// CHECK: @Child() attributes {body_type_lowering = #firrtl<convention scalarized>}
756+
firrtl.module @Child() attributes {convention = #firrtl<convention internal>} {}
757+
758+
// CHECK: @Child2() {
759+
firrtl.module @Child2() attributes {convention = #firrtl<convention internal>} {}
760+
}
761+
762+
// -----
763+
745764
firrtl.circuit "Test" attributes {rawAnnotations =[
746765
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>comb", prefix = "Prefix_"},
747766
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>seq", prefix = "Prefix_"},

test/Dialect/FIRRTL/lower-types.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,45 @@ firrtl.circuit "UnrealizedConversion" {
14051405
}
14061406
}
14071407

1408+
firrtl.circuit "Conventions1" {
1409+
// COMMON-LABEL: @Conventions1
1410+
// AGGREGATE-SAME: %input_0
1411+
// AGGREGATE-NEXT: firrtl.reg
1412+
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
1413+
firrtl.module public @Conventions1(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention internal>}{
1414+
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
1415+
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
1416+
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
1417+
}
1418+
// COMMON-LABEL: @Conventions2
1419+
// AGGREGATE-SAME: %input_0: !firrtl.uint<8>
1420+
// AGGREGATE-NEXT: firrtl.reg
1421+
// AGGREGATE-SAME: !firrtl.uint<8>
1422+
firrtl.module private @Conventions2(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention scalarized>}{
1423+
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
1424+
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
1425+
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
1426+
}
1427+
// COMMON-LABEL: @Conventions3
1428+
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
1429+
// AGGREGATE-NEXT: firrtl.reg
1430+
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
1431+
firrtl.module private @Conventions3(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention internal>}{
1432+
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
1433+
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
1434+
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
1435+
}
1436+
// COMMON-LABEL: @Conventions4
1437+
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
1438+
// AGGREGATE-NEXT: firrtl.reg
1439+
// AGGREGATE-SAME: !firrtl.uint<8>
1440+
firrtl.module private @Conventions4(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention scalarized>}{
1441+
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
1442+
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
1443+
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
1444+
}
1445+
}
1446+
14081447
// Test that memories have their prefixes copied when lowering.
14091448
// See: https://github.com/llvm/circt/issues/7835
14101449
firrtl.circuit "MemoryPrefixCopying" {

0 commit comments

Comments
 (0)