@@ -268,17 +268,14 @@ static std::optional<Convention> parseConvention(llvm::StringRef str) {
268268 .Default (std::nullopt );
269269}
270270
271- template <bool IsConventionAnno>
272- static LogicalResult
273- applyConventionOrTypeLoweringAnno (const AnnoPathValue &target,
274- DictionaryAttr anno, ApplyState &state) {
271+ static LogicalResult applyConventionAnno (const AnnoPathValue &target,
272+ DictionaryAttr anno,
273+ ApplyState &state) {
275274 auto *op = target.ref .getOp ();
276275 auto loc = op->getLoc ();
277276 auto error = [&]() {
278277 auto diag = mlir::emitError (loc);
279- diag << (IsConventionAnno ? " circuit.ConventionAnnotation "
280- : " circuit.TypeLoweringAnnotation " )
281- << " " ;
278+ diag << " circuit.ConventionAnnotation " ;
282279 return diag;
283280 };
284281
@@ -301,47 +298,85 @@ applyConventionOrTypeLoweringAnno(const AnnoPathValue &target,
301298
302299 auto convention = *conventionOpt;
303300
304- if (convention == Convention::Internal)
305- // Convention is internal by default so there is nothing to change
306- return success ();
307-
308- auto includeHierarchy = anno.getAs <BoolAttr>(" includeHierarchy" );
309- auto conventionAttr = ConventionAttr::get (op->getContext (), convention);
310- auto setConvention = [&](Operation *moduleOp) {
311- TypeSwitch<Operation *>(moduleOp)
312- .Case <FModuleOp, FExtModuleOp>([&](auto moduleOp) {
313- if (IsConventionAnno)
314- moduleOp.setConventionAttr (conventionAttr);
315- else
316- moduleOp->setDiscardableAttr (" body_type_lowering" , conventionAttr);
317- })
318- .Default ([](auto ) {});
319- };
320-
321301 if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
322- if (includeHierarchy && includeHierarchy.getValue ()) {
323- // If includeHierarchy is true, update the convention for all modules in
324- // the hierarchy.
325- for (auto *node :
326- llvm::post_order (state.instancePathCache .instanceGraph [moduleOp])) {
327- if (node && isa<FModuleOp, FExtModuleOp>(*node->getModule ()))
328- setConvention (node->getModule ());
329- }
330- } else {
331- // Update the convention.
332- setConvention (moduleOp);
333- }
302+ moduleOp.setConvention (convention);
334303 return success ();
335304 }
336305
337306 if (auto extModuleOp = dyn_cast<FExtModuleOp>(op)) {
338- setConvention (extModuleOp );
307+ extModuleOp. setConvention (convention );
339308 return success ();
340309 }
341310
342311 return error () << " can only target to a module or extmodule" ;
343312}
344313
314+ static LogicalResult applyBodyTypeLoweringAnno (const AnnoPathValue &target,
315+ DictionaryAttr anno,
316+ ApplyState &state) {
317+ auto *op = target.ref .getOp ();
318+ auto loc = op->getLoc ();
319+ auto error = [&]() {
320+ auto diag = mlir::emitError (loc);
321+ diag << typeLoweringAnnoClass;
322+ return diag;
323+ };
324+
325+ auto opTarget = dyn_cast<OpAnnoTarget>(target.ref );
326+ if (!opTarget)
327+ return error () << " must target a module object" ;
328+
329+ if (!target.isLocal ())
330+ return error () << " must be local" ;
331+
332+ auto moduleOp = dyn_cast<FModuleOp>(op);
333+
334+ if (!moduleOp)
335+ return error () << " can only target to a module" ;
336+
337+ auto conventionStrAttr =
338+ tryGetAs<StringAttr>(anno, anno, " convention" , loc, conventionAnnoClass);
339+
340+ if (!conventionStrAttr)
341+ return failure ();
342+
343+ auto conventionStr = conventionStrAttr.getValue ();
344+ auto conventionOpt = parseConvention (conventionStr);
345+ if (!conventionOpt)
346+ return error () << " unknown convention " << conventionStr;
347+
348+ auto convention = *conventionOpt;
349+
350+ if (convention == Convention::Internal)
351+ // Convention is internal by default so there is nothing to change
352+ return success ();
353+
354+ auto conventionAttr = ConventionAttr::get (op->getContext (), convention);
355+
356+ // `includeHierarchy` only valid in BodyTypeLowering.
357+ bool includeHierarchy = false ;
358+ if (auto includeHierarchyAttr = tryGetAs<BoolAttr>(
359+ anno, anno, " includeHierarchy" , loc, conventionAnnoClass))
360+ includeHierarchy = includeHierarchyAttr.getValue ();
361+
362+ if (includeHierarchy) {
363+ // If includeHierarchy is true, update the convention for all modules in
364+ // the hierarchy.
365+ for (auto *node :
366+ llvm::post_order (state.instancePathCache .instanceGraph [moduleOp])) {
367+ if (!node)
368+ continue ;
369+ if (auto fmodule = dyn_cast<FModuleOp>(*node->getModule ()))
370+ fmodule->setAttr (" body_type_lowering" , conventionAttr);
371+ }
372+ } else {
373+ // Update the convention.
374+ moduleOp->setAttr (" body_type_lowering" , conventionAttr);
375+ }
376+
377+ return success ();
378+ }
379+
345380static LogicalResult applyModulePrefixAnno (const AnnoPathValue &target,
346381 DictionaryAttr anno,
347382 ApplyState &state) {
@@ -583,10 +618,8 @@ static llvm::StringMap<AnnoRecord> annotationRecords{{
583618 {memTapPortClass, {stdResolve, applyWithoutTarget<true >}},
584619 {memTapBlackboxClass, {stdResolve, applyWithoutTarget<true >}},
585620 // Miscellaneous Annotations
586- {conventionAnnoClass,
587- {stdResolve, applyConventionOrTypeLoweringAnno<true >}},
588- {typeLoweringAnnoClass,
589- {stdResolve, applyConventionOrTypeLoweringAnno<false >}},
621+ {conventionAnnoClass, {stdResolve, applyConventionAnno}},
622+ {typeLoweringAnnoClass, {stdResolve, applyBodyTypeLoweringAnno}},
590623 {dontTouchAnnoClass,
591624 {stdResolve, applyWithoutTarget<true , true , WireOp, NodeOp, RegOp,
592625 RegResetOp, InstanceOp, MemOp, CombMemOp,
0 commit comments