Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIRRTL] Add a new FIRRTL annotation to specify type lowering behavior of module body #7751

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions docs/Dialects/FIRRTL/FIRRTLAnnotations.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,12 @@ Example:

### Convention

| Property | Type | Description |
| ---------- | ------ | --------------------------------------- |
| class | string | `circt.ConventionAnnotation` |
| convention | string | `scalarized` |
| target | string | Reference target |
| Property | Type | Description |
| ---------------- | ------ | ---------------------------------------------------- |
| class | string | `circt.ConventionAnnotation` |
| convention | string | `scalarized` |
| target | string | Reference target |
| includeHierarchy | bool | Apply the convention to all modules in the hierarchy |

Specify the port convention for a module. The port convention controls how a
module's ports are transformed, and how that module can be instantiated, in the
Expand All @@ -337,11 +338,41 @@ The options are:
- `scalarized`: Convert aggregate ports (i.e. vector or bundles) into multiple
ground-typed ports.

`includeHierarchy` is optional and defaults to `false`, meaning that the
convention is applied only to the specified module. If `includeHierarchy` is
`true`, the convention is applied to all modules in the hierarchy. If there are
multiple annotation instances that specify conventions, the `scalarized` convention
takes precedence over the `internal` convention.

```json
{
"class": "circt.ConventionAnnotation",
"convention": "scalarized",
"target": "~Foo|Bar/d:Baz"
"target": "~Foo|Bar",
"includeHierarchy": true
}
```

### BodyTypeLoweringAnnotation

| Property | Type | Description |
| ---------------- | ------ | ---------------------------------- |
| class | string | `circt.BodyTypeLoweringAnnotation` |
| convention | string | See `Convention` annotation |
| target | string | See `Convention` annotation |
| includeHierarchy | bool | See `Convention` annotation |

Specify the type lowering option for module internal signals.
This is similar to the `Convention` annotation, but for internal signals
rather than module ports. Refer to the `Convention` annotation for each
property description.

```json
{
"class": "circt.BodyTypeLoweringAnnotation",
"convention": "scalarized",
"target": "~Foo|Bar",
"includeHierarchy": true
}
```

Expand Down
2 changes: 2 additions & 0 deletions include/circt/Dialect/FIRRTL/AnnotationDetails.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ constexpr const char *rawAnnotations = "rawAnnotations";
//===----------------------------------------------------------------------===//

constexpr const char *conventionAnnoClass = "circt.ConventionAnnotation";
constexpr const char *typeLoweringAnnoClass =
"circt.BodyTypeLoweringAnnotation";
constexpr const char *dontTouchAnnoClass =
"firrtl.transforms.DontTouchAnnotation";
constexpr const char *enumComponentAnnoClass =
Expand Down
48 changes: 41 additions & 7 deletions lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,17 @@ static std::optional<Convention> parseConvention(llvm::StringRef str) {
.Default(std::nullopt);
}

static LogicalResult applyConventionAnno(const AnnoPathValue &target,
DictionaryAttr anno,
ApplyState &state) {
template <bool IsConventionAnno>
static LogicalResult
applyConventionOrTypeLoweringAnno(const AnnoPathValue &target,
DictionaryAttr anno, ApplyState &state) {
auto *op = target.ref.getOp();
auto loc = op->getLoc();
auto error = [&]() {
auto diag = mlir::emitError(loc);
diag << "circuit.ConventionAnnotation ";
diag << (IsConventionAnno ? "circuit.ConventionAnnotation "
: "circuit.TypeLoweringAnnotation ")
<< " ";
return diag;
};

Expand All @@ -305,13 +308,41 @@ static LogicalResult applyConventionAnno(const AnnoPathValue &target,

auto convention = *conventionOpt;

if (convention == Convention::Internal)
// Convention is internal by default so there is nothing to change
return success();

auto includeHierarchy = anno.getAs<BoolAttr>("includeHierarchy");
auto conventionAttr = ConventionAttr::get(op->getContext(), convention);
auto setConvention = [&](Operation *moduleOp) {
TypeSwitch<Operation *>(moduleOp)
.Case<FModuleOp, FExtModuleOp>([&](auto moduleOp) {
if (IsConventionAnno)
moduleOp.setConventionAttr(conventionAttr);
else
moduleOp->setDiscardableAttr("body_type_lowering", conventionAttr);
})
.Default([](auto) {});
};

if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
moduleOp.setConvention(convention);
if (includeHierarchy && includeHierarchy.getValue()) {
// If includeHierarchy is true, update the convention for all modules in
// the hierarchy.
for (auto *node :
llvm::post_order(state.instancePathCache.instanceGraph[moduleOp])) {
if (node && isa<FModuleOp, FExtModuleOp>(*node->getModule()))
setConvention(node->getModule());
}
} else {
// Update the convention.
setConvention(moduleOp);
}
return success();
}

if (auto extModuleOp = dyn_cast<FExtModuleOp>(op)) {
extModuleOp.setConvention(convention);
setConvention(extModuleOp);
return success();
}

Expand Down Expand Up @@ -563,7 +594,10 @@ static llvm::StringMap<AnnoRecord> annotationRecords{{
{omirTrackerAnnoClass, {stdResolve, applyWithoutTarget<true>}},
{omirFileAnnoClass, NoTargetAnnotation},
// Miscellaneous Annotations
{conventionAnnoClass, {stdResolve, applyConventionAnno}},
{conventionAnnoClass,
{stdResolve, applyConventionOrTypeLoweringAnno<true>}},
{typeLoweringAnnoClass,
{stdResolve, applyConventionOrTypeLoweringAnno<false>}},
{dontTouchAnnoClass,
{stdResolve, applyWithoutTarget<true, true, WireOp, NodeOp, RegOp,
RegResetOp, InstanceOp, MemOp, CombMemOp,
Expand Down
42 changes: 27 additions & 15 deletions lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,17 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {

TypeLoweringVisitor(
MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate,
Convention bodyConvention,
PreserveAggregate::PreserveMode memoryPreservationMode,
SymbolTable &symTbl, const AttrCache &cache,
const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
: context(context), aggregatePreservationMode(preserveAggregate),
: context(context), defaultAggregatePreservationMode(preserveAggregate),
memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
cache(cache), conventionTable(conventionTable) {}
cache(cache), conventionTable(conventionTable) {
bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
? PreserveAggregate::None
: defaultAggregatePreservationMode;
}
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitExpr;
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitStmt;
Expand Down Expand Up @@ -429,7 +434,7 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
Location errorLoc);

PreserveAggregate::PreserveMode
getPreservationModeForModule(FModuleLike moduleLike);
getPreservationModeForPorts(FModuleLike moduleLike);
Value getSubWhatever(Value val, size_t index);

size_t uniqueIdx = 0;
Expand All @@ -441,7 +446,8 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
MLIRContext *context;

/// Aggregate preservation mode.
PreserveAggregate::PreserveMode aggregatePreservationMode;
PreserveAggregate::PreserveMode defaultAggregatePreservationMode;
PreserveAggregate::PreserveMode bodyAggregatePreservationMode;
PreserveAggregate::PreserveMode memoryPreservationMode;

/// The builder is set and maintained in the main loop.
Expand All @@ -460,21 +466,21 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
};
} // namespace

/// Return aggregate preservation mode for the module. If the module has a
/// Return aggregate preservation mode for the module ports. If the module has a
/// scalarized linkage, then we may not preserve it's aggregate ports.
PreserveAggregate::PreserveMode
TypeLoweringVisitor::getPreservationModeForModule(FModuleLike module) {
TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
auto lookup = conventionTable.find(module);
if (lookup == conventionTable.end())
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
switch (lookup->second) {
case Convention::Scalarized:
return PreserveAggregate::None;
case Convention::Internal:
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
}
llvm_unreachable("Unknown convention");
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
}

Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
Expand Down Expand Up @@ -643,7 +649,7 @@ bool TypeLoweringVisitor::lowerProducer(
return false;
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;

if (!peelType(srcFType, fieldTypes, aggregatePreservationMode))
if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
return false;

SmallVector<Value> lowered;
Expand Down Expand Up @@ -809,7 +815,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
// Flatten any bundle types.
SmallVector<FlatBundleFieldEntry> fieldTypes;
auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].pi.type);
if (!peelType(srcType, fieldTypes, getPreservationModeForModule(module)))
if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
return false;

// Ports with internalPath set cannot be lowered.
Expand Down Expand Up @@ -929,7 +935,7 @@ bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
// Attempt to get the bundle types.
SmallVector<FlatBundleFieldEntry> fields;

if (!peelType(op.getDest().getType(), fields, aggregatePreservationMode))
if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
return false;

// Loop over the leaf aggregates.
Expand Down Expand Up @@ -1454,7 +1460,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
SmallVector<Direction> newDirs;
SmallVector<Attribute> newNames;
SmallVector<Attribute> newPortAnno;
PreserveAggregate::PreserveMode mode = getPreservationModeForModule(
PreserveAggregate::PreserveMode mode = getPreservationModeForPorts(
cast<FModuleLike>(op.getReferencedOperation(symTbl)));

endFields.push_back(0);
Expand Down Expand Up @@ -1667,9 +1673,15 @@ void LowerTypesPass::runOnOperation() {

// This lambda, executes in parallel for each Op within the circt.
auto lowerModules = [&](FModuleLike op) -> LogicalResult {
// Use body type lowering attribute if it exists, otherwise use internal.
Convention convention = Convention::Internal;
if (auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
op->getDiscardableAttr("body_type_lowering")))
convention = conventionAttr.getValue();

auto tl =
TypeLoweringVisitor(&getContext(), preserveAggregate, preserveMemories,
symTbl, cache, conventionTable);
TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
preserveMemories, symTbl, cache, conventionTable);
tl.lowerModule(op);

return LogicalResult::failure(tl.isFailed());
Expand Down
23 changes: 21 additions & 2 deletions test/Dialect/FIRRTL/annotations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -734,14 +734,33 @@ firrtl.circuit "Test" attributes {rawAnnotations = [
// -----

firrtl.circuit "Test" attributes {rawAnnotations =[
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"}
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"},
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized"}
]} {
// CHECK: attributes {convention = #firrtl<convention scalarized>}
// CHECK: attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {}
}

// -----

firrtl.circuit "Test" attributes {rawAnnotations = [
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true},
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true}
]} {
// CHECK: @Test() attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {
firrtl.instance child @Child()
}

// CHECK: @Child() attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
firrtl.module @Child() attributes {convention = #firrtl<convention internal>} {}

// CHECK: @Child2() {
firrtl.module @Child2() attributes {convention = #firrtl<convention internal>} {}
}

// -----

firrtl.circuit "Test" attributes {rawAnnotations =[
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>comb", prefix = "Prefix_"},
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>seq", prefix = "Prefix_"},
Expand Down
39 changes: 39 additions & 0 deletions test/Dialect/FIRRTL/lower-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1404,3 +1404,42 @@ firrtl.circuit "UnrealizedConversion" {
firrtl.matchingconnect %w, %b : !firrtl.bundle<data: uint<64>, tag: uint<1>>
}
}

firrtl.circuit "Conventions1" {
// COMMON-LABEL: @Conventions1
// AGGREGATE-SAME: %input_0
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
firrtl.module public @Conventions1(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention internal>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions2
// AGGREGATE-SAME: %input_0: !firrtl.uint<8>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.uint<8>
firrtl.module private @Conventions2(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention scalarized>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions3
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
firrtl.module private @Conventions3(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention internal>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions4
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.uint<8>
firrtl.module private @Conventions4(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention scalarized>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
}
Loading