Skip to content

Commit 2160b4d

Browse files
committed
[FIRRTL] Extend Convention to specify type lowering for body
Extend the FIRRTL Convention attribute to separately specify type lowering behavior for module ports and module body. This allows more fine-grained control over how types are lowered in different contexts. The Convention attribute now takes two parameters: - Port convention: Controls how module ports are lowered - Body convention: Controls how types within the module body are lowered Updates the syntax from: #firrtl<convention internal> to: #firrtl.convention<internal, internal> This change enables modules to have different type lowering strategies for their interfaces and their bodies. For example, a module could preserve aggregate types in its ports while scalarizing them in its body. Updates all relevant tests and code to use the new two-parameter Convention format.
1 parent 0713332 commit 2160b4d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+317
-172
lines changed

include/circt/Dialect/FIRRTL/FIRRTLEnums.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def Convention : I32EnumAttr<"Convention", "lowering convention", [
4444
let genSpecializedAttr = 0;
4545
}
4646

47-
def ConventionAttr : EnumAttr<FIRRTLDialect, Convention, "convention">;
47+
// Pair of port and body conventions
48+
def ConventionAttr : AttrDef<FIRRTLDialect, "Convention"> {
49+
let parameters = (ins "Convention":$portConvention, "Convention":$bodyConvention);
50+
let summary = "Pair of port and body conventions";
51+
let mnemonic = "convention";
52+
let assemblyFormat = "`<` $portConvention `,` $bodyConvention `>`";
53+
}
4854

4955
//===----------------------------------------------------------------------===//
5056
// Layer Lowering Conventions

include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.td

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,17 @@ def FModuleLike : OpInterface<"FModuleLike", [Symbol, PortList, InstanceGraphMod
3232
//===------------------------------------------------------------------===//
3333

3434
InterfaceMethod<"Get the module's instantiation convention",
35-
"ConventionAttr", "getConventionAttr">,
35+
"ConventionAttr", "getConventionAttr", (ins), [{}],
36+
/*defaultImplementation=*/[{
37+
return ConventionAttr::get($_op.getContext(), $_op.getPortConvention(),
38+
$_op.getBodyConvention());
39+
}]>,
3640

37-
InterfaceMethod<"Get the module's instantiation convention",
38-
"Convention", "getConvention">,
41+
InterfaceMethod<"Get the module's port convention",
42+
"Convention", "getPortConvention">,
43+
44+
InterfaceMethod<"Get the module's body convention",
45+
"Convention", "getBodyConvention">,
3946

4047
//===------------------------------------------------------------------===//
4148
// Enabled (AKA Required) Layers

include/circt/Dialect/FIRRTL/FIRRTLStructure.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class FIRRTLModuleLike<string mnemonic, list<Trait> traits = []> :
126126
}
127127

128128

129-
def FModuleOp : FIRRTLModuleLike<"module", [SingleBlock, NoTerminator]> {
129+
def FModuleOp : FIRRTLModuleLike<"module", [SingleBlock, NoTerminator, DeclareOpInterfaceMethods<FModuleLike>]> {
130130
let summary = "FIRRTL Module";
131131
let description = [{
132132
The "firrtl.module" operation represents a Verilog module, including a given

lib/CAPI/Dialect/FIRRTL.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -174,19 +174,19 @@ MlirType firrtlTypeGetMaskType(MlirType type) {
174174
//===----------------------------------------------------------------------===//
175175

176176
MlirAttribute firrtlAttrGetConvention(MlirContext ctx,
177-
FIRRTLConvention convention) {
178-
Convention value;
179-
180-
switch (convention) {
181-
case FIRRTL_CONVENTION_INTERNAL:
182-
value = Convention::Internal;
183-
break;
184-
case FIRRTL_CONVENTION_SCALARIZED:
185-
value = Convention::Scalarized;
186-
break;
187-
}
188-
189-
return wrap(ConventionAttr::get(unwrap(ctx), value));
177+
FIRRTLConvention portConvention,
178+
FIRRTLConvention bodyConvention) {
179+
auto getConvention = [&](FIRRTLConvention convention) {
180+
switch (convention) {
181+
case FIRRTL_CONVENTION_INTERNAL:
182+
return Convention::Internal;
183+
case FIRRTL_CONVENTION_SCALARIZED:
184+
return Convention::Scalarized;
185+
}
186+
};
187+
188+
return wrap(ConventionAttr::get(unwrap(ctx), getConvention(portConvention),
189+
getConvention(bodyConvention)));
190190
}
191191

192192
MlirAttribute firrtlAttrGetPortDirs(MlirContext ctx, size_t count,

lib/Dialect/FIRRTL/FIRRTLOps.cpp

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,7 +1351,8 @@ static void printFModuleLikeOp(OpAsmPrinter &p, FModuleLike op) {
13511351
"sym_name", "portDirections", "portTypes", "portAnnotations",
13521352
"portSymbols", "portLocations", "parameters", visibilityAttrName};
13531353

1354-
if (op.getConvention() == Convention::Internal)
1354+
if (op.getPortConvention() == Convention::Internal &&
1355+
op.getBodyConvention() == Convention::Internal)
13551356
omittedAttrs.push_back("convention");
13561357

13571358
// We can omit the portNames if they were able to be printed as properly as
@@ -1535,9 +1536,10 @@ ParseResult FModuleOp::parse(OpAsmParser &parser, OperationState &result) {
15351536
if (parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/true))
15361537
return failure();
15371538
if (!result.attributes.get("convention"))
1538-
result.addAttribute(
1539-
"convention",
1540-
ConventionAttr::get(result.getContext(), Convention::Internal));
1539+
result.addAttribute("convention",
1540+
ConventionAttr::get(result.getContext(),
1541+
Convention::Internal,
1542+
Convention::Internal));
15411543
if (!result.attributes.get("layers"))
15421544
result.addAttribute("layers", ArrayAttr::get(parser.getContext(), {}));
15431545
return success();
@@ -1547,9 +1549,10 @@ ParseResult FExtModuleOp::parse(OpAsmParser &parser, OperationState &result) {
15471549
if (parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/false))
15481550
return failure();
15491551
if (!result.attributes.get("convention"))
1550-
result.addAttribute(
1551-
"convention",
1552-
ConventionAttr::get(result.getContext(), Convention::Internal));
1552+
result.addAttribute("convention",
1553+
ConventionAttr::get(result.getContext(),
1554+
Convention::Internal,
1555+
Convention::Internal));
15531556
return success();
15541557
}
15551558

@@ -1767,17 +1770,13 @@ ArrayAttr FMemModuleOp::getParameters() { return {}; }
17671770

17681771
ArrayAttr FModuleOp::getParameters() { return {}; }
17691772

1770-
Convention FIntModuleOp::getConvention() { return Convention::Internal; }
1773+
Convention FIntModuleOp::getPortConvention() { return Convention::Internal; }
17711774

1772-
ConventionAttr FIntModuleOp::getConventionAttr() {
1773-
return ConventionAttr::get(getContext(), getConvention());
1774-
}
1775+
Convention FIntModuleOp::getBodyConvention() { return Convention::Internal; }
17751776

1776-
Convention FMemModuleOp::getConvention() { return Convention::Internal; }
1777+
Convention FMemModuleOp::getPortConvention() { return Convention::Internal; }
17771778

1778-
ConventionAttr FMemModuleOp::getConventionAttr() {
1779-
return ConventionAttr::get(getContext(), getConvention());
1780-
}
1779+
Convention FMemModuleOp::getBodyConvention() { return Convention::Internal; }
17811780

17821781
//===----------------------------------------------------------------------===//
17831782
// ClassLike Helpers
@@ -2056,12 +2055,26 @@ void ClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
20562055
::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
20572056
}
20582057

2059-
Convention ClassOp::getConvention() { return Convention::Internal; }
2058+
Convention FModuleOp::getPortConvention() {
2059+
return getConventionAttr().getPortConvention();
2060+
}
2061+
2062+
Convention FModuleOp::getBodyConvention() {
2063+
return getConventionAttr().getBodyConvention();
2064+
}
2065+
2066+
Convention FExtModuleOp::getPortConvention() {
2067+
return getConventionAttr().getPortConvention();
2068+
}
20602069

2061-
ConventionAttr ClassOp::getConventionAttr() {
2062-
return ConventionAttr::get(getContext(), getConvention());
2070+
Convention FExtModuleOp::getBodyConvention() {
2071+
return getConventionAttr().getBodyConvention();
20632072
}
20642073

2074+
Convention ClassOp::getPortConvention() { return Convention::Internal; }
2075+
2076+
Convention ClassOp::getBodyConvention() { return Convention::Internal; }
2077+
20652078
ArrayAttr ClassOp::getParameters() { return {}; }
20662079

20672080
ArrayAttr ClassOp::getPortAnnotationsAttr() {
@@ -2142,11 +2155,9 @@ void ExtClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
21422155
::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
21432156
}
21442157

2145-
Convention ExtClassOp::getConvention() { return Convention::Internal; }
2158+
Convention ExtClassOp::getPortConvention() { return Convention::Internal; }
21462159

2147-
ConventionAttr ExtClassOp::getConventionAttr() {
2148-
return ConventionAttr::get(getContext(), getConvention());
2149-
}
2160+
Convention ExtClassOp::getBodyConvention() { return Convention::Internal; }
21502161

21512162
ArrayAttr ExtClassOp::getLayersAttr() {
21522163
return ArrayAttr::get(getContext(), {});

lib/Dialect/FIRRTL/Import/FIRParser.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5237,7 +5237,8 @@ ParseResult FIRCircuitParser::parseExtModule(CircuitOp circuit,
52375237
getConstants().options.scalarizeExtModules
52385238
? Convention::Scalarized
52395239
: Convention::Internal;
5240-
auto conventionAttr = ConventionAttr::get(getContext(), convention);
5240+
auto conventionAttr =
5241+
ConventionAttr::get(getContext(), convention, Convention::Internal);
52415242
auto annotations = ArrayAttr::get(getContext(), {});
52425243
auto extModuleOp = builder.create<FExtModuleOp>(
52435244
info.getLoc(), name, conventionAttr, portList, defName, annotations,
@@ -5325,7 +5326,9 @@ ParseResult FIRCircuitParser::parseModule(CircuitOp circuit, bool isPublic,
53255326
convention = Convention::Scalarized;
53265327
if (!isPublic && getConstants().options.scalarizeInternalModules)
53275328
convention = Convention::Scalarized;
5328-
auto conventionAttr = ConventionAttr::get(getContext(), convention);
5329+
// Use Internal as body convention.
5330+
auto conventionAttr =
5331+
ConventionAttr::get(getContext(), convention, Convention::Internal);
53295332
auto builder = circuit.getBodyBuilder();
53305333
auto moduleOp = builder.create<FModuleOp>(info.getLoc(), name, conventionAttr,
53315334
portList, annotations, layers);

lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,9 @@ void ExtractInstancesPass::groupInstances() {
983983
// Create the wrapper module.
984984
auto wrapper = builder.create<FModuleOp>(
985985
builder.getUnknownLoc(), wrapperModuleName,
986-
ConventionAttr::get(builder.getContext(), Convention::Internal), ports);
986+
ConventionAttr::get(builder.getContext(), Convention::Internal,
987+
Convention::Internal),
988+
ports);
987989
SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);
988990

989991
// Instantiate the wrapper module in the parent and replace uses of the

lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ static LogicalResult applyDUTAnno(const AnnoPathValue &target,
272272
static std::optional<Convention> parseConvention(llvm::StringRef str) {
273273
return ::llvm::StringSwitch<::std::optional<Convention>>(str)
274274
.Case("scalarized", Convention::Scalarized)
275+
.Case("internal", Convention::Internal)
275276
.Default(std::nullopt);
276277
}
277278

@@ -293,25 +294,75 @@ static LogicalResult applyConventionAnno(const AnnoPathValue &target,
293294
if (!target.isLocal())
294295
return error() << "must be local";
295296

296-
auto conventionStrAttr =
297-
tryGetAs<StringAttr>(anno, anno, "convention", loc, conventionAnnoClass);
298-
if (!conventionStrAttr)
297+
auto getConventionAttr = [&](StringRef name) -> FailureOr<Convention> {
298+
auto conventionName =
299+
tryGetAs<StringAttr>(anno, anno, name, loc, conventionAnnoClass);
300+
if (!conventionName)
301+
return failure();
302+
303+
auto conventionOpt = parseConvention(conventionName.getValue());
304+
if (!conventionOpt)
305+
return error() << "unknown convention " << conventionName.getValue();
306+
307+
return *conventionOpt;
308+
};
309+
auto portConvention = getConventionAttr("portConvention");
310+
auto bodyConvention = getConventionAttr("bodyConvention");
311+
if (failed(portConvention) || failed(bodyConvention))
299312
return failure();
300313

301-
auto conventionStr = conventionStrAttr.getValue();
302-
auto conventionOpt = parseConvention(conventionStr);
303-
if (!conventionOpt)
304-
return error() << "unknown convention " << conventionStr;
314+
if (*portConvention == Convention::Internal &&
315+
*bodyConvention == Convention::Internal) {
316+
// Convention is internal by default so there is nothing to change
317+
return success();
318+
}
319+
320+
auto includeHierarchy = anno.getAs<BoolAttr>("includeHierarchy");
321+
auto convention =
322+
ConventionAttr::get(op->getContext(), *portConvention, *bodyConvention);
305323

306-
auto convention = *conventionOpt;
324+
bool isBothScalarized = *portConvention == Convention::Scalarized &&
325+
*bodyConvention == Convention::Scalarized;
307326

327+
auto setConvention = [&](FModuleOp fmodule) {
328+
if (isBothScalarized) // Fast path.
329+
{
330+
if (convention != fmodule.getConvention())
331+
fmodule.setConventionAttr(convention);
332+
} else {
333+
// We prioritize scalarized over internal.
334+
auto getStrongerConvention = [&](Convention c1,
335+
Convention c2) -> Convention {
336+
return c1 == Convention::Scalarized ? c1 : c2;
337+
};
338+
// Update both port and body conventions.
339+
auto newConvention = ConventionAttr::get(
340+
convention.getContext(),
341+
getStrongerConvention(convention.getPortConvention(),
342+
fmodule.getConvention().getPortConvention()),
343+
getStrongerConvention(convention.getBodyConvention(),
344+
fmodule.getConvention().getBodyConvention()));
345+
if (newConvention != fmodule.getConvention())
346+
fmodule.setConventionAttr(newConvention);
347+
}
348+
};
308349
if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
309-
moduleOp.setConvention(convention);
350+
if (includeHierarchy && includeHierarchy.getValue()) {
351+
for (auto *node :
352+
llvm::post_order(state.instancePathCache.instanceGraph[moduleOp])) {
353+
if (node)
354+
if (auto fmodule =
355+
dyn_cast_or_null<firrtl::FModuleOp>(*node->getModule()))
356+
setConvention(fmodule);
357+
}
358+
} else {
359+
setConvention(moduleOp);
360+
}
310361
return success();
311362
}
312363

313364
if (auto extModuleOp = dyn_cast<FExtModuleOp>(op)) {
314-
extModuleOp.setConvention(convention);
365+
extModuleOp.setConventionAttr(convention);
315366
return success();
316367
}
317368

lib/Dialect/FIRRTL/Transforms/LowerLayers.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,9 @@ FModuleOp LowerLayersPass::buildNewModule(OpBuilder &builder,
209209
llvm::sys::SmartScopedLock<true> instrumentationLock(*circuitMutex);
210210
FModuleOp newModule = builder.create<FModuleOp>(
211211
location, builder.getStringAttr(namehint),
212-
ConventionAttr::get(builder.getContext(), Convention::Internal), ports,
213-
ArrayAttr{});
212+
ConventionAttr::get(builder.getContext(), Convention::Internal,
213+
Convention::Internal),
214+
ports, ArrayAttr{});
214215
if (auto dir = getOutputFile(layerBlock.getLayerNameAttr())) {
215216
assert(dir.isDirectory());
216217
newModule->setAttr("output_file", dir);

lib/Dialect/FIRRTL/Transforms/LowerMemory.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary,
244244
OpBuilder b(mem->getParentOfType<FModuleOp>());
245245
auto wrapper = b.create<FModuleOp>(
246246
mem->getLoc(), wrapperName,
247-
ConventionAttr::get(context, Convention::Internal), ports);
247+
ConventionAttr::get(context, Convention::Internal, Convention::Internal),
248+
ports);
248249
SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);
249250

250251
// Create an instance of the external memory module. The instance has the

0 commit comments

Comments
 (0)