@@ -332,12 +332,17 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
332332
333333 TypeLoweringVisitor (
334334 MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate,
335+ Convention bodyConvention,
335336 PreserveAggregate::PreserveMode memoryPreservationMode,
336337 SymbolTable &symTbl, const AttrCache &cache,
337338 const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
338- : context(context), aggregatePreservationMode (preserveAggregate),
339+ : context(context), defaultAggregatePreservationMode (preserveAggregate),
339340 memoryPreservationMode (memoryPreservationMode), symTbl(symTbl),
340- cache(cache), conventionTable(conventionTable) {}
341+ cache(cache), conventionTable(conventionTable) {
342+ bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
343+ ? PreserveAggregate::None
344+ : defaultAggregatePreservationMode;
345+ }
341346 using FIRRTLVisitor<TypeLoweringVisitor, bool >::visitDecl;
342347 using FIRRTLVisitor<TypeLoweringVisitor, bool >::visitExpr;
343348 using FIRRTLVisitor<TypeLoweringVisitor, bool >::visitStmt;
@@ -422,7 +427,7 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
422427 Location errorLoc);
423428
424429 PreserveAggregate::PreserveMode
425- getPreservationModeForModule (FModuleLike moduleLike);
430+ getPreservationModeForPorts (FModuleLike moduleLike);
426431 Value getSubWhatever (Value val, size_t index);
427432
428433 size_t uniqueIdx = 0 ;
@@ -434,7 +439,8 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
434439 MLIRContext *context;
435440
436441 // / Aggregate preservation mode.
437- PreserveAggregate::PreserveMode aggregatePreservationMode;
442+ PreserveAggregate::PreserveMode defaultAggregatePreservationMode;
443+ PreserveAggregate::PreserveMode bodyAggregatePreservationMode;
438444 PreserveAggregate::PreserveMode memoryPreservationMode;
439445
440446 // / The builder is set and maintained in the main loop.
@@ -453,21 +459,21 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
453459};
454460} // namespace
455461
456- // / Return aggregate preservation mode for the module. If the module has a
462+ // / Return aggregate preservation mode for the module ports . If the module has a
457463// / scalarized linkage, then we may not preserve it's aggregate ports.
458464PreserveAggregate::PreserveMode
459- TypeLoweringVisitor::getPreservationModeForModule (FModuleLike module ) {
465+ TypeLoweringVisitor::getPreservationModeForPorts (FModuleLike module ) {
460466 auto lookup = conventionTable.find (module );
461467 if (lookup == conventionTable.end ())
462- return aggregatePreservationMode ;
468+ return defaultAggregatePreservationMode ;
463469 switch (lookup->second ) {
464470 case Convention::Scalarized:
465471 return PreserveAggregate::None;
466472 case Convention::Internal:
467- return aggregatePreservationMode ;
473+ return defaultAggregatePreservationMode ;
468474 }
469475 llvm_unreachable (" Unknown convention" );
470- return aggregatePreservationMode ;
476+ return defaultAggregatePreservationMode ;
471477}
472478
473479Value TypeLoweringVisitor::getSubWhatever (Value val, size_t index) {
@@ -636,7 +642,7 @@ bool TypeLoweringVisitor::lowerProducer(
636642 return false ;
637643 SmallVector<FlatBundleFieldEntry, 8 > fieldTypes;
638644
639- if (!peelType (srcFType, fieldTypes, aggregatePreservationMode ))
645+ if (!peelType (srcFType, fieldTypes, bodyAggregatePreservationMode ))
640646 return false ;
641647
642648 SmallVector<Value> lowered;
@@ -805,7 +811,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
805811 // Flatten any bundle types.
806812 SmallVector<FlatBundleFieldEntry> fieldTypes;
807813 auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].pi .type );
808- if (!peelType (srcType, fieldTypes, getPreservationModeForModule (module )))
814+ if (!peelType (srcType, fieldTypes, getPreservationModeForPorts (module )))
809815 return false ;
810816
811817 // Ports with internalPath set cannot be lowered.
@@ -925,7 +931,7 @@ bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
925931 // Attempt to get the bundle types.
926932 SmallVector<FlatBundleFieldEntry> fields;
927933
928- if (!peelType (op.getDest ().getType (), fields, aggregatePreservationMode ))
934+ if (!peelType (op.getDest ().getType (), fields, bodyAggregatePreservationMode ))
929935 return false ;
930936
931937 // Loop over the leaf aggregates.
@@ -1458,7 +1464,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
14581464 SmallVector<Direction> newDirs;
14591465 SmallVector<Attribute> newNames;
14601466 SmallVector<Attribute> newPortAnno;
1461- PreserveAggregate::PreserveMode mode = getPreservationModeForModule (
1467+ PreserveAggregate::PreserveMode mode = getPreservationModeForPorts (
14621468 cast<FModuleLike>(op.getReferencedOperation (symTbl)));
14631469
14641470 endFields.push_back (0 );
@@ -1662,9 +1668,15 @@ void LowerTypesPass::runOnOperation() {
16621668
16631669 // This lambda, executes in parallel for each Op within the circt.
16641670 auto lowerModules = [&](FModuleLike op) -> LogicalResult {
1671+ // Use body type lowering attribute if it exists, otherwise use internal.
1672+ Convention convention = Convention::Internal;
1673+ if (auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
1674+ op->getDiscardableAttr (" body_type_lowering" )))
1675+ convention = conventionAttr.getValue ();
1676+
16651677 auto tl =
1666- TypeLoweringVisitor (&getContext (), preserveAggregate, preserveMemories ,
1667- symTbl, cache, conventionTable);
1678+ TypeLoweringVisitor (&getContext (), preserveAggregate, convention ,
1679+ preserveMemories, symTbl, cache, conventionTable);
16681680 tl.lowerModule (op);
16691681
16701682 return LogicalResult::failure (tl.isFailed ());
0 commit comments