Skip to content

Commit 02e0055

Browse files
Fix how we decide whether to create ShardByStream for inputs (#5562)
## Summary - add a helper to fetch the requested TensorView domain and let haveDifferentShardings take explicit producer and consumer DomainType arguments - update all callers, including the resharding passes and stream unit test, to pass the desired domain types - fix a bug where host IR lowering should inspect input allocation (not loop) domains when deciding whether to put shardByStream ## Testing - Not run (not requested) --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 8b08d20 commit 02e0055

File tree

4 files changed

+108
-43
lines changed

4 files changed

+108
-43
lines changed

csrc/host_ir/lowering.cpp

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -84,29 +84,45 @@ std::ostream& operator<<(std::ostream& os, const LoopNest& loop_nest) {
8484
return os;
8585
}
8686

87+
int numParallelIterDomains(const TensorView* tv) {
88+
return std::ranges::count_if(
89+
tv->getLoopDomain(), [](IterDomain* id) { return id->isParallelized(); });
90+
}
91+
92+
template <typename R>
93+
TensorView* findMostParallelTensorView(const R& range) {
94+
TensorView* reference = nullptr;
95+
int max_parallel_count = -1;
96+
for (TensorView* tv : range) {
97+
auto parallel_count = numParallelIterDomains(tv);
98+
if (parallel_count > max_parallel_count) {
99+
max_parallel_count = parallel_count;
100+
reference = tv;
101+
}
102+
}
103+
return reference;
104+
}
105+
87106
// Finds the TensorView in the group whose loop domain has the most parallel
88107
// types and returns its loop domain.
89-
const std::vector<IterDomain*>& findReferenceLoopDomain(
108+
const std::vector<IterDomain*>& findMostParallelLoopDomain(
90109
const SegmentedGroup& group) {
91-
TensorView* reference_tv = nullptr;
110+
TensorView* reference = nullptr;
92111
int max_parallel_count = -1;
93-
for (auto* expr : group.exprs()) {
94-
for (auto* tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
95-
auto loop_domain = tv->getLoopDomain();
96-
int parallel_count = 0;
97-
for (auto* id : loop_domain) {
98-
if (id->isParallelized()) {
99-
parallel_count++;
100-
}
101-
}
102-
if (parallel_count > max_parallel_count) {
103-
max_parallel_count = parallel_count;
104-
reference_tv = tv;
105-
}
112+
for (Expr* expr : group.exprs()) {
113+
TensorView* tv = findMostParallelTensorView(
114+
ir_utils::filterByType<TensorView>(expr->outputs()));
115+
if (tv == nullptr) {
116+
continue;
117+
}
118+
auto parallel_count = numParallelIterDomains(tv);
119+
if (parallel_count > max_parallel_count) {
120+
max_parallel_count = parallel_count;
121+
reference = tv;
106122
}
107123
}
108-
NVF_ERROR(reference_tv != nullptr);
109-
return reference_tv->getLoopDomain();
124+
NVF_ERROR(reference != nullptr, "Can't find any TensorView in ", &group);
125+
return reference->getLoopDomain();
110126
}
111127

112128
// Returns a new Expr with the inputs and outputs replaced by the replacement
@@ -217,11 +233,34 @@ void lowerSegment(
217233
std::unordered_map<Val*, Val*> replacement_map;
218234
for (Expr* e : exprs) {
219235
for (auto* in : ir_utils::filterByType<TensorView>(e->inputs())) {
220-
if (getShardedIterDomain(
221-
in, ParallelType::Stream, DomainType::kLoop) != nullptr &&
222-
getShardedIterDomain(
223-
in, ParallelType::Stream, DomainType::kAllocation) ==
224-
nullptr) {
236+
// A loop domain should go with an Expr rather than each individual
237+
// output TensorView. Before this is fixed, pick the most parallel
238+
// output TensorView as a proxy.
239+
TensorView* out = findMostParallelTensorView(
240+
ir_utils::filterByType<TensorView>(e->outputs()));
241+
if (out == nullptr) {
242+
continue;
243+
}
244+
// Check whether in's **allocation** and out's loop are sharded on
245+
// ParallelType::Stream consistently. If not, insert a ShardByStream.
246+
//
247+
// Consider the following example:
248+
// ```
249+
// in: [m, k] w: [k, n] # logical/allocation
250+
// |
251+
// | matmul
252+
// v
253+
// out: [m, n] logical
254+
// / \.
255+
// s m/s loop
256+
// ```
257+
// `in` needs to be sharded by stream regardless of its loop domain.
258+
if (haveDifferentShardings(
259+
in,
260+
DomainType::kAllocation,
261+
out,
262+
DomainType::kLoop,
263+
{ParallelType::Stream})) {
225264
auto [i, inserted] = replacement_map.try_emplace(
226265
in, hir::shardByStream(in, for_loop->index()));
227266
if (inserted) {
@@ -345,7 +384,7 @@ std::unique_ptr<hir::HostIrContainer> lowerSegmentedFusionToHostIr(
345384
for (SegmentedGroup* group :
346385
prepareRuntimeOrder(segmented_fusion).group_run_order) {
347386
const std::vector<IterDomain*>& curr_ref_loop =
348-
findReferenceLoopDomain(*group);
387+
findMostParallelLoopDomain(*group);
349388
const int64_t inline_position =
350389
computeInlinePosition(prev_ref_loop, curr_ref_loop, id_model);
351390
while (loop_nest.size() > inline_position) {

csrc/multidevice/utils.cpp

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,26 @@ std::ostream& operator<<(std::ostream& os, DomainType domain_type) {
4343
std::unreachable();
4444
}
4545

46+
namespace {
47+
48+
const std::vector<IterDomain*>& getDomainOf(
49+
const TensorView* tv,
50+
DomainType domain_type) {
51+
switch (domain_type) {
52+
case DomainType::kRoot:
53+
return tv->getMaybeRootDomain();
54+
case DomainType::kLogical:
55+
return tv->getLogicalDomain();
56+
case DomainType::kLoop:
57+
return tv->getLoopDomain();
58+
case DomainType::kAllocation:
59+
return tv->getMaybeAllocationDomain();
60+
}
61+
std::unreachable();
62+
}
63+
64+
} // namespace
65+
4666
bool isSharded(const TensorView* tv) {
4767
bool is_sharded = false;
4868
for (IterDomain* id : tv->getLoopDomain()) {
@@ -214,20 +234,7 @@ IterDomain* getShardedIterDomain(
214234
const TensorView* tv,
215235
const ParallelType parallel_type,
216236
const DomainType domain_type) {
217-
const std::vector<IterDomain*>& domain =
218-
[&]() -> const std::vector<IterDomain*>& {
219-
switch (domain_type) {
220-
case DomainType::kRoot:
221-
return tv->getMaybeRootDomain();
222-
case DomainType::kLogical:
223-
return tv->getLogicalDomain();
224-
case DomainType::kLoop:
225-
return tv->getLoopDomain();
226-
case DomainType::kAllocation:
227-
return tv->getMaybeAllocationDomain();
228-
}
229-
std::unreachable();
230-
}();
237+
const auto& domain = getDomainOf(tv, domain_type);
231238

232239
for (IterDomain* id : domain | TensorDomain::kNoReductions) {
233240
if (id->getParallelType() == parallel_type) {
@@ -318,7 +325,9 @@ std::unordered_set<IterDomain*> getInputsInTargetDomain(
318325

319326
bool haveDifferentShardings(
320327
const TensorView* producer,
328+
DomainType producer_domain_type,
321329
const TensorView* consumer,
330+
DomainType consumer_domain_type,
322331
const std::unordered_set<ParallelType>& parallel_types) {
323332
// cpu scalars are not parallelized
324333
if (producer->isCpuScalar() || consumer->isCpuScalar()) {
@@ -342,6 +351,9 @@ bool haveDifferentShardings(
342351
return true;
343352
}
344353

354+
const auto& producer_domain = getDomainOf(producer, producer_domain_type);
355+
const auto& consumer_domain = getDomainOf(consumer, consumer_domain_type);
356+
345357
// Special handling of SelectOp for a quick fix
346358
// TODO: work on a proper implementation
347359
if (consumer->definition()->isA<SelectOp>()) {
@@ -373,8 +385,8 @@ bool haveDifferentShardings(
373385
.mapBroadcast(false)
374386
.mapConsumerToProducer();
375387
return !std::all_of(
376-
consumer->getLoopDomain().begin(),
377-
consumer->getLoopDomain().end(),
388+
consumer_domain.begin(),
389+
consumer_domain.end(),
378390
[&c2p, &parallel_types](IterDomain* c_id) {
379391
auto p_id = c2p.at(c_id);
380392
auto p_id_pt = p_id->getParallelType();
@@ -455,9 +467,9 @@ bool haveDifferentShardings(
455467
// optimization, we create indices only for those that parallel_types depend
456468
// on.
457469
std::unordered_map<ParallelType, IterDomain*> p_parallel_type_to_id =
458-
mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain());
470+
mapDeviceAndStreamParallelTypeToId(producer_domain);
459471
std::unordered_map<ParallelType, IterDomain*> c_parallel_type_to_id =
460-
mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain());
472+
mapDeviceAndStreamParallelTypeToId(consumer_domain);
461473
for (const auto parallel_type : parallel_types) {
462474
if (IterDomain* p_loop_id =
463475
getOrDefault(p_parallel_type_to_id, parallel_type)) {
@@ -548,6 +560,14 @@ bool haveDifferentShardings(
548560
return false;
549561
}
550562

563+
bool haveDifferentShardings(
564+
const TensorView* producer,
565+
const TensorView* consumer,
566+
const std::unordered_set<ParallelType>& parallel_types) {
567+
return haveDifferentShardings(
568+
producer, DomainType::kLoop, consumer, DomainType::kLoop, parallel_types);
569+
}
570+
551571
bool isResharding(const Expr* expr) {
552572
FUSER_PERF_SCOPE("isResharding");
553573

csrc/multidevice/utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ NVF_API bool isResharding(const Expr* expr);
4949

5050
// Returns whether two tensors have different shardings. Expect a
5151
// producer/consumer relationship between the arguments.
52+
bool haveDifferentShardings(
53+
const TensorView* producer,
54+
DomainType producer_domain_type,
55+
const TensorView* consumer,
56+
DomainType consumer_domain_type,
57+
const std::unordered_set<ParallelType>& parallel_types);
58+
59+
// Same as the above but checks loop domains for both producer and consumer.
5260
bool haveDifferentShardings(
5361
const TensorView* producer,
5462
const TensorView* consumer,

tests/cpp/test_stream.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ TEST_F(StreamTest, Matmul) {
7878
fusion->addInput(w);
7979
fusion->addOutput(out);
8080

81-
w->outer_split(1, c);
82-
w->axis(1)->parallelize(ParallelType::Stream);
8381
out->outer_split(1, c);
8482
out->axis(1)->parallelize(ParallelType::Stream);
8583
}

0 commit comments

Comments
 (0)