@@ -339,12 +339,17 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
339339
340340 TypeLoweringVisitor (
341341 MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate,
342+ Convention bodyConvention,
342343 PreserveAggregate::PreserveMode memoryPreservationMode,
343344 SymbolTable &symTbl, const AttrCache &cache,
344345 const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
345- : context(context), aggregatePreservationMode (preserveAggregate),
346+ : context(context), defaultAggregatePreservationMode (preserveAggregate),
346347 memoryPreservationMode (memoryPreservationMode), symTbl(symTbl),
347- cache(cache), conventionTable(conventionTable) {}
348+ cache(cache), conventionTable(conventionTable) {
349+ bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
350+ ? PreserveAggregate::None
351+ : defaultAggregatePreservationMode;
352+ }
348353 using FIRRTLVisitor<TypeLoweringVisitor, bool >::visitDecl;
349354 using FIRRTLVisitor<TypeLoweringVisitor, bool >::visitExpr;
350355 using FIRRTLVisitor<TypeLoweringVisitor, bool >::visitStmt;
@@ -429,7 +434,7 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
429434 Location errorLoc);
430435
431436 PreserveAggregate::PreserveMode
432- getPreservationModeForModule (FModuleLike moduleLike);
437+ getPreservationModeForPorts (FModuleLike moduleLike);
433438 Value getSubWhatever (Value val, size_t index);
434439
435440 size_t uniqueIdx = 0 ;
@@ -441,7 +446,8 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
441446 MLIRContext *context;
442447
443448 // / Aggregate preservation mode.
444- PreserveAggregate::PreserveMode aggregatePreservationMode;
449+ PreserveAggregate::PreserveMode defaultAggregatePreservationMode;
450+ PreserveAggregate::PreserveMode bodyAggregatePreservationMode;
445451 PreserveAggregate::PreserveMode memoryPreservationMode;
446452
447453 // / The builder is set and maintained in the main loop.
@@ -460,21 +466,21 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
460466};
461467} // namespace
462468
463- // / Return aggregate preservation mode for the module. If the module has a
469+ // / Return aggregate preservation mode for the module ports . If the module has a
464470// / scalarized linkage, then we may not preserve it's aggregate ports.
465471PreserveAggregate::PreserveMode
466- TypeLoweringVisitor::getPreservationModeForModule (FModuleLike module ) {
472+ TypeLoweringVisitor::getPreservationModeForPorts (FModuleLike module ) {
467473 auto lookup = conventionTable.find (module );
468474 if (lookup == conventionTable.end ())
469- return aggregatePreservationMode ;
475+ return defaultAggregatePreservationMode ;
470476 switch (lookup->second ) {
471477 case Convention::Scalarized:
472478 return PreserveAggregate::None;
473479 case Convention::Internal:
474- return aggregatePreservationMode ;
480+ return defaultAggregatePreservationMode ;
475481 }
476482 llvm_unreachable (" Unknown convention" );
477- return aggregatePreservationMode ;
483+ return defaultAggregatePreservationMode ;
478484}
479485
480486Value TypeLoweringVisitor::getSubWhatever (Value val, size_t index) {
@@ -643,7 +649,7 @@ bool TypeLoweringVisitor::lowerProducer(
643649 return false ;
644650 SmallVector<FlatBundleFieldEntry, 8 > fieldTypes;
645651
646- if (!peelType (srcFType, fieldTypes, aggregatePreservationMode ))
652+ if (!peelType (srcFType, fieldTypes, bodyAggregatePreservationMode ))
647653 return false ;
648654
649655 SmallVector<Value> lowered;
@@ -809,7 +815,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
809815 // Flatten any bundle types.
810816 SmallVector<FlatBundleFieldEntry> fieldTypes;
811817 auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].pi .type );
812- if (!peelType (srcType, fieldTypes, getPreservationModeForModule (module )))
818+ if (!peelType (srcType, fieldTypes, getPreservationModeForPorts (module )))
813819 return false ;
814820
815821 // Ports with internalPath set cannot be lowered.
@@ -929,7 +935,7 @@ bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
929935 // Attempt to get the bundle types.
930936 SmallVector<FlatBundleFieldEntry> fields;
931937
932- if (!peelType (op.getDest ().getType (), fields, aggregatePreservationMode ))
938+ if (!peelType (op.getDest ().getType (), fields, bodyAggregatePreservationMode ))
933939 return false ;
934940
935941 // Loop over the leaf aggregates.
@@ -1454,7 +1460,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
14541460 SmallVector<Direction> newDirs;
14551461 SmallVector<Attribute> newNames;
14561462 SmallVector<Attribute> newPortAnno;
1457- PreserveAggregate::PreserveMode mode = getPreservationModeForModule (
1463+ PreserveAggregate::PreserveMode mode = getPreservationModeForPorts (
14581464 cast<FModuleLike>(op.getReferencedOperation (symTbl)));
14591465
14601466 endFields.push_back (0 );
@@ -1667,9 +1673,15 @@ void LowerTypesPass::runOnOperation() {
16671673
16681674 // This lambda, executes in parallel for each Op within the circt.
16691675 auto lowerModules = [&](FModuleLike op) -> LogicalResult {
1676+ // Use body type lowering attribute if it exists, otherwise use internal.
1677+ Convention convention = Convention::Internal;
1678+ if (auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
1679+ op->getDiscardableAttr (" body_type_lowering" )))
1680+ convention = conventionAttr.getValue ();
1681+
16701682 auto tl =
1671- TypeLoweringVisitor (&getContext (), preserveAggregate, preserveMemories ,
1672- symTbl, cache, conventionTable);
1683+ TypeLoweringVisitor (&getContext (), preserveAggregate, convention ,
1684+ preserveMemories, symTbl, cache, conventionTable);
16731685 tl.lowerModule (op);
16741686
16751687 return LogicalResult::failure (tl.isFailed ());
0 commit comments