diff --git a/CMakeLists.txt b/CMakeLists.txt index 758fe426db2..7aed3adc11e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1173,17 +1173,17 @@ if(BUILD_TEST) list(APPEND MULTIDEVICE_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/multidevice.cpp ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp - ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_pipeline.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_sharding.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_stream_parallel_type.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_transformer.cpp - ${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp ) add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "") list(APPEND TEST_BINARIES test_multidevice) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 82949a531e2..11f88f223fd 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -482,14 +482,15 @@ void HostIrEvaluator::handle(MatmulOp* matmul) { TensorView* b = matmul->inB(); TensorView* out = matmul->out(); - if (expr_evaluator_.isKnown(out)) { - auto t_a = getKnownConcreteValue(a).as(); - auto t_b = getKnownConcreteValue(b).as(); - auto t_out = getKnownConcreteValue(out).as(); - at::matmul_out(t_out, t_a, t_b); - } else { + if (!matmul->outputIsPreallocated()) { unhandled(matmul); + return; } + + auto t_a = getKnownConcreteValue(a).as(); + auto t_b = getKnownConcreteValue(b).as(); + auto t_out = getKnownConcreteValue(out).as(); + at::matmul_out(t_out, t_a, t_b); } void HostIrEvaluator::handle(LinearOp* linear) { @@ -497,9 +498,8 @@ void HostIrEvaluator::handle(LinearOp* linear) { auto* weight = linear->inB()->as(); auto* out = linear->out()->as(); - if (!expr_evaluator_.isKnown(out)) { - unhandled(linear); - return; + if (!linear->outputIsPreallocated()) { + return unhandled(linear); } auto in_tensor = getKnownConcreteValue(in).as(); @@ -753,7 +753,7 @@ void HostIrEvaluator::handle(ShardByStream* shard) { IterDomain* stream_id = *i; auto in_tensor = getKnownConcreteValue(shard->in()).as(); - int64_t stream_index = + auto stream_index = expr_evaluator_.evaluate(shard->stream_index()).as(); at::Tensor out_tensor = in_tensor diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 90a32c25748..8eec3309c72 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -20,10 +20,10 @@ namespace nvfuser { namespace { struct LoopInfo { - hir::ForLoop* loop; + hir::ForLoop* loop = nullptr; // The Scope that owns `loop`. It's one level outer than `loop`'s body scope. - Scope* parent_scope; + Scope* parent_scope = nullptr; // The iterator that points to `loop`. This way, we can insert instructions, // e.g. Allocate, right before the loop. @@ -31,7 +31,11 @@ struct LoopInfo { }; std::ostream& operator<<(std::ostream& os, const LoopInfo& loop_info) { - os << loop_info.loop->toInlineString(); + if (loop_info.loop == nullptr) { + os << ""; + } else { + os << loop_info.loop->toInlineString(); + } return os; } @@ -114,7 +118,10 @@ const std::vector& findReferenceLoopDomain( // Expr. Expr* cloneWithNewOperands( Expr* e, - const std::unordered_map& replacement_map) { + const std::unordered_map& replacement_map, + bool output_is_preallocated) { + NVF_ERROR(!e->outputIsPreallocated()); + auto maybe_replace = [&](Val*& x) -> bool { Val* new_x = getOrDefault(replacement_map, x); if (new_x == nullptr) { @@ -132,11 +139,16 @@ Expr* cloneWithNewOperands( std::vector new_outs = e->outputs(); replaced += std::ranges::count_if(new_outs, maybe_replace); - if (replaced == 0) { + if (replaced == 0 && !output_is_preallocated) { return e; } - return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); + Expr* new_e = + e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); + if (output_is_preallocated) { + new_e = new_e->withOutputPreallocated(); + } + return new_e; } void lowerSegment( @@ -146,6 +158,14 @@ void lowerSegment( hir::HostIrContainer& hic, LoopNest& loop_nest, IrCloner& ir_cloner) { + Scope& innermost_scope = loop_nest.innermostScope(); + // FIXME: cleanup. innermost can return an empty LoopInfo when the nest is + // empty. + LoopInfo innermost; + if (!loop_nest.empty()) { + innermost = loop_nest.innermost(); + } + switch (group.schedulerType()) { case SchedulerType::Communication: { auto device_id = Communicator::getInstance().deviceId(); @@ -157,24 +177,50 @@ void lowerSegment( // without cloning the value again. Expr* e = ir_cloner.clone(group.exprs().front()); - for (auto* c : convertSingleOpToCommunication(e, device_id)) { + // FIXME: should this be associated with the scope? + std::unordered_map replacement_map; + for (Expr* c : convertSingleOpToCommunication(e, device_id)) { NVF_ERROR( c->isA(), "Exprs in a Communication group should be Communication: ", c); - // Allocate the recv buffers of communications auto* communication = c->as(); - TensorView* tv = communication->out(); - if (tv->getDeviceMesh().has(device_id)) { - auto* allocate = - IrBuilder::create(tv, MemoryType::Global); - // TODO: allocation may have to go to the top level. See how - // SchedulerType::ExprEval handles allocations. - loop_nest.innermostScope().push_back(allocate); + TensorView* in = communication->in(); + TensorView* out = communication->out(); + if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) != + nullptr && + getShardedIterDomain( + in, ParallelType::Stream, DomainType::kAllocation) == nullptr) { + auto [i, inserted] = replacement_map.try_emplace( + in, hir::shardByStream(in, innermost.loop->index())); + if (inserted) { + innermost_scope.push_back(i->second->definition()); + } } - loop_nest.innermostScope().push_back(communication); - auto wait = IrBuilder::create(communication); - loop_nest.innermostScope().push_back(wait); + + // Allocate the recv buffers of communications + auto* allocate = + IrBuilder::create(out, MemoryType::Global); + if (getShardedIterDomain( + out, ParallelType::Stream, DomainType::kLoop) != nullptr && + getShardedIterDomain( + out, ParallelType::Stream, DomainType::kAllocation) == + nullptr) { + innermost.parent_scope->insert( + innermost.parent_insertion_point, allocate); + auto [i, inserted] = replacement_map.try_emplace( + out, hir::shardByStream(out, innermost.loop->index())); + NVF_ERROR(inserted); + innermost_scope.push_back(i->second->definition()); + } else { + innermost_scope.push_back(allocate); + } + + Expr* new_c = cloneWithNewOperands(c, replacement_map, true); + innermost_scope.push_back(new_c); + + auto* wait = IrBuilder::create(new_c); + innermost_scope.push_back(wait); } break; } @@ -206,14 +252,11 @@ void lowerSegment( // TensorViews. if (loop_nest.empty()) { for (Expr* e : exprs) { - loop_nest.innermostScope().push_back(e); + innermost_scope.push_back(e); } break; } - auto [for_loop, parent_scope, parent_insertion_point] = - loop_nest.innermost(); - std::unordered_map replacement_map; for (Expr* e : exprs) { for (auto* in : ir_utils::filterByType(e->inputs())) { @@ -223,34 +266,38 @@ void lowerSegment( in, ParallelType::Stream, DomainType::kAllocation) == nullptr) { auto [i, inserted] = replacement_map.try_emplace( - in, hir::shardByStream(in, for_loop->index())); + in, hir::shardByStream(in, innermost.loop->index())); if (inserted) { - for_loop->body().push_back(i->second->definition()); + innermost_scope.push_back(i->second->definition()); } } } + bool output_is_preallocated = false; for (auto* out : ir_utils::filterByType(e->outputs())) { if (getShardedIterDomain( out, ParallelType::Stream, DomainType::kAllocation) == nullptr) { auto* allocate = IrBuilder::create(out, MemoryType::Global); - parent_scope->insert(parent_insertion_point, allocate); + output_is_preallocated = true; + innermost.parent_scope->insert( + innermost.parent_insertion_point, allocate); // Loop is stream parallelized but allocation is not. Therefore, // `out` should be allocated outside the loop. // // I use try_emplace here so shardByStream is called only when `out` // is missing. auto [i, inserted] = replacement_map.try_emplace( - out, hir::shardByStream(out, for_loop->index())); + out, hir::shardByStream(out, innermost.loop->index())); NVF_ERROR(inserted); - for_loop->body().push_back(i->second->definition()); + innermost_scope.push_back(i->second->definition()); } } - Expr* new_e = cloneWithNewOperands(e, replacement_map); - for_loop->body().push_back(new_e); + Expr* new_e = + cloneWithNewOperands(e, replacement_map, output_is_preallocated); + innermost_scope.push_back(new_e); } break; } @@ -275,7 +322,7 @@ void lowerSegment( auto* tv = out->as(); auto* allocate = IrBuilder::create(tv, MemoryType::Global); - loop_nest.innermostScope().push_back(allocate); + innermost_scope.push_back(allocate); } // Add the LaunchKernel instruction. @@ -291,7 +338,7 @@ void lowerSegment( ins, outs, cache_id); - loop_nest.innermostScope().push_back(launch_kernel); + innermost_scope.push_back(launch_kernel); } } // switch } // lowerSegment diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index 7b9a60a5eeb..1c135c8d1a7 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -475,6 +475,7 @@ std::list processForLoopBodies( ir_utils::filterByType(body_expr->outputs())) { processTensor(body_expr, output, tensor_index); } + body_expr = body_expr->withOutputPreallocated(); new_loop_body.push_back(body_expr); } } diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 26dc0d31339..d06860bba17 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -253,6 +253,7 @@ std::optional Val::getDataType() const { // after inputs and outputs are registered with the Expr Expr::Expr(IrBuilderPasskey passkey) : Statement(passkey) {} +// FIXME: Should this constructor copy the output_is_preallocated_ flag? Expr::Expr(const Expr* src, IrCloner* ir_cloner) : Statement(src, ir_cloner), attributes_(ir_cloner->clone(src->attributes_)), @@ -270,12 +271,13 @@ Expr::Expr( outputs_(std::move(outputs)) {} Expr* Expr::shallowCopy() const { - auto result = + Expr* result = newObjectFunc()(ir_container_, inputs(), outputs(), attributes()); if (container()->isA()) { result->predicate_ = predicate_; result->write_predicate_ = write_predicate_; } + result->output_is_preallocated_ = output_is_preallocated_; return result; } @@ -383,6 +385,11 @@ Expr* Expr::withWritePredicate(kir::Predicate* predicate) { return result; } +Expr* Expr::withOutputPreallocated() { + output_is_preallocated_ = true; + return this; +} + std::vector Expr::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index d3ebdc5807c..e42afa8dba3 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -599,6 +599,12 @@ class NVF_API Expr : public Statement { // TODO: Protect based on being in kernel container Expr* withWritePredicate(kir::Predicate* write_predicate); + bool outputIsPreallocated() const { + return output_is_preallocated_; + } + + Expr* withOutputPreallocated(); + // Get the name of an expression virtual const char* getOpString() const = 0; @@ -660,6 +666,8 @@ class NVF_API Expr : public Statement { // Only used for reduction-related expressions kir::Predicate* write_predicate_ = nullptr; + + bool output_is_preallocated_ = false; }; template diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index f422b90086e..9379bedf4d6 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -62,12 +62,12 @@ class NVF_API Communicator { } // returns the number of processes in the communicator - auto size() const { + int64_t size() const { return size_; } // returns the local number of processes in the communicator (within the node) - auto local_size() const { + int64_t local_size() const { return local_size_; } @@ -89,7 +89,7 @@ class NVF_API Communicator { const std::string& prefix = ""); // returns the device associated with the current process - auto device() const { + at::Device device() const { return at::Device("cuda:" + std::to_string(local_rank_)); } diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index 962b540d675..9b2edc4150c 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -7,6 +7,8 @@ // clang-format on #include +#include + #include #include #include @@ -25,8 +27,6 @@ #include #include -#include - namespace nvfuser { namespace { diff --git a/tests/cpp/test_host_ir_evaluator.cpp b/tests/cpp/test_host_ir_evaluator.cpp index bebbc13c5a3..744425d7dd4 100644 --- a/tests/cpp/test_host_ir_evaluator.cpp +++ b/tests/cpp/test_host_ir_evaluator.cpp @@ -159,7 +159,8 @@ TEST_F(HostIrEvaluatorTest, MatmulInLoop) { // By default, MatmulOp is computed by ExpressionEvaluator so it appears in // host IR. - auto* mm = IrBuilder::create(loop_out, in, loop_w); + auto* mm = IrBuilder::create(loop_out, in, loop_w) + ->withOutputPreallocated(); for_loop->body().push_back(mm); hic->pushBackTopLevelExprs(allocate_out); diff --git a/tests/cpp/test_host_ir_stream_lowering.cpp b/tests/cpp/test_host_ir_stream_lowering.cpp index 6f2c322867a..e17c2bdd5cb 100644 --- a/tests/cpp/test_host_ir_stream_lowering.cpp +++ b/tests/cpp/test_host_ir_stream_lowering.cpp @@ -6,9 +6,6 @@ */ // clang-format on -#include -#include - #include #include diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index b932cc9332d..2292d858671 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -874,7 +874,8 @@ TEST_F(MatmulHostIrTest, HostIrMatmulOut) { TensorView* tv0 = makeContigTensor(3); TensorView* tv1 = makeContigTensor(3); TensorView* tv2 = makeContigTensor(3); - auto* matmul = IrBuilder::create(tv2, tv0, tv1); + auto* matmul = + IrBuilder::create(tv2, tv0, tv1)->withOutputPreallocated(); hic->addInput(tv0); hic->addInput(tv1); @@ -956,7 +957,8 @@ TEST_F(LinearHostIrTest, HostIrLinearOut) { TensorView* bias = makeContigTensor(1); TensorView* out = makeContigTensor(3); - auto linear_op = IrBuilder::create(out, in, weight, bias); + auto* linear_op = IrBuilder::create(out, in, weight, bias) + ->withOutputPreallocated(); hic->addInput(in); hic->addInput(weight); diff --git a/tests/cpp/test_multidevice_stream_parallel_type.cpp b/tests/cpp/test_multidevice_stream_parallel_type.cpp index 17a2739b8c6..8280a394262 100644 --- a/tests/cpp/test_multidevice_stream_parallel_type.cpp +++ b/tests/cpp/test_multidevice_stream_parallel_type.cpp @@ -5,8 +5,6 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include - #include #include @@ -24,7 +22,6 @@ namespace nvfuser { using testing::ElementsAre; -using testing::SizeIs; using MultiDeviceStreamParallelTypeTest = MultiDeviceTest; diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 30a8d8d83ea..0db27b32697 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -10,6 +10,49 @@ from nvfuser_direct import DataType, FusionDefinition, CommunicatorBackend, TensorView +@pytest.mark.mpi +def test_row_parallel_linear_forward(multidevice_direct_test): + # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. + h, s, t = 2, 3, 6 + d = multidevice_direct_test.size + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + assert t % s == 0 + + mesh = nvfuser.multidevice.DeviceMesh(range(d)) + + with FusionDefinition() as fd: + inp = fd.define_tensor( + shape=[-1, h * 4], contiguity=True, dtype=DataType.BFloat16 + ) + weight = fd.define_tensor( + shape=[h, h * 4], contiguity=True, dtype=DataType.BFloat16 + ) + out = fd.ops.linear(inp, weight) + fd.add_output(out) + + for tv in (inp, weight): + tv.set_device_mesh(mesh) + + inp.split(0, s, inner_split=False) + inp.axis(0).parallelize(nvfuser.ParallelType.stream) + inp.split(2, d, inner_split=False) + inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x) + weight.split(1, d, inner_split=False) + weight.axis(1).parallelize(nvfuser.ParallelType.mesh_x) + + inp_ref = torch.randint(-2, 3, (t, h * 4), dtype=torch.int32).to(torch.bfloat16) + weight_ref = torch.randint(-2, 3, (h, h * 4), dtype=torch.int32).to(torch.bfloat16) + out_ref = torch.nn.functional.linear(inp_ref, weight_ref) + + inp = (multidevice_direct_test.shard_tensor(inp_ref, -1, mesh),) + weight = (multidevice_direct_test.shard_tensor(weight_ref, -1, mesh),) + (out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"]) + torch.testing.assert_close(out.cpu(), out_ref) + + @pytest.mark.mpi @pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl]) @pytest.mark.parametrize("s", [1, 8])