Skip to content

Commit 96b6abd

Browse files
committed
[FIRRTL] Add a new FIRRTL annotation to specify type lowering behavior of module body
This allows more fine-grained control over how types are lowered in different contexts. This also adds an "includeHierarchy" option to Convention annotations that allows applying the convention to all modules in the hierarchy below the annotated module.
1 parent 823c948 commit 96b6abd

File tree

7 files changed

+161
-31
lines changed

7 files changed

+161
-31
lines changed

docs/Dialects/FIRRTL/FIRRTLAnnotations.md

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,12 @@ Example:
323323

324324
### Convention
325325

326-
| Property | Type | Description |
327-
| ---------- | ------ | --------------------------------------- |
328-
| class | string | `circt.ConventionAnnotation` |
329-
| convention | string | `scalarized` |
330-
| target | string | Reference target |
326+
| Property | Type | Description |
327+
| ---------------- | ------ | ---------------------------------------------------- |
328+
| class | string | `circt.ConventionAnnotation` |
329+
| convention | string | `scalarized` |
330+
| target | string | Reference target |
331+
| includeHierarchy | bool | Apply the convention to all modules in the hierarchy |
331332

332333
Specify the port convention for a module. The port convention controls how a
333334
module's ports are transformed, and how that module can be instantiated, in the
@@ -341,7 +342,31 @@ The options are:
341342
{
342343
"class": "circt.ConventionAnnotation",
343344
"convention": "scalarized",
344-
"target": "~Foo|Bar/d:Baz"
345+
"target": "~Foo|Bar",
346+
"includeHierarchy": true
347+
}
348+
```
349+
350+
### BodyTypeLoweringAnnotation
351+
352+
| Property | Type | Description |
353+
| ---------------- | ------ | ---------------------------------- |
354+
| class | string | `circt.BodyTypeLoweringAnnotation` |
355+
| convention | string | See `Convention` annotation |
356+
| target | string | See `Convention` annotation |
357+
| includeHierarchy | bool | See `Convention` annotation |
358+
359+
Specify the type lowering option for module internal signals.
360+
This is similar to the `Convention` annotation, but for internal signals
361+
rather than module ports. Refer to the `Convention` annotation for each
362+
property description.
363+
364+
```json
365+
{
366+
"class": "circt.BodyTypeLoweringAnnotation",
367+
"convention": "scalarized",
368+
"target": "~Foo|Bar",
369+
"includeHierarchy": true
345370
}
346371
```
347372

include/circt/Dialect/FIRRTL/AnnotationDetails.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ constexpr const char *rawAnnotations = "rawAnnotations";
2929
//===----------------------------------------------------------------------===//
3030

3131
constexpr const char *conventionAnnoClass = "circt.ConventionAnnotation";
32+
constexpr const char *typeLoweringAnnoClass = "circt.BodyTypeLoweringAnnotation";
3233
constexpr const char *dontTouchAnnoClass =
3334
"firrtl.transforms.DontTouchAnnotation";
3435
constexpr const char *enumComponentAnnoClass =

lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,17 @@ static std::optional<Convention> parseConvention(llvm::StringRef str) {
275275
.Default(std::nullopt);
276276
}
277277

278-
static LogicalResult applyConventionAnno(const AnnoPathValue &target,
279-
DictionaryAttr anno,
280-
ApplyState &state) {
278+
template <bool IsConventionAnno>
279+
static LogicalResult
280+
applyConventionOrTypeLoweringAnno(const AnnoPathValue &target,
281+
DictionaryAttr anno, ApplyState &state) {
281282
auto *op = target.ref.getOp();
282283
auto loc = op->getLoc();
283284
auto error = [&]() {
284285
auto diag = mlir::emitError(loc);
285-
diag << "circuit.ConventionAnnotation ";
286+
diag << (IsConventionAnno ? "circuit.ConventionAnnotation "
287+
: "circuit.TypeLoweringAnnotation ")
288+
<< " ";
286289
return diag;
287290
};
288291

@@ -305,13 +308,41 @@ static LogicalResult applyConventionAnno(const AnnoPathValue &target,
305308

306309
auto convention = *conventionOpt;
307310

311+
if (convention == Convention::Internal)
312+
// Convention is internal by default so there is nothing to change
313+
return success();
314+
315+
auto includeHierarchy = anno.getAs<BoolAttr>("includeHierarchy");
316+
auto conventionAttr = ConventionAttr::get(op->getContext(), convention);
317+
auto setConvention = [&](Operation *moduleOp) {
318+
TypeSwitch<Operation *>(moduleOp)
319+
.Case<FModuleOp, FExtModuleOp>([&](auto moduleOp) {
320+
if (IsConventionAnno)
321+
moduleOp.setConventionAttr(conventionAttr);
322+
else
323+
moduleOp->setDiscardableAttr("body_type_lowering", conventionAttr);
324+
})
325+
.Default([](auto) {});
326+
};
327+
308328
if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
309-
moduleOp.setConvention(convention);
329+
if (includeHierarchy && includeHierarchy.getValue()) {
330+
// If includeHierarchy is true, update the convention for all modules in
331+
// the hierarchy.
332+
for (auto *node :
333+
llvm::post_order(state.instancePathCache.instanceGraph[moduleOp])) {
334+
if (node && isa<FModuleOp, FExtModuleOp>(*node->getModule()))
335+
setConvention(node->getModule());
336+
}
337+
} else {
338+
// Update the convention.
339+
setConvention(moduleOp);
340+
}
310341
return success();
311342
}
312343

313344
if (auto extModuleOp = dyn_cast<FExtModuleOp>(op)) {
314-
extModuleOp.setConvention(convention);
345+
setConvention(extModuleOp);
315346
return success();
316347
}
317348

@@ -563,7 +594,10 @@ static llvm::StringMap<AnnoRecord> annotationRecords{{
563594
{omirTrackerAnnoClass, {stdResolve, applyWithoutTarget<true>}},
564595
{omirFileAnnoClass, NoTargetAnnotation},
565596
// Miscellaneous Annotations
566-
{conventionAnnoClass, {stdResolve, applyConventionAnno}},
597+
{conventionAnnoClass,
598+
{stdResolve, applyConventionOrTypeLoweringAnno<true>}},
599+
{typeLoweringAnnoClass,
600+
{stdResolve, applyConventionOrTypeLoweringAnno<false>}},
567601
{dontTouchAnnoClass,
568602
{stdResolve, applyWithoutTarget<true, true, WireOp, NodeOp, RegOp,
569603
RegResetOp, InstanceOp, MemOp, CombMemOp,

lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -339,12 +339,17 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
339339

340340
TypeLoweringVisitor(
341341
MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate,
342+
Convention bodyConvention,
342343
PreserveAggregate::PreserveMode memoryPreservationMode,
343344
SymbolTable &symTbl, const AttrCache &cache,
344345
const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
345-
: context(context), aggregatePreservationMode(preserveAggregate),
346+
: context(context), defaultAggregatePreservationMode(preserveAggregate),
346347
memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
347-
cache(cache), conventionTable(conventionTable) {}
348+
cache(cache), conventionTable(conventionTable) {
349+
bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
350+
? PreserveAggregate::None
351+
: defaultAggregatePreservationMode;
352+
}
348353
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
349354
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitExpr;
350355
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitStmt;
@@ -429,7 +434,7 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
429434
Location errorLoc);
430435

431436
PreserveAggregate::PreserveMode
432-
getPreservationModeForModule(FModuleLike moduleLike);
437+
getPreservationModeForPorts(FModuleLike moduleLike);
433438
Value getSubWhatever(Value val, size_t index);
434439

435440
size_t uniqueIdx = 0;
@@ -441,7 +446,8 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
441446
MLIRContext *context;
442447

443448
/// Aggregate preservation mode.
444-
PreserveAggregate::PreserveMode aggregatePreservationMode;
449+
PreserveAggregate::PreserveMode defaultAggregatePreservationMode;
450+
PreserveAggregate::PreserveMode bodyAggregatePreservationMode;
445451
PreserveAggregate::PreserveMode memoryPreservationMode;
446452

447453
/// The builder is set and maintained in the main loop.
@@ -460,21 +466,21 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
460466
};
461467
} // namespace
462468

463-
/// Return aggregate preservation mode for the module. If the module has a
469+
/// Return aggregate preservation mode for the module ports. If the module has a
464470
/// scalarized linkage, then we may not preserve it's aggregate ports.
465471
PreserveAggregate::PreserveMode
466-
TypeLoweringVisitor::getPreservationModeForModule(FModuleLike module) {
472+
TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
467473
auto lookup = conventionTable.find(module);
468474
if (lookup == conventionTable.end())
469-
return aggregatePreservationMode;
475+
return defaultAggregatePreservationMode;
470476
switch (lookup->second) {
471477
case Convention::Scalarized:
472478
return PreserveAggregate::None;
473479
case Convention::Internal:
474-
return aggregatePreservationMode;
480+
return defaultAggregatePreservationMode;
475481
}
476482
llvm_unreachable("Unknown convention");
477-
return aggregatePreservationMode;
483+
return defaultAggregatePreservationMode;
478484
}
479485

480486
Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
@@ -643,7 +649,7 @@ bool TypeLoweringVisitor::lowerProducer(
643649
return false;
644650
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
645651

646-
if (!peelType(srcFType, fieldTypes, aggregatePreservationMode))
652+
if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
647653
return false;
648654

649655
SmallVector<Value> lowered;
@@ -809,7 +815,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
809815
// Flatten any bundle types.
810816
SmallVector<FlatBundleFieldEntry> fieldTypes;
811817
auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].pi.type);
812-
if (!peelType(srcType, fieldTypes, getPreservationModeForModule(module)))
818+
if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
813819
return false;
814820

815821
// Ports with internalPath set cannot be lowered.
@@ -929,7 +935,7 @@ bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
929935
// Attempt to get the bundle types.
930936
SmallVector<FlatBundleFieldEntry> fields;
931937

932-
if (!peelType(op.getDest().getType(), fields, aggregatePreservationMode))
938+
if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
933939
return false;
934940

935941
// Loop over the leaf aggregates.
@@ -1454,7 +1460,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
14541460
SmallVector<Direction> newDirs;
14551461
SmallVector<Attribute> newNames;
14561462
SmallVector<Attribute> newPortAnno;
1457-
PreserveAggregate::PreserveMode mode = getPreservationModeForModule(
1463+
PreserveAggregate::PreserveMode mode = getPreservationModeForPorts(
14581464
cast<FModuleLike>(op.getReferencedOperation(symTbl)));
14591465

14601466
endFields.push_back(0);
@@ -1667,9 +1673,15 @@ void LowerTypesPass::runOnOperation() {
16671673

16681674
// This lambda, executes in parallel for each Op within the circt.
16691675
auto lowerModules = [&](FModuleLike op) -> LogicalResult {
1676+
// Use body type lowering attribute if it exists, otherwise use internal.
1677+
Convention convention = Convention::Internal;
1678+
if (auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
1679+
op->getDiscardableAttr("body_type_lowering")))
1680+
convention = conventionAttr.getValue();
1681+
16701682
auto tl =
1671-
TypeLoweringVisitor(&getContext(), preserveAggregate, preserveMemories,
1672-
symTbl, cache, conventionTable);
1683+
TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
1684+
preserveMemories, symTbl, cache, conventionTable);
16731685
tl.lowerModule(op);
16741686

16751687
return LogicalResult::failure(tl.isFailed());

llvm

Submodule llvm updated 5164 files

test/Dialect/FIRRTL/annotations.mlir

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -734,14 +734,33 @@ firrtl.circuit "Test" attributes {rawAnnotations = [
734734
// -----
735735

736736
firrtl.circuit "Test" attributes {rawAnnotations =[
737-
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"}
737+
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"},
738+
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized"}
738739
]} {
739-
// CHECK: attributes {convention = #firrtl<convention scalarized>}
740+
// CHECK: attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
740741
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {}
741742
}
742743

743744
// -----
744745

746+
firrtl.circuit "Test" attributes {rawAnnotations = [
747+
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true},
748+
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true}
749+
]} {
750+
// CHECK: @Test() attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
751+
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {
752+
firrtl.instance child @Child()
753+
}
754+
755+
// CHECK: @Child() attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
756+
firrtl.module @Child() attributes {convention = #firrtl<convention internal>} {}
757+
758+
// CHECK: @Child2() {
759+
firrtl.module @Child2() attributes {convention = #firrtl<convention internal>} {}
760+
}
761+
762+
// -----
763+
745764
firrtl.circuit "Test" attributes {rawAnnotations =[
746765
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>comb", prefix = "Prefix_"},
747766
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>seq", prefix = "Prefix_"},

test/Dialect/FIRRTL/lower-types.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,3 +1404,42 @@ firrtl.circuit "UnrealizedConversion" {
14041404
firrtl.matchingconnect %w, %b : !firrtl.bundle<data: uint<64>, tag: uint<1>>
14051405
}
14061406
}
1407+
1408+
firrtl.circuit "Conventions1" {
1409+
// COMMON-LABEL: @Conventions1
1410+
// AGGREGATE-SAME: %input_0
1411+
// AGGREGATE-NEXT: firrtl.reg
1412+
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
1413+
firrtl.module public @Conventions1(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention internal>}{
1414+
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
1415+
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
1416+
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
1417+
}
1418+
// COMMON-LABEL: @Conventions2
1419+
// AGGREGATE-SAME: %input_0: !firrtl.uint<8>
1420+
// AGGREGATE-NEXT: firrtl.reg
1421+
// AGGREGATE-SAME: !firrtl.uint<8>
1422+
firrtl.module private @Conventions2(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention scalarized>}{
1423+
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
1424+
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
1425+
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
1426+
}
1427+
// COMMON-LABEL: @Conventions3
1428+
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
1429+
// AGGREGATE-NEXT: firrtl.reg
1430+
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
1431+
firrtl.module private @Conventions3(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention internal>}{
1432+
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
1433+
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
1434+
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
1435+
}
1436+
// COMMON-LABEL: @Conventions4
1437+
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
1438+
// AGGREGATE-NEXT: firrtl.reg
1439+
// AGGREGATE-SAME: !firrtl.uint<8>
1440+
firrtl.module private @Conventions4(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention scalarized>}{
1441+
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
1442+
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
1443+
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
1444+
}
1445+
}

0 commit comments

Comments
 (0)