From 9f916efeded3f9a9de7a79eb1ff12bfa335fc692 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 12 Nov 2025 14:31:10 -0800 Subject: [PATCH 1/6] WIP --- csrc/host_ir/container.cpp | 4 - csrc/host_ir/container.h | 7 +- csrc/host_ir/lowering.cpp | 165 +++++++++++++++++++++++++++++++++++-- 3 files changed, 162 insertions(+), 14 deletions(-) diff --git a/csrc/host_ir/container.cpp b/csrc/host_ir/container.cpp index d1a7291e286..5b4827d848f 100644 --- a/csrc/host_ir/container.cpp +++ b/csrc/host_ir/container.cpp @@ -36,10 +36,6 @@ std::ostream& HostIrContainer::print(std::ostream& os) const { return os; } -const Scope& HostIrContainer::topLevel() const { - return top_level_; -} - void HostIrContainer::resetTopLevelExprs(std::list exprs) { top_level_.mutableExprs() = std::move(exprs); } diff --git a/csrc/host_ir/container.h b/csrc/host_ir/container.h index 87e859d88c6..f797b3ab827 100644 --- a/csrc/host_ir/container.h +++ b/csrc/host_ir/container.h @@ -29,7 +29,12 @@ class HostIrContainer final : public Fusion { // Print to an output stream std::ostream& print(std::ostream& os) const; - const Scope& topLevel() const; + const Scope& topLevel() const { + return top_level_; + } + Scope& topLevel() { + return top_level_; + } const Scope::ExprList& topLevelExprs() const { return topLevel().exprs(); } diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index f71280d0db3..85b70218a8c 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -17,6 +17,74 @@ namespace nvfuser { +namespace hir { +// FIXME: rename to LoopInfo? +// FIXME: can this be nested in LoopNest? +struct Frame { + ForLoop* loop; + Scope* parent_scope; + Scope::Iterator parent_insertion_point; + + friend std::ostream& operator<<(std::ostream& os, const Frame& frame); +}; + +class LoopNest { + private: + public: + LoopNest(Scope& top_level) : top_level_(top_level) {} + + int64_t size() const { + return std::ssize(frames_); + } + + bool empty() const { + return frames_.empty(); + } + + void closeLoop() { + NVF_ERROR(!empty()); + frames_.pop_back(); + } + + Frame& innermost() { + NVF_ERROR(!empty()); + return frames_.back(); + } + + Scope& innermostScope() { + return empty() ? top_level_ : innermost().loop->body(); + } + + ForLoop* openLoop(IterDomain* id) { + Scope& parent_scope = innermostScope(); + auto* for_loop = ForLoop::createFromIterDomain(id); + frames_.push_back( + {for_loop, &parent_scope, parent_scope.push_back(for_loop)}); + return for_loop; + } + + friend std::ostream& operator<<(std::ostream& os, const LoopNest& loop_nest); + + private: + std::vector frames_; + Scope& top_level_; +}; + +std::ostream& operator<<(std::ostream& os, const Frame& frame) { + os << frame.loop->toInlineString(); + return os; +} + +std::ostream& operator<<(std::ostream& os, const LoopNest& loop_nest) { + os << "LoopNest:" << std::endl; + for (const auto& frame : loop_nest.frames_) { + indent(os, 1) << frame << frame.loop->toString() << std::endl; + } + return os; +} + +} // namespace hir + namespace { // Finds the stream-parallelized IterDomain in the loop domain of a TensorView, // or nullptr if not found. This is different from `getShardedIterDomain(tv, @@ -47,6 +115,7 @@ void lowerSegment( const AliasInfoMap& aliases, const LaunchParams& launch_params, hir::HostIrContainer& hic, + hir::LoopNest& loop_nest, IrCloner& ir_cloner) { switch (group.schedulerType()) { case SchedulerType::Communication: { @@ -70,11 +139,12 @@ void lowerSegment( if (tv->getDeviceMesh().has(device_id)) { auto* allocate = IrBuilder::create(tv, MemoryType::Global); - hic.pushBackTopLevelExprs(allocate); + // FIXME: allocation may have to go to the top level. + loop_nest.innermostScope().push_back(allocate); } - hic.pushBackTopLevelExprs(communication); + loop_nest.innermostScope().push_back(communication); auto wait = IrBuilder::create(communication); - hic.pushBackTopLevelExprs(wait); + loop_nest.innermostScope().push_back(wait); } break; } @@ -109,13 +179,13 @@ void lowerSegment( IterDomain* stream_id = findStreamIterDomain(outs); if (stream_id == nullptr) { for (Expr* e : exprs) { - hic.pushBackTopLevelExprs(e); + loop_nest.innermostScope().push_back(e); } break; } - auto* for_loop = hir::ForLoop::createFromIterDomain(stream_id); - auto top_level_insertion_point = hic.pushBackTopLevelExprs(for_loop); + auto [for_loop, parent_scope, parent_insertion_point] = + loop_nest.innermost(); std::unordered_map replacement_map; for (Expr* e : exprs) { @@ -134,9 +204,12 @@ void lowerSegment( if (getShardedIterDomain(out, ParallelType::Stream) == nullptr) { auto* allocate = IrBuilder::create(out, MemoryType::Global); - hic.insertExprBefore(top_level_insertion_point, allocate); + parent_scope->insert(parent_insertion_point, allocate); // Loop is stream parallelized but allocation is not. Therefore, // `out` should be allocated outside the loop. + // + // I use try_emplace here so shardByStream is called only when `out` + // is missing. auto [i, inserted] = replacement_map.try_emplace( out, hir::shardByStream(out, for_loop->index())); NVF_ERROR(inserted); @@ -187,7 +260,7 @@ void lowerSegment( auto* tv = out->as(); auto* allocate = IrBuilder::create(tv, MemoryType::Global); - hic.pushBackTopLevelExprs(allocate); + loop_nest.innermostScope().push_back(allocate); } // Add the LaunchKernel instruction. @@ -203,10 +276,56 @@ void lowerSegment( ins, outs, cache_id); - hic.pushBackTopLevelExprs(launch_kernel); + loop_nest.innermostScope().push_back(launch_kernel); } } // switch } // lowerSegment + +// Finds the TensorView in the group whose loop domain has the most parallel +// types and returns its loop domain. +const std::vector& findReferenceLoopDomain( + const SegmentedGroup& group) { + TensorView* reference_tv = nullptr; + int max_parallel_count = -1; + for (auto* expr : group.exprs()) { + for (auto* tv : ir_utils::filterByType(expr->outputs())) { + auto loop_domain = tv->getLoopDomain(); + int parallel_count = 0; + for (auto* id : loop_domain) { + if (id->isParallelized()) { + parallel_count++; + } + } + if (parallel_count > max_parallel_count) { + max_parallel_count = parallel_count; + reference_tv = tv; + } + } + } + NVF_ERROR(reference_tv != nullptr); + return reference_tv->getLoopDomain(); +} + +int64_t computeInlinePosition( + const std::vector& prev_ref_loop, + const std::vector& curr_ref_loop, + const IdModel& id_model) { + const auto& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + int64_t inline_position = 0; + for (auto [prev_id, curr_id] : zip(prev_ref_loop, curr_ref_loop)) { + if (prev_id->getParallelType() != curr_id->getParallelType()) { + break; + } + + if (!exact_graph.disjointValSets().strictAreMapped(prev_id, curr_id)) { + break; + } + + inline_position++; + } + + return inline_position; +} } // namespace std::unique_ptr lowerSegmentedFusionToHostIr( @@ -227,14 +346,42 @@ std::unique_ptr lowerSegmentedFusionToHostIr( hic->addKernelExecutor(std::unique_ptr(ke)); } + hir::LoopNest loop_nest(hic->topLevel()); + + IdModel id_model(segmented_fusion.completeFusion(), /*build_graphs=*/false); + id_model.buildExactGraph(); + + std::vector prev_ref_loop; for (SegmentedGroup* group : prepareRuntimeOrder(segmented_fusion).group_run_order) { + // Compute the inline position. + const std::vector& curr_ref_loop = + findReferenceLoopDomain(*group); + // Compute the inline position based on parallel type and ID mapping. + + const int64_t inline_position = + computeInlinePosition(prev_ref_loop, curr_ref_loop, id_model); + + while (loop_nest.size() > inline_position) { + loop_nest.closeLoop(); + } + while (loop_nest.size() < std::ssize(curr_ref_loop) && + curr_ref_loop.at(loop_nest.size())->isStream()) { + auto* stream_id = ir_cloner.clone(curr_ref_loop.at(loop_nest.size())); + loop_nest.openLoop(stream_id); + } + + // FIXME: consider making HostIrLowering a class so `hic` and + // `segmented_fusion` can be made global. lowerSegment( *group, segmented_fusion.completeFusion()->getOutputAliases(), launch_params_per_segment.at(group->groupId()), *hic, + loop_nest, ir_cloner); + + prev_ref_loop = std::move(curr_ref_loop); } hir_pass::InsertDeallocations().runPass(hic.get()); From 9cc28d1128910a6652a715221ee5e4280a7169e5 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 13 Nov 2025 12:54:23 -0800 Subject: [PATCH 2/6] Test two matmuls --- csrc/runtime/allocations.cpp | 3 ++- tests/cpp/test_stream.cpp | 45 ++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index 4bf20f3fc72..85b36125027 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -829,7 +829,8 @@ std::pair, std::vector> inferAllocationShape( continue; } - if (id->isDeviceDim()) { + // FIXME: should this be isParallelized? + if (id->isDeviceDim() || id->isStream()) { symbolic_sizes.push_back(id->container()->oneVal()); } else { symbolic_sizes.push_back(id->getMaybeExpandedExtent()); diff --git a/tests/cpp/test_stream.cpp b/tests/cpp/test_stream.cpp index 68b6d2296f7..281fc6ffb9a 100644 --- a/tests/cpp/test_stream.cpp +++ b/tests/cpp/test_stream.cpp @@ -110,6 +110,51 @@ TEST_F(StreamTest, Matmul) { __FILE__); } +TEST_F(StreamTest, TwoMatmuls) { + constexpr int64_t c = 3; + + auto fusion = std::make_unique(); + { + FusionGuard fg(fusion.get()); + TensorView* in = makeSymbolicTensor(2); + TensorView* w1 = makeSymbolicTensor(2); + TensorView* w2 = makeSymbolicTensor(2); + TensorView* out = matmul(in, w1); + out = matmul(out, w2); + fusion->addInput(in); + fusion->addInput(w1); + fusion->addInput(w2); + fusion->addOutput(out); + + in->outer_split(0, c); + in->axis(0)->parallelize(ParallelType::Stream); + } + + { + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA); + at::Tensor in = at::randn({c * 2, 3}, options); + at::Tensor w1 = at::randn({3, 5}, options); + at::Tensor w2 = at::randn({5, 3}, options); + + // With NVFUSER_DUMP=host_ir, you'll see the host IR container like the + // following: + // clang-format off + // %HostIrContainer { (T0_g_float[iS0{i0}, iS1{i2}], T1_g_float[istreamIdx7{3}, iS11{i2}, iS8{( ceilDiv(i4, 3) )}]) -> (T2_g_float[istreamIdx9{3}, iS4{i0}, iS10{( ceilDiv(i4, 3) )}, rS6{i2}]) : + // FOR i18 from 0 to 3: + // T2_g_float[istreamIdx9{3}, iS4{i0}, iS10{( ceilDiv(i4, 3) )}, rS6{i2}] + // = matmul(T0_g_float[iS0{i0}, iS1{i2}], + // T1_g_float[istreamIdx7{3}, iS11{i2}, iS8{( ceilDiv(i4, 3) )}]) + // } // %HostIrContainer + // clang-format on + FusionExecutorCache executor_cache(std::move(fusion)); + auto out = + executor_cache.runFusionWithInputs({in, w1, w2})[0].as(); + + testValidate( + executor_cache.fusion(), {out}, {in, w1, w2}, __LINE__, __FILE__); + } +} + TEST_F(StreamTest, HaveDifferentShardings) { Fusion fusion; FusionGuard fg(&fusion); From a10162aae1254ee11b9a039acbbd65a13fcf2c36 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 13 Nov 2025 13:16:19 -0800 Subject: [PATCH 3/6] Hack for MultiDeviceExecutor --- csrc/runtime/allocations.cpp | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index 85b36125027..b78c946571d 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -829,12 +829,21 @@ std::pair, std::vector> inferAllocationShape( continue; } - // FIXME: should this be isParallelized? - if (id->isDeviceDim() || id->isStream()) { - symbolic_sizes.push_back(id->container()->oneVal()); - } else { - symbolic_sizes.push_back(id->getMaybeExpandedExtent()); - } + symbolic_sizes.push_back([&]() { + if (id->isDeviceDim()) { + return id->container()->oneVal(); + } + + if (id->isStream()) { + // Hack for MultiDeviceExecutor. + if (std::ranges::find(tv->getLogicalDomain(), id) == + tv->getLogicalDomain().end()) { + return id->container()->oneVal(); + } + } + return id->getMaybeExpandedExtent(); + }()); + if (id->hasExpandedExtent()) { NVF_ERROR( id->isBroadcast(), From 5c861d7f0bbd52d4b7d34636f9c1699efe1ab892 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 13 Nov 2025 16:31:30 -0800 Subject: [PATCH 4/6] Cleanup --- csrc/host_ir/lowering.cpp | 53 ++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 85b70218a8c..983273ecfc0 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -17,48 +17,51 @@ namespace nvfuser { -namespace hir { -// FIXME: rename to LoopInfo? -// FIXME: can this be nested in LoopNest? -struct Frame { - ForLoop* loop; +namespace { + +struct LoopInfo { + hir::ForLoop* loop; Scope* parent_scope; Scope::Iterator parent_insertion_point; - friend std::ostream& operator<<(std::ostream& os, const Frame& frame); + friend std::ostream& operator<<(std::ostream& os, const LoopInfo& loop_info); }; +std::ostream& operator<<(std::ostream& os, const LoopInfo& loop_info) { + os << loop_info.loop->toInlineString(); + return os; +} + class LoopNest { - private: public: LoopNest(Scope& top_level) : top_level_(top_level) {} int64_t size() const { - return std::ssize(frames_); + return std::ssize(loop_infos_); } bool empty() const { - return frames_.empty(); + return loop_infos_.empty(); } void closeLoop() { NVF_ERROR(!empty()); - frames_.pop_back(); + loop_infos_.pop_back(); } - Frame& innermost() { + const LoopInfo& innermost() const { NVF_ERROR(!empty()); - return frames_.back(); + return loop_infos_.back(); } - Scope& innermostScope() { + Scope& innermostScope() const { return empty() ? top_level_ : innermost().loop->body(); } - ForLoop* openLoop(IterDomain* id) { + hir::ForLoop* openLoop(IterDomain* id) { Scope& parent_scope = innermostScope(); - auto* for_loop = ForLoop::createFromIterDomain(id); - frames_.push_back( + auto* for_loop = hir::ForLoop::createFromIterDomain(id); + loop_infos_.push_back( {for_loop, &parent_scope, parent_scope.push_back(for_loop)}); return for_loop; } @@ -66,26 +69,18 @@ class LoopNest { friend std::ostream& operator<<(std::ostream& os, const LoopNest& loop_nest); private: - std::vector frames_; + std::vector loop_infos_; Scope& top_level_; }; -std::ostream& operator<<(std::ostream& os, const Frame& frame) { - os << frame.loop->toInlineString(); - return os; -} - std::ostream& operator<<(std::ostream& os, const LoopNest& loop_nest) { os << "LoopNest:" << std::endl; - for (const auto& frame : loop_nest.frames_) { - indent(os, 1) << frame << frame.loop->toString() << std::endl; + for (const auto& loop_info : loop_nest.loop_infos_) { + indent(os, 1) << loop_info << std::endl; } return os; } -} // namespace hir - -namespace { // Finds the stream-parallelized IterDomain in the loop domain of a TensorView, // or nullptr if not found. This is different from `getShardedIterDomain(tv, // ParallelType::Stream)`, which searches the allocation domain. Consider @@ -115,7 +110,7 @@ void lowerSegment( const AliasInfoMap& aliases, const LaunchParams& launch_params, hir::HostIrContainer& hic, - hir::LoopNest& loop_nest, + LoopNest& loop_nest, IrCloner& ir_cloner) { switch (group.schedulerType()) { case SchedulerType::Communication: { @@ -346,7 +341,7 @@ std::unique_ptr lowerSegmentedFusionToHostIr( hic->addKernelExecutor(std::unique_ptr(ke)); } - hir::LoopNest loop_nest(hic->topLevel()); + LoopNest loop_nest(hic->topLevel()); IdModel id_model(segmented_fusion.completeFusion(), /*build_graphs=*/false); id_model.buildExactGraph(); From 7d26a00a966ca6d7baf5a0290631ab082bbff09a Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 13 Nov 2025 16:59:25 -0800 Subject: [PATCH 5/6] Comment --- csrc/host_ir/lowering.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 983273ecfc0..a27f35dab4e 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -134,7 +134,8 @@ void lowerSegment( if (tv->getDeviceMesh().has(device_id)) { auto* allocate = IrBuilder::create(tv, MemoryType::Global); - // FIXME: allocation may have to go to the top level. + // TODO: allocation may have to go to the top level. See how + // SchedulerType::ExprEval handles allocations. loop_nest.innermostScope().push_back(allocate); } loop_nest.innermostScope().push_back(communication); From 547da865cff71e250c314b133ce5a72e78fde328 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 13 Nov 2025 18:05:28 -0800 Subject: [PATCH 6/6] Refactor host IR lowering --- csrc/host_ir/lowering.cpp | 450 ++++++++++++++++++++------------------ 1 file changed, 233 insertions(+), 217 deletions(-) diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index a27f35dab4e..af4114c1a61 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -14,6 +14,7 @@ #include #include #include +#include namespace nvfuser { @@ -105,178 +106,6 @@ IterDomain* findStreamIterDomain(const std::vector& outs) { return nullptr; } -void lowerSegment( - const SegmentedGroup& group, - const AliasInfoMap& aliases, - const LaunchParams& launch_params, - hir::HostIrContainer& hic, - LoopNest& loop_nest, - IrCloner& ir_cloner) { - switch (group.schedulerType()) { - case SchedulerType::Communication: { - auto device_id = Communicator::getInstance().deviceId(); - NVF_ERROR_EQ( - group.exprs().size(), - 1, - "Communication segments must contain only one Expr."); - // If a value is already cloned, IrCloner::clone returns the cloned value - // without cloning the value again. - Expr* e = ir_cloner.clone(group.exprs().front()); - - for (auto* c : convertSingleOpToCommunication(e, device_id)) { - NVF_ERROR( - c->isA(), - "Exprs in a Communication group should be Communication: ", - c); - // Allocate the recv buffers of communications - auto* communication = c->as(); - TensorView* tv = communication->out(); - if (tv->getDeviceMesh().has(device_id)) { - auto* allocate = - IrBuilder::create(tv, MemoryType::Global); - // TODO: allocation may have to go to the top level. See how - // SchedulerType::ExprEval handles allocations. - loop_nest.innermostScope().push_back(allocate); - } - loop_nest.innermostScope().push_back(communication); - auto wait = IrBuilder::create(communication); - loop_nest.innermostScope().push_back(wait); - } - break; - } - case SchedulerType::ExprEval: { - // Pseudocode: - // clang-format off - // ``` - // if no expressions are stream parallelized: - // append the list to the top level - // return - // - // create a new, empty for loop - // for each expression in the segment: - // for each input TensorView of that expression: - // if it's allocated outside the loop: - // shard it by stream - // for each output TensorView of that expression: - // if it needs to be allocated outside the loop: - // create an Allocate before the for loop - // shard it by stream - // add the expression to the loop body with the maybe-sharded inputs and outputs - // ``` - // clang-format on - const std::vector& exprs = - ir_cloner.clone(group.stablyOrderedExprs()); - - std::vector outs = ir_cloner.clone(group.outputs()); - // All expressions in the group are expected to be stream parallelized in - // the same way. So it's safe to find the stream IterDomain from any of - // them. Ideally, loop domains should be tied to expressions not - // TensorViews. - IterDomain* stream_id = findStreamIterDomain(outs); - if (stream_id == nullptr) { - for (Expr* e : exprs) { - loop_nest.innermostScope().push_back(e); - } - break; - } - - auto [for_loop, parent_scope, parent_insertion_point] = - loop_nest.innermost(); - - std::unordered_map replacement_map; - for (Expr* e : exprs) { - for (auto* in : ir_utils::filterByType(e->inputs())) { - if (findStreamIterDomain(in) != nullptr && - getShardedIterDomain(in, ParallelType::Stream) == nullptr) { - auto [i, inserted] = replacement_map.try_emplace( - in, hir::shardByStream(in, for_loop->index())); - if (inserted) { - for_loop->body().push_back(i->second->definition()); - } - } - } - - for (auto* out : ir_utils::filterByType(e->outputs())) { - if (getShardedIterDomain(out, ParallelType::Stream) == nullptr) { - auto* allocate = - IrBuilder::create(out, MemoryType::Global); - parent_scope->insert(parent_insertion_point, allocate); - // Loop is stream parallelized but allocation is not. Therefore, - // `out` should be allocated outside the loop. - // - // I use try_emplace here so shardByStream is called only when `out` - // is missing. - auto [i, inserted] = replacement_map.try_emplace( - out, hir::shardByStream(out, for_loop->index())); - NVF_ERROR(inserted); - for_loop->body().push_back(i->second->definition()); - } - } - - std::vector new_inputs; - std::transform( - e->inputs().begin(), - e->inputs().end(), - std::back_inserter(new_inputs), - [&replacement_map](Val* input) { - return getOrDefault(replacement_map, input, input); - }); - std::vector new_outputs; - std::transform( - e->outputs().begin(), - e->outputs().end(), - std::back_inserter(new_outputs), - [&replacement_map](Val* output) { - return getOrDefault(replacement_map, output, output); - }); - Expr* new_e = e->newObjectFunc()( - e->container(), new_inputs, new_outputs, e->attributes()); - for_loop->body().push_back(new_e); - } - break; - } - default: { - std::vector ins = ir_cloner.clone(group.inputs()); - std::vector outs = ir_cloner.clone(group.outputs()); - - // Allocate the output TensorViews. - for (auto* out : outs) { - NVF_ERROR( - out->isA(), - "Output must be a TensorView but got ", - out); - const AliasInfo& alias = aliases.get(out); - NVF_ERROR_EQ( - alias.type, - AllocationType::New, - "Output ", - out->toString(), - " must not be an alias, got ", - alias); - auto* tv = out->as(); - auto* allocate = - IrBuilder::create(tv, MemoryType::Global); - loop_nest.innermostScope().push_back(allocate); - } - - // Add the LaunchKernel instruction. - const int group_id = group.groupId(); - KernelExecutor& ke = hic.getKernelExecutor(group_id); - // Needed for KernelExecutor. Should be removed once #4927 is fixed. - auto* cache_id = - IrBuilder::create("cacheId", DataType::UInt64); - auto launch_kernel = IrBuilder::create( - group_id, - launch_params, - ke.compiledKernel()->compileParams(), - ins, - outs, - cache_id); - loop_nest.innermostScope().push_back(launch_kernel); - } - } // switch -} // lowerSegment - // Finds the TensorView in the group whose loop domain has the most parallel // types and returns its loop domain. const std::vector& findReferenceLoopDomain( @@ -322,64 +151,251 @@ int64_t computeInlinePosition( return inline_position; } -} // namespace -std::unique_ptr lowerSegmentedFusionToHostIr( - const SegmentedFusion& segmented_fusion, - const std::vector& launch_params_per_segment, - std::vector>& executors) { - auto hic = std::make_unique(); - IrCloner ir_cloner = - Fusion::copy(segmented_fusion.completeFusion(), hic.get()); +class HostIrLowering { + public: + HostIrLowering( + const SegmentedFusion& segmented_fusion, + const std::vector& launch_params_per_segment, + std::vector>& executors, + hir::HostIrContainer& hic) + : segmented_fusion_(segmented_fusion), + launch_params_per_segment_(launch_params_per_segment), + hic_(hic), + ir_cloner_(&hic), + loop_nest_(hic.topLevel()) { + ir_cloner_ = Fusion::copy(segmented_fusion.completeFusion(), &hic_); + + for (auto& executor : executors) { + if (executor == nullptr) { + continue; + } + auto* ke = executor.release()->as(); + hic_.addKernelExecutor(std::unique_ptr(ke)); + } + } + + void lower() { + FusionGuard fg(&hic_); + + IdModel id_model( + segmented_fusion_.completeFusion(), /*build_graphs=*/false); + id_model.buildExactGraph(); - FusionGuard fg(hic.get()); + std::vector prev_ref_loop; + for (SegmentedGroup* group : + prepareRuntimeOrder(segmented_fusion_).group_run_order) { + // Compute the inline position. + const std::vector& curr_ref_loop = + findReferenceLoopDomain(*group); + // Compute the inline position based on parallel type and ID mapping. - for (auto& executor : executors) { - if (executor == nullptr) { - continue; + const int64_t inline_position = + computeInlinePosition(prev_ref_loop, curr_ref_loop, id_model); + + while (loop_nest_.size() > inline_position) { + loop_nest_.closeLoop(); + } + while (loop_nest_.size() < std::ssize(curr_ref_loop) && + curr_ref_loop.at(loop_nest_.size())->isStream()) { + auto* stream_id = ir_cloner_.clone(curr_ref_loop.at(loop_nest_.size())); + loop_nest_.openLoop(stream_id); + } + + lowerSegment(*group, launch_params_per_segment_.at(group->groupId())); + + prev_ref_loop = std::move(curr_ref_loop); } - auto* ke = executor.release()->as(); - hic->addKernelExecutor(std::unique_ptr(ke)); } - LoopNest loop_nest(hic->topLevel()); + void lowerSegment( + const SegmentedGroup& group, + const LaunchParams& launch_params) { + switch (group.schedulerType()) { + case SchedulerType::Communication: { + auto device_id = Communicator::getInstance().deviceId(); + NVF_ERROR_EQ( + group.exprs().size(), + 1, + "Communication segments must contain only one Expr."); + // If a value is already cloned, IrCloner::clone returns the cloned + // value without cloning the value again. + Expr* e = ir_cloner_.clone(group.exprs().front()); + + for (auto* c : convertSingleOpToCommunication(e, device_id)) { + NVF_ERROR( + c->isA(), + "Exprs in a Communication group should be Communication: ", + c); + // Allocate the recv buffers of communications + auto* communication = c->as(); + TensorView* tv = communication->out(); + if (tv->getDeviceMesh().has(device_id)) { + auto* allocate = + IrBuilder::create(tv, MemoryType::Global); + // TODO: allocation may have to go to the top level. See how + // SchedulerType::ExprEval handles allocations. + loop_nest_.innermostScope().push_back(allocate); + } + loop_nest_.innermostScope().push_back(communication); + auto wait = IrBuilder::create(communication); + loop_nest_.innermostScope().push_back(wait); + } + break; + } + case SchedulerType::ExprEval: { + // Pseudocode: + // clang-format off + // ``` + // if no expressions are stream parallelized: + // append the list to the top level + // return + // + // create a new, empty for loop + // for each expression in the segment: + // for each input TensorView of that expression: + // if it's allocated outside the loop: + // shard it by stream + // for each output TensorView of that expression: + // if it needs to be allocated outside the loop: + // create an Allocate before the for loop + // shard it by stream + // add the expression to the loop body with the maybe-sharded inputs and outputs + // ``` + // clang-format on + const std::vector& exprs = + ir_cloner_.clone(group.stablyOrderedExprs()); + + std::vector outs = ir_cloner_.clone(group.outputs()); + // All expressions in the group are expected to be stream parallelized + // in the same way. So it's safe to find the stream IterDomain from any + // of them. Ideally, loop domains should be tied to expressions not + // TensorViews. + IterDomain* stream_id = findStreamIterDomain(outs); + if (stream_id == nullptr) { + for (Expr* e : exprs) { + loop_nest_.innermostScope().push_back(e); + } + break; + } - IdModel id_model(segmented_fusion.completeFusion(), /*build_graphs=*/false); - id_model.buildExactGraph(); + auto [for_loop, parent_scope, parent_insertion_point] = + loop_nest_.innermost(); - std::vector prev_ref_loop; - for (SegmentedGroup* group : - prepareRuntimeOrder(segmented_fusion).group_run_order) { - // Compute the inline position. - const std::vector& curr_ref_loop = - findReferenceLoopDomain(*group); - // Compute the inline position based on parallel type and ID mapping. + std::unordered_map replacement_map; + for (Expr* e : exprs) { + for (auto* in : ir_utils::filterByType(e->inputs())) { + if (findStreamIterDomain(in) != nullptr && + getShardedIterDomain(in, ParallelType::Stream) == nullptr) { + auto [i, inserted] = replacement_map.try_emplace( + in, hir::shardByStream(in, for_loop->index())); + if (inserted) { + for_loop->body().push_back(i->second->definition()); + } + } + } - const int64_t inline_position = - computeInlinePosition(prev_ref_loop, curr_ref_loop, id_model); + for (auto* out : ir_utils::filterByType(e->outputs())) { + if (getShardedIterDomain(out, ParallelType::Stream) == nullptr) { + auto* allocate = + IrBuilder::create(out, MemoryType::Global); + parent_scope->insert(parent_insertion_point, allocate); + // Loop is stream parallelized but allocation is not. Therefore, + // `out` should be allocated outside the loop. + // + // I use try_emplace here so shardByStream is called only when + // `out` is missing. + auto [i, inserted] = replacement_map.try_emplace( + out, hir::shardByStream(out, for_loop->index())); + NVF_ERROR(inserted); + for_loop->body().push_back(i->second->definition()); + } + } - while (loop_nest.size() > inline_position) { - loop_nest.closeLoop(); - } - while (loop_nest.size() < std::ssize(curr_ref_loop) && - curr_ref_loop.at(loop_nest.size())->isStream()) { - auto* stream_id = ir_cloner.clone(curr_ref_loop.at(loop_nest.size())); - loop_nest.openLoop(stream_id); - } + std::vector new_inputs; + std::transform( + e->inputs().begin(), + e->inputs().end(), + std::back_inserter(new_inputs), + [&replacement_map](Val* input) { + return getOrDefault(replacement_map, input, input); + }); + std::vector new_outputs; + std::transform( + e->outputs().begin(), + e->outputs().end(), + std::back_inserter(new_outputs), + [&replacement_map](Val* output) { + return getOrDefault(replacement_map, output, output); + }); + Expr* new_e = e->newObjectFunc()( + e->container(), new_inputs, new_outputs, e->attributes()); + for_loop->body().push_back(new_e); + } + break; + } + default: { + std::vector ins = ir_cloner_.clone(group.inputs()); + std::vector outs = ir_cloner_.clone(group.outputs()); + + // Allocate the output TensorViews. + for (auto* out : outs) { + NVF_ERROR( + out->isA(), + "Output must be a TensorView but got ", + out); + const AliasInfo& alias = + segmented_fusion_.completeFusion()->getOutputAliases().get(out); + NVF_ERROR_EQ( + alias.type, + AllocationType::New, + "Output ", + out->toString(), + " must not be an alias, got ", + alias); + auto* tv = out->as(); + auto* allocate = + IrBuilder::create(tv, MemoryType::Global); + loop_nest_.innermostScope().push_back(allocate); + } - // FIXME: consider making HostIrLowering a class so `hic` and - // `segmented_fusion` can be made global. - lowerSegment( - *group, - segmented_fusion.completeFusion()->getOutputAliases(), - launch_params_per_segment.at(group->groupId()), - *hic, - loop_nest, - ir_cloner); - - prev_ref_loop = std::move(curr_ref_loop); + // Add the LaunchKernel instruction. + const int group_id = group.groupId(); + KernelExecutor& ke = hic_.getKernelExecutor(group_id); + // Needed for KernelExecutor. Should be removed once #4927 is fixed. + auto* cache_id = + IrBuilder::create("cacheId", DataType::UInt64); + auto launch_kernel = IrBuilder::create( + group_id, + launch_params, + ke.compiledKernel()->compileParams(), + ins, + outs, + cache_id); + loop_nest_.innermostScope().push_back(launch_kernel); + } + } // switch } + private: + const SegmentedFusion& segmented_fusion_; + const std::vector& launch_params_per_segment_; + hir::HostIrContainer& hic_; + IrCloner ir_cloner_; + LoopNest loop_nest_; +}; + +} // namespace + +std::unique_ptr lowerSegmentedFusionToHostIr( + const SegmentedFusion& segmented_fusion, + const std::vector& launch_params_per_segment, + std::vector>& executors) { + auto hic = std::make_unique(); + + HostIrLowering(segmented_fusion, launch_params_per_segment, executors, *hic) + .lower(); + hir_pass::InsertDeallocations().runPass(hic.get()); return hic;