Skip to content

Commit

Permalink
[FIRRTL] Extend Convention to specify type lowering for body
Browse files Browse the repository at this point in the history
Extend the FIRRTL Convention attribute to separately specify type lowering
behavior for module ports and module body. This allows more fine-grained
control over how types are lowered in different contexts.

The Convention attribute now takes two parameters:
- Port convention: Controls how module ports are lowered
- Body convention: Controls how types within the module body are lowered

Updates the syntax from:
  #firrtl<convention internal>
to:
  #firrtl.convention<internal, internal>

This change enables modules to have different type lowering strategies
for their interfaces and their bodies. For example, a module
could preserve aggregate types in its ports while scalarizing them in
its body.

Updates all relevant tests and code to use the new two-parameter
Convention format.
  • Loading branch information
uenoku committed Oct 30, 2024
1 parent 0713332 commit 2160b4d
Show file tree
Hide file tree
Showing 41 changed files with 317 additions and 172 deletions.
8 changes: 7 additions & 1 deletion include/circt/Dialect/FIRRTL/FIRRTLEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def Convention : I32EnumAttr<"Convention", "lowering convention", [
let genSpecializedAttr = 0;
}

def ConventionAttr : EnumAttr<FIRRTLDialect, Convention, "convention">;
// Pair of port and body conventions
def ConventionAttr : AttrDef<FIRRTLDialect, "Convention"> {
let parameters = (ins "Convention":$portConvention, "Convention":$bodyConvention);
let summary = "Pair of port and body conventions";
let mnemonic = "convention";
let assemblyFormat = "`<` $portConvention `,` $bodyConvention `>`";
}

//===----------------------------------------------------------------------===//
// Layer Lowering Conventions
Expand Down
13 changes: 10 additions & 3 deletions include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ def FModuleLike : OpInterface<"FModuleLike", [Symbol, PortList, InstanceGraphMod
//===------------------------------------------------------------------===//

InterfaceMethod<"Get the module's instantiation convention",
"ConventionAttr", "getConventionAttr">,
"ConventionAttr", "getConventionAttr", (ins), [{}],
/*defaultImplementation=*/[{
return ConventionAttr::get($_op.getContext(), $_op.getPortConvention(),
$_op.getBodyConvention());
}]>,

InterfaceMethod<"Get the module's instantiation convention",
"Convention", "getConvention">,
InterfaceMethod<"Get the module's port convention",
"Convention", "getPortConvention">,

InterfaceMethod<"Get the module's body convention",
"Convention", "getBodyConvention">,

//===------------------------------------------------------------------===//
// Enabled (AKA Required) Layers
Expand Down
2 changes: 1 addition & 1 deletion include/circt/Dialect/FIRRTL/FIRRTLStructure.td
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class FIRRTLModuleLike<string mnemonic, list<Trait> traits = []> :
}


def FModuleOp : FIRRTLModuleLike<"module", [SingleBlock, NoTerminator]> {
def FModuleOp : FIRRTLModuleLike<"module", [SingleBlock, NoTerminator, DeclareOpInterfaceMethods<FModuleLike>]> {
let summary = "FIRRTL Module";
let description = [{
The "firrtl.module" operation represents a Verilog module, including a given
Expand Down
26 changes: 13 additions & 13 deletions lib/CAPI/Dialect/FIRRTL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,19 @@ MlirType firrtlTypeGetMaskType(MlirType type) {
//===----------------------------------------------------------------------===//

MlirAttribute firrtlAttrGetConvention(MlirContext ctx,
FIRRTLConvention convention) {
Convention value;

switch (convention) {
case FIRRTL_CONVENTION_INTERNAL:
value = Convention::Internal;
break;
case FIRRTL_CONVENTION_SCALARIZED:
value = Convention::Scalarized;
break;
}

return wrap(ConventionAttr::get(unwrap(ctx), value));
FIRRTLConvention portConvention,
FIRRTLConvention bodyConvention) {
auto getConvention = [&](FIRRTLConvention convention) {
switch (convention) {
case FIRRTL_CONVENTION_INTERNAL:
return Convention::Internal;
case FIRRTL_CONVENTION_SCALARIZED:
return Convention::Scalarized;
}
};

return wrap(ConventionAttr::get(unwrap(ctx), getConvention(portConvention),
getConvention(bodyConvention)));
}

MlirAttribute firrtlAttrGetPortDirs(MlirContext ctx, size_t count,
Expand Down
55 changes: 33 additions & 22 deletions lib/Dialect/FIRRTL/FIRRTLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,8 @@ static void printFModuleLikeOp(OpAsmPrinter &p, FModuleLike op) {
"sym_name", "portDirections", "portTypes", "portAnnotations",
"portSymbols", "portLocations", "parameters", visibilityAttrName};

if (op.getConvention() == Convention::Internal)
if (op.getPortConvention() == Convention::Internal &&
op.getBodyConvention() == Convention::Internal)
omittedAttrs.push_back("convention");

// We can omit the portNames if they were able to be printed as properly as
Expand Down Expand Up @@ -1535,9 +1536,10 @@ ParseResult FModuleOp::parse(OpAsmParser &parser, OperationState &result) {
if (parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/true))
return failure();
if (!result.attributes.get("convention"))
result.addAttribute(
"convention",
ConventionAttr::get(result.getContext(), Convention::Internal));
result.addAttribute("convention",
ConventionAttr::get(result.getContext(),
Convention::Internal,
Convention::Internal));
if (!result.attributes.get("layers"))
result.addAttribute("layers", ArrayAttr::get(parser.getContext(), {}));
return success();
Expand All @@ -1547,9 +1549,10 @@ ParseResult FExtModuleOp::parse(OpAsmParser &parser, OperationState &result) {
if (parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/false))
return failure();
if (!result.attributes.get("convention"))
result.addAttribute(
"convention",
ConventionAttr::get(result.getContext(), Convention::Internal));
result.addAttribute("convention",
ConventionAttr::get(result.getContext(),
Convention::Internal,
Convention::Internal));
return success();
}

Expand Down Expand Up @@ -1767,17 +1770,13 @@ ArrayAttr FMemModuleOp::getParameters() { return {}; }

ArrayAttr FModuleOp::getParameters() { return {}; }

Convention FIntModuleOp::getConvention() { return Convention::Internal; }
Convention FIntModuleOp::getPortConvention() { return Convention::Internal; }

ConventionAttr FIntModuleOp::getConventionAttr() {
return ConventionAttr::get(getContext(), getConvention());
}
Convention FIntModuleOp::getBodyConvention() { return Convention::Internal; }

Convention FMemModuleOp::getConvention() { return Convention::Internal; }
Convention FMemModuleOp::getPortConvention() { return Convention::Internal; }

ConventionAttr FMemModuleOp::getConventionAttr() {
return ConventionAttr::get(getContext(), getConvention());
}
Convention FMemModuleOp::getBodyConvention() { return Convention::Internal; }

//===----------------------------------------------------------------------===//
// ClassLike Helpers
Expand Down Expand Up @@ -2056,12 +2055,26 @@ void ClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
}

Convention ClassOp::getConvention() { return Convention::Internal; }
Convention FModuleOp::getPortConvention() {
return getConventionAttr().getPortConvention();
}

Convention FModuleOp::getBodyConvention() {
return getConventionAttr().getBodyConvention();
}

Convention FExtModuleOp::getPortConvention() {
return getConventionAttr().getPortConvention();
}

ConventionAttr ClassOp::getConventionAttr() {
return ConventionAttr::get(getContext(), getConvention());
Convention FExtModuleOp::getBodyConvention() {
return getConventionAttr().getBodyConvention();
}

Convention ClassOp::getPortConvention() { return Convention::Internal; }

Convention ClassOp::getBodyConvention() { return Convention::Internal; }

ArrayAttr ClassOp::getParameters() { return {}; }

ArrayAttr ClassOp::getPortAnnotationsAttr() {
Expand Down Expand Up @@ -2142,11 +2155,9 @@ void ExtClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
}

Convention ExtClassOp::getConvention() { return Convention::Internal; }
Convention ExtClassOp::getPortConvention() { return Convention::Internal; }

ConventionAttr ExtClassOp::getConventionAttr() {
return ConventionAttr::get(getContext(), getConvention());
}
Convention ExtClassOp::getBodyConvention() { return Convention::Internal; }

ArrayAttr ExtClassOp::getLayersAttr() {
return ArrayAttr::get(getContext(), {});
Expand Down
7 changes: 5 additions & 2 deletions lib/Dialect/FIRRTL/Import/FIRParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5237,7 +5237,8 @@ ParseResult FIRCircuitParser::parseExtModule(CircuitOp circuit,
getConstants().options.scalarizeExtModules
? Convention::Scalarized
: Convention::Internal;
auto conventionAttr = ConventionAttr::get(getContext(), convention);
auto conventionAttr =
ConventionAttr::get(getContext(), convention, Convention::Internal);
auto annotations = ArrayAttr::get(getContext(), {});
auto extModuleOp = builder.create<FExtModuleOp>(
info.getLoc(), name, conventionAttr, portList, defName, annotations,
Expand Down Expand Up @@ -5325,7 +5326,9 @@ ParseResult FIRCircuitParser::parseModule(CircuitOp circuit, bool isPublic,
convention = Convention::Scalarized;
if (!isPublic && getConstants().options.scalarizeInternalModules)
convention = Convention::Scalarized;
auto conventionAttr = ConventionAttr::get(getContext(), convention);
// Use Internal as body convention.
auto conventionAttr =
ConventionAttr::get(getContext(), convention, Convention::Internal);
auto builder = circuit.getBodyBuilder();
auto moduleOp = builder.create<FModuleOp>(info.getLoc(), name, conventionAttr,
portList, annotations, layers);
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,9 @@ void ExtractInstancesPass::groupInstances() {
// Create the wrapper module.
auto wrapper = builder.create<FModuleOp>(
builder.getUnknownLoc(), wrapperModuleName,
ConventionAttr::get(builder.getContext(), Convention::Internal), ports);
ConventionAttr::get(builder.getContext(), Convention::Internal,
Convention::Internal),
ports);
SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);

// Instantiate the wrapper module in the parent and replace uses of the
Expand Down
71 changes: 61 additions & 10 deletions lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ static LogicalResult applyDUTAnno(const AnnoPathValue &target,
static std::optional<Convention> parseConvention(llvm::StringRef str) {
return ::llvm::StringSwitch<::std::optional<Convention>>(str)
.Case("scalarized", Convention::Scalarized)
.Case("internal", Convention::Internal)
.Default(std::nullopt);
}

Expand All @@ -293,25 +294,75 @@ static LogicalResult applyConventionAnno(const AnnoPathValue &target,
if (!target.isLocal())
return error() << "must be local";

auto conventionStrAttr =
tryGetAs<StringAttr>(anno, anno, "convention", loc, conventionAnnoClass);
if (!conventionStrAttr)
auto getConventionAttr = [&](StringRef name) -> FailureOr<Convention> {
auto conventionName =
tryGetAs<StringAttr>(anno, anno, name, loc, conventionAnnoClass);
if (!conventionName)
return failure();

auto conventionOpt = parseConvention(conventionName.getValue());
if (!conventionOpt)
return error() << "unknown convention " << conventionName.getValue();

return *conventionOpt;
};
auto portConvention = getConventionAttr("portConvention");
auto bodyConvention = getConventionAttr("bodyConvention");
if (failed(portConvention) || failed(bodyConvention))
return failure();

auto conventionStr = conventionStrAttr.getValue();
auto conventionOpt = parseConvention(conventionStr);
if (!conventionOpt)
return error() << "unknown convention " << conventionStr;
if (*portConvention == Convention::Internal &&
*bodyConvention == Convention::Internal) {
// Convention is internal by default so there is nothing to change
return success();
}

auto includeHierarchy = anno.getAs<BoolAttr>("includeHierarchy");
auto convention =
ConventionAttr::get(op->getContext(), *portConvention, *bodyConvention);

auto convention = *conventionOpt;
bool isBothScalarized = *portConvention == Convention::Scalarized &&
*bodyConvention == Convention::Scalarized;

auto setConvention = [&](FModuleOp fmodule) {
if (isBothScalarized) // Fast path.
{
if (convention != fmodule.getConvention())
fmodule.setConventionAttr(convention);
} else {
// We prioritize scalarized over internal.
auto getStrongerConvention = [&](Convention c1,
Convention c2) -> Convention {
return c1 == Convention::Scalarized ? c1 : c2;
};
// Update both port and body conventions.
auto newConvention = ConventionAttr::get(
convention.getContext(),
getStrongerConvention(convention.getPortConvention(),
fmodule.getConvention().getPortConvention()),
getStrongerConvention(convention.getBodyConvention(),
fmodule.getConvention().getBodyConvention()));
if (newConvention != fmodule.getConvention())
fmodule.setConventionAttr(newConvention);
}
};
if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
moduleOp.setConvention(convention);
if (includeHierarchy && includeHierarchy.getValue()) {
for (auto *node :
llvm::post_order(state.instancePathCache.instanceGraph[moduleOp])) {
if (node)
if (auto fmodule =
dyn_cast_or_null<firrtl::FModuleOp>(*node->getModule()))
setConvention(fmodule);
}
} else {
setConvention(moduleOp);
}
return success();
}

if (auto extModuleOp = dyn_cast<FExtModuleOp>(op)) {
extModuleOp.setConvention(convention);
extModuleOp.setConventionAttr(convention);
return success();
}

Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/FIRRTL/Transforms/LowerLayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ FModuleOp LowerLayersPass::buildNewModule(OpBuilder &builder,
llvm::sys::SmartScopedLock<true> instrumentationLock(*circuitMutex);
FModuleOp newModule = builder.create<FModuleOp>(
location, builder.getStringAttr(namehint),
ConventionAttr::get(builder.getContext(), Convention::Internal), ports,
ArrayAttr{});
ConventionAttr::get(builder.getContext(), Convention::Internal,
Convention::Internal),
ports, ArrayAttr{});
if (auto dir = getOutputFile(layerBlock.getLayerNameAttr())) {
assert(dir.isDirectory());
newModule->setAttr("output_file", dir);
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/FIRRTL/Transforms/LowerMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary,
OpBuilder b(mem->getParentOfType<FModuleOp>());
auto wrapper = b.create<FModuleOp>(
mem->getLoc(), wrapperName,
ConventionAttr::get(context, Convention::Internal), ports);
ConventionAttr::get(context, Convention::Internal, Convention::Internal),
ports);
SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);

// Create an instance of the external memory module. The instance has the
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/FIRRTL/Transforms/LowerSignatures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ void LowerSignaturesPass::runOnOperation() {
auto circuit = getOperation();

for (auto mod : circuit.getOps<FModuleLike>()) {
if (lowerModuleSignature(mod, mod.getConvention(), cache,
if (lowerModuleSignature(mod, mod.getPortConvention(), cache,
portMap[mod.getNameAttr()])
.failed())
return signalPassFailure();
Expand Down
Loading

0 comments on commit 2160b4d

Please sign in to comment.