@@ -20,18 +20,22 @@ namespace nvfuser {
2020namespace {
2121
2222struct LoopInfo {
23- hir::ForLoop* loop;
23+ hir::ForLoop* loop = nullptr ;
2424
2525 // The Scope that owns `loop`. It's one level outer than `loop`'s body scope.
26- Scope* parent_scope;
26+ Scope* parent_scope = nullptr ;
2727
2828 // The iterator that points to `loop`. This way, we can insert instructions,
2929 // e.g. Allocate, right before the loop.
3030 Scope::Iterator parent_insertion_point;
3131};
3232
3333std::ostream& operator <<(std::ostream& os, const LoopInfo& loop_info) {
34- os << loop_info.loop ->toInlineString ();
34+ if (loop_info.loop == nullptr ) {
35+ os << " <null>" ;
36+ } else {
37+ os << loop_info.loop ->toInlineString ();
38+ }
3539 return os;
3640}
3741
@@ -131,7 +135,7 @@ Expr* cloneWithNewOperands(
131135 int64_t out_replaced = std::ranges::count_if (new_outs, maybe_replace);
132136
133137 if (in_replaced == 0 && out_replaced == 0 ) {
134- return 0 ;
138+ return e ;
135139 }
136140
137141 if (out_replaced > 0 ) {
@@ -151,6 +155,14 @@ void lowerSegment(
151155 hir::HostIrContainer& hic,
152156 LoopNest& loop_nest,
153157 IrCloner& ir_cloner) {
158+ Scope& innermost_scope = loop_nest.innermostScope ();
159+ // FIXME: cleanup. innermost can return an empty LoopInfo when the nest is
160+ // empty.
161+ LoopInfo innermost;
162+ if (!loop_nest.empty ()) {
163+ innermost = loop_nest.innermost ();
164+ }
165+
154166 switch (group.schedulerType ()) {
155167 case SchedulerType::Communication: {
156168 auto device_id = Communicator::getInstance ().deviceId ();
@@ -162,24 +174,50 @@ void lowerSegment(
162174 // without cloning the value again.
163175 Expr* e = ir_cloner.clone (group.exprs ().front ());
164176
165- for (auto * c : convertSingleOpToCommunication (e, device_id)) {
177+ // FIXME: should this be associated with the scope?
178+ std::unordered_map<Val*, Val*> replacement_map;
179+ for (Expr* c : convertSingleOpToCommunication (e, device_id)) {
166180 NVF_ERROR (
167181 c->isA <Communication>(),
168182 " Exprs in a Communication group should be Communication: " ,
169183 c);
170- // Allocate the recv buffers of communications
171184 auto * communication = c->as <Communication>();
172- TensorView* tv = communication->out ();
173- if (tv->getDeviceMesh ().has (device_id)) {
174- auto * allocate =
175- IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
176- // TODO: allocation may have to go to the top level. See how
177- // SchedulerType::ExprEval handles allocations.
178- loop_nest.innermostScope ().push_back (allocate);
185+ TensorView* in = communication->in ();
186+ TensorView* out = communication->out ();
187+ if (getShardedIterDomain (in, ParallelType::Stream, DomainType::kLoop ) !=
188+ nullptr &&
189+ getShardedIterDomain (
190+ in, ParallelType::Stream, DomainType::kAllocation ) == nullptr ) {
191+ auto [i, inserted] = replacement_map.try_emplace (
192+ in, hir::shardByStream (in, innermost.loop ->index ()));
193+ if (inserted) {
194+ innermost_scope.push_back (i->second ->definition ());
195+ }
179196 }
180- loop_nest.innermostScope ().push_back (communication);
181- auto wait = IrBuilder::create<hir::Wait>(communication);
182- loop_nest.innermostScope ().push_back (wait);
197+
198+ // Allocate the recv buffers of communications
199+ auto * allocate =
200+ IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
201+ if (getShardedIterDomain (
202+ out, ParallelType::Stream, DomainType::kLoop ) != nullptr &&
203+ getShardedIterDomain (
204+ out, ParallelType::Stream, DomainType::kAllocation ) ==
205+ nullptr ) {
206+ innermost.parent_scope ->insert (
207+ innermost.parent_insertion_point , allocate);
208+ auto [i, inserted] = replacement_map.try_emplace (
209+ out, hir::shardByStream (out, innermost.loop ->index ()));
210+ NVF_ERROR (inserted);
211+ innermost_scope.push_back (i->second ->definition ());
212+ } else {
213+ innermost_scope.push_back (allocate);
214+ }
215+
216+ Expr* new_c = cloneWithNewOperands (c, replacement_map);
217+ innermost_scope.push_back (new_c);
218+
219+ auto * wait = IrBuilder::create<hir::Wait>(new_c);
220+ innermost_scope.push_back (wait);
183221 }
184222 break ;
185223 }
@@ -211,14 +249,11 @@ void lowerSegment(
211249 // TensorViews.
212250 if (loop_nest.empty ()) {
213251 for (Expr* e : exprs) {
214- loop_nest. innermostScope () .push_back (e);
252+ innermost_scope .push_back (e);
215253 }
216254 break ;
217255 }
218256
219- auto [for_loop, parent_scope, parent_insertion_point] =
220- loop_nest.innermost ();
221-
222257 std::unordered_map<Val*, Val*> replacement_map;
223258 for (Expr* e : exprs) {
224259 for (auto * in : ir_utils::filterByType<TensorView>(e->inputs ())) {
@@ -228,9 +263,9 @@ void lowerSegment(
228263 in, ParallelType::Stream, DomainType::kAllocation ) ==
229264 nullptr ) {
230265 auto [i, inserted] = replacement_map.try_emplace (
231- in, hir::shardByStream (in, for_loop ->index ()));
266+ in, hir::shardByStream (in, innermost. loop ->index ()));
232267 if (inserted) {
233- for_loop-> body () .push_back (i->second ->definition ());
268+ innermost_scope .push_back (i->second ->definition ());
234269 }
235270 }
236271 }
@@ -241,21 +276,22 @@ void lowerSegment(
241276 nullptr ) {
242277 auto * allocate =
243278 IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
244- parent_scope->insert (parent_insertion_point, allocate);
279+ innermost.parent_scope ->insert (
280+ innermost.parent_insertion_point , allocate);
245281 // Loop is stream parallelized but allocation is not. Therefore,
246282 // `out` should be allocated outside the loop.
247283 //
248284 // I use try_emplace here so shardByStream is called only when `out`
249285 // is missing.
250286 auto [i, inserted] = replacement_map.try_emplace (
251- out, hir::shardByStream (out, for_loop ->index ()));
287+ out, hir::shardByStream (out, innermost. loop ->index ()));
252288 NVF_ERROR (inserted);
253- for_loop-> body () .push_back (i->second ->definition ());
289+ innermost_scope .push_back (i->second ->definition ());
254290 }
255291 }
256292
257293 Expr* new_e = cloneWithNewOperands (e, replacement_map);
258- for_loop-> body () .push_back (new_e);
294+ innermost_scope .push_back (new_e);
259295 }
260296 break ;
261297 }
@@ -280,7 +316,7 @@ void lowerSegment(
280316 auto * tv = out->as <TensorView>();
281317 auto * allocate =
282318 IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
283- loop_nest. innermostScope () .push_back (allocate);
319+ innermost_scope .push_back (allocate);
284320 }
285321
286322 // Add the LaunchKernel instruction.
@@ -296,7 +332,7 @@ void lowerSegment(
296332 ins,
297333 outs,
298334 cache_id);
299- loop_nest. innermostScope () .push_back (launch_kernel);
335+ innermost_scope .push_back (launch_kernel);
300336 }
301337 } // switch
302338} // lowerSegment
0 commit comments