Skip to content

Commit fd5af5d

Browse files
committed
WIP
1 parent 7b88601 commit fd5af5d

File tree

8 files changed

+116
-40
lines changed

8 files changed

+116
-40
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,17 +1173,17 @@ if(BUILD_TEST)
11731173
list(APPEND MULTIDEVICE_TEST_SRCS
11741174
${NVFUSER_ROOT}/tests/cpp/multidevice.cpp
11751175
${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp
1176-
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp
11771176
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp
11781177
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp
11791178
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp
1179+
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp
1180+
${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp
11801181
${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp
11811182
${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp
11821183
${NVFUSER_ROOT}/tests/cpp/test_multidevice_pipeline.cpp
11831184
${NVFUSER_ROOT}/tests/cpp/test_multidevice_sharding.cpp
11841185
${NVFUSER_ROOT}/tests/cpp/test_multidevice_stream_parallel_type.cpp
11851186
${NVFUSER_ROOT}/tests/cpp/test_multidevice_transformer.cpp
1186-
${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp
11871187
)
11881188
add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "")
11891189
list(APPEND TEST_BINARIES test_multidevice)

csrc/host_ir/evaluator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ void HostIrEvaluator::handle(ShardByStream* shard) {
777777
IterDomain* stream_id = *i;
778778

779779
auto in_tensor = getKnownConcreteValue(shard->in()).as<at::Tensor>();
780-
int64_t stream_index =
780+
auto stream_index =
781781
expr_evaluator_.evaluate(shard->stream_index()).as<int64_t>();
782782
at::Tensor out_tensor =
783783
in_tensor

csrc/host_ir/host_ir.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
271271
NVF_ERROR(
272272
(expr->isOneOf<Communication, P2PCommunication, EndCoalescing>()),
273273
expr,
274-
"must be a Communication, a P2PCommunication, or a EndCoalescing");
274+
" must be a Communication, a P2PCommunication, or a EndCoalescing");
275275
}
276276

277277
NVFUSER_DEFINE_CLONE_AND_CREATE(Wait)

csrc/host_ir/lowering.cpp

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@ namespace nvfuser {
2020
namespace {
2121

2222
struct LoopInfo {
23-
hir::ForLoop* loop;
23+
hir::ForLoop* loop = nullptr;
2424

2525
// The Scope that owns `loop`. It's one level outer than `loop`'s body scope.
26-
Scope* parent_scope;
26+
Scope* parent_scope = nullptr;
2727

2828
// The iterator that points to `loop`. This way, we can insert instructions,
2929
// e.g. Allocate, right before the loop.
3030
Scope::Iterator parent_insertion_point;
3131
};
3232

3333
std::ostream& operator<<(std::ostream& os, const LoopInfo& loop_info) {
34-
os << loop_info.loop->toInlineString();
34+
if (loop_info.loop == nullptr) {
35+
os << "<null>";
36+
} else {
37+
os << loop_info.loop->toInlineString();
38+
}
3539
return os;
3640
}
3741

@@ -131,7 +135,7 @@ Expr* cloneWithNewOperands(
131135
int64_t out_replaced = std::ranges::count_if(new_outs, maybe_replace);
132136

133137
if (in_replaced == 0 && out_replaced == 0) {
134-
return 0;
138+
return e;
135139
}
136140

137141
if (out_replaced > 0) {
@@ -151,6 +155,14 @@ void lowerSegment(
151155
hir::HostIrContainer& hic,
152156
LoopNest& loop_nest,
153157
IrCloner& ir_cloner) {
158+
Scope& innermost_scope = loop_nest.innermostScope();
159+
// FIXME: cleanup. innermost can return an empty LoopInfo when the nest is
160+
// empty.
161+
LoopInfo innermost;
162+
if (!loop_nest.empty()) {
163+
innermost = loop_nest.innermost();
164+
}
165+
154166
switch (group.schedulerType()) {
155167
case SchedulerType::Communication: {
156168
auto device_id = Communicator::getInstance().deviceId();
@@ -162,24 +174,50 @@ void lowerSegment(
162174
// without cloning the value again.
163175
Expr* e = ir_cloner.clone(group.exprs().front());
164176

165-
for (auto* c : convertSingleOpToCommunication(e, device_id)) {
177+
// FIXME: should this be associated with the scope?
178+
std::unordered_map<Val*, Val*> replacement_map;
179+
for (Expr* c : convertSingleOpToCommunication(e, device_id)) {
166180
NVF_ERROR(
167181
c->isA<Communication>(),
168182
"Exprs in a Communication group should be Communication: ",
169183
c);
170-
// Allocate the recv buffers of communications
171184
auto* communication = c->as<Communication>();
172-
TensorView* tv = communication->out();
173-
if (tv->getDeviceMesh().has(device_id)) {
174-
auto* allocate =
175-
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
176-
// TODO: allocation may have to go to the top level. See how
177-
// SchedulerType::ExprEval handles allocations.
178-
loop_nest.innermostScope().push_back(allocate);
185+
TensorView* in = communication->in();
186+
TensorView* out = communication->out();
187+
if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) !=
188+
nullptr &&
189+
getShardedIterDomain(
190+
in, ParallelType::Stream, DomainType::kAllocation) == nullptr) {
191+
auto [i, inserted] = replacement_map.try_emplace(
192+
in, hir::shardByStream(in, innermost.loop->index()));
193+
if (inserted) {
194+
innermost_scope.push_back(i->second->definition());
195+
}
179196
}
180-
loop_nest.innermostScope().push_back(communication);
181-
auto wait = IrBuilder::create<hir::Wait>(communication);
182-
loop_nest.innermostScope().push_back(wait);
197+
198+
// Allocate the recv buffers of communications
199+
auto* allocate =
200+
IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
201+
if (getShardedIterDomain(
202+
out, ParallelType::Stream, DomainType::kLoop) != nullptr &&
203+
getShardedIterDomain(
204+
out, ParallelType::Stream, DomainType::kAllocation) ==
205+
nullptr) {
206+
innermost.parent_scope->insert(
207+
innermost.parent_insertion_point, allocate);
208+
auto [i, inserted] = replacement_map.try_emplace(
209+
out, hir::shardByStream(out, innermost.loop->index()));
210+
NVF_ERROR(inserted);
211+
innermost_scope.push_back(i->second->definition());
212+
} else {
213+
innermost_scope.push_back(allocate);
214+
}
215+
216+
Expr* new_c = cloneWithNewOperands(c, replacement_map);
217+
innermost_scope.push_back(new_c);
218+
219+
auto* wait = IrBuilder::create<hir::Wait>(new_c);
220+
innermost_scope.push_back(wait);
183221
}
184222
break;
185223
}
@@ -211,14 +249,11 @@ void lowerSegment(
211249
// TensorViews.
212250
if (loop_nest.empty()) {
213251
for (Expr* e : exprs) {
214-
loop_nest.innermostScope().push_back(e);
252+
innermost_scope.push_back(e);
215253
}
216254
break;
217255
}
218256

219-
auto [for_loop, parent_scope, parent_insertion_point] =
220-
loop_nest.innermost();
221-
222257
std::unordered_map<Val*, Val*> replacement_map;
223258
for (Expr* e : exprs) {
224259
for (auto* in : ir_utils::filterByType<TensorView>(e->inputs())) {
@@ -228,9 +263,9 @@ void lowerSegment(
228263
in, ParallelType::Stream, DomainType::kAllocation) ==
229264
nullptr) {
230265
auto [i, inserted] = replacement_map.try_emplace(
231-
in, hir::shardByStream(in, for_loop->index()));
266+
in, hir::shardByStream(in, innermost.loop->index()));
232267
if (inserted) {
233-
for_loop->body().push_back(i->second->definition());
268+
innermost_scope.push_back(i->second->definition());
234269
}
235270
}
236271
}
@@ -241,21 +276,22 @@ void lowerSegment(
241276
nullptr) {
242277
auto* allocate =
243278
IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
244-
parent_scope->insert(parent_insertion_point, allocate);
279+
innermost.parent_scope->insert(
280+
innermost.parent_insertion_point, allocate);
245281
// Loop is stream parallelized but allocation is not. Therefore,
246282
// `out` should be allocated outside the loop.
247283
//
248284
// I use try_emplace here so shardByStream is called only when `out`
249285
// is missing.
250286
auto [i, inserted] = replacement_map.try_emplace(
251-
out, hir::shardByStream(out, for_loop->index()));
287+
out, hir::shardByStream(out, innermost.loop->index()));
252288
NVF_ERROR(inserted);
253-
for_loop->body().push_back(i->second->definition());
289+
innermost_scope.push_back(i->second->definition());
254290
}
255291
}
256292

257293
Expr* new_e = cloneWithNewOperands(e, replacement_map);
258-
for_loop->body().push_back(new_e);
294+
innermost_scope.push_back(new_e);
259295
}
260296
break;
261297
}
@@ -280,7 +316,7 @@ void lowerSegment(
280316
auto* tv = out->as<TensorView>();
281317
auto* allocate =
282318
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
283-
loop_nest.innermostScope().push_back(allocate);
319+
innermost_scope.push_back(allocate);
284320
}
285321

286322
// Add the LaunchKernel instruction.
@@ -296,7 +332,7 @@ void lowerSegment(
296332
ins,
297333
outs,
298334
cache_id);
299-
loop_nest.innermostScope().push_back(launch_kernel);
335+
innermost_scope.push_back(launch_kernel);
300336
}
301337
} // switch
302338
} // lowerSegment

csrc/multidevice/communicator.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ class NVF_API Communicator {
6262
}
6363

6464
// returns the number of processes in the communicator
65-
auto size() const {
65+
int64_t size() const {
6666
return size_;
6767
}
6868

6969
// returns the local number of processes in the communicator (within the node)
70-
auto local_size() const {
70+
int64_t local_size() const {
7171
return local_size_;
7272
}
7373

@@ -89,7 +89,7 @@ class NVF_API Communicator {
8989
const std::string& prefix = "");
9090

9191
// returns the device associated with the current process
92-
auto device() const {
92+
at::Device device() const {
9393
return at::Device("cuda:" + std::to_string(local_rank_));
9494
}
9595

csrc/runtime/fusion_kernel_runtime.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
// clang-format on
88
#include <runtime/fusion_kernel_runtime.h>
99

10+
#include <c10/cuda/CUDAGuard.h>
11+
1012
#include <fusion.h>
1113
#include <fusion_profiler.h>
1214
#include <fusion_segmenter.h>
@@ -25,8 +27,6 @@
2527
#include <serde/fusion_cache_generated.h>
2628
#include <type.h>
2729

28-
#include <c10/cuda/CUDAGuard.h>
29-
3030
namespace nvfuser {
3131

3232
namespace {

tests/cpp/test_multidevice_stream_parallel_type.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
* SPDX-License-Identifier: BSD-3-Clause
66
*/
77
// clang-format on
8-
#include <iterator>
9-
108
#include <cuda_profiler_api.h>
119

1210
#include <fusion.h>
@@ -24,7 +22,6 @@
2422
namespace nvfuser {
2523

2624
using testing::ElementsAre;
27-
using testing::SizeIs;
2825

2926
using MultiDeviceStreamParallelTypeTest = MultiDeviceTest;
3027

tests/python/multidevice/test_overlap.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,49 @@
1010
from nvfuser_direct import DataType, FusionDefinition, CommunicatorBackend, TensorView
1111

1212

13+
@pytest.mark.mpi
14+
def test_row_parallel_linear_forward(multidevice_direct_test):
15+
# This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward.
16+
h, s, t = 2, 3, 6
17+
d = multidevice_direct_test.size
18+
if (h * 4) % d != 0:
19+
pytest.skip(
20+
f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
21+
)
22+
assert t % s == 0
23+
24+
mesh = nvfuser.multidevice.DeviceMesh(range(d))
25+
26+
with FusionDefinition() as fd:
27+
inp = fd.define_tensor(
28+
shape=[-1, h * 4], contiguity=True, dtype=DataType.BFloat16
29+
)
30+
weight = fd.define_tensor(
31+
shape=[h, h * 4], contiguity=True, dtype=DataType.BFloat16
32+
)
33+
out = fd.ops.linear(inp, weight)
34+
fd.add_output(out)
35+
36+
for tv in (inp, weight):
37+
tv.set_device_mesh(mesh)
38+
39+
inp.split(0, s, inner_split=False)
40+
inp.axis(0).parallelize(nvfuser.ParallelType.stream)
41+
inp.split(2, d, inner_split=False)
42+
inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x)
43+
weight.split(1, d, inner_split=False)
44+
weight.axis(1).parallelize(nvfuser.ParallelType.mesh_x)
45+
46+
inp_ref = torch.randint(-2, 3, (t, h * 4), dtype=torch.int32).to(torch.bfloat16)
47+
weight_ref = torch.randint(-2, 3, (h, h * 4), dtype=torch.int32).to(torch.bfloat16)
48+
out_ref = torch.nn.functional.linear(inp_ref, weight_ref)
49+
50+
inp = (multidevice_direct_test.shard_tensor(inp_ref, -1, mesh),)
51+
weight = (multidevice_direct_test.shard_tensor(weight_ref, -1, mesh),)
52+
(out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"])
53+
torch.testing.assert_close(out.cpu(), out_ref)
54+
55+
1356
@pytest.mark.mpi
1457
@pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl])
1558
@pytest.mark.parametrize("s", [1, 8])

0 commit comments

Comments
 (0)