Skip to content

Commit

Permalink
Fixing DomainMap::areAllTargetIdsCoveredBy to handle broadcast to n…
Browse files Browse the repository at this point in the history
…on-root IDs (#3655)

Fixes #3653

This PR adds another projection to source ID after permissive mapping.
This allows us to correctly identify coverage when the broadcast IDs are
not source ID directly.
  • Loading branch information
jjsjann123 authored Jan 2, 2025
1 parent 6eb2bc0 commit ed5b3c9
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 25 deletions.
66 changes: 41 additions & 25 deletions csrc/scheduler/tools/domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,15 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv)
bool DomainMap::areAllTargetIdsCoveredBy(
TensorView* target_tv,
TensorView* reference_tv) const {
auto get_source_iter_domains = [this](TensorView* tv) {
auto get_source_iter_domains = [this](const std::vector<IterDomain*>& ids) {
// traverse back to collect all disjoint set producer IDs for each ID in the
// logical domain of tv.
VectorOfUniqueEntries<std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>>
all_producer_sets;
std::for_each(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
[&](IterDomain* tv_logical_id) {
all_producer_sets.pushBack(
ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT));
});
std::for_each(ids.begin(), ids.end(), [&](IterDomain* tv_logical_id) {
all_producer_sets.pushBack(
ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT));
});
all_producer_sets.pushBack(
ca_map_.getAllDisjointSetProducers(all_producer_sets));

Expand All @@ -213,7 +210,8 @@ bool DomainMap::areAllTargetIdsCoveredBy(
// this contains all source iter domain that's covered by reference_tv, so
// it's safe for target_tv to have them.
std::unordered_set<IterDomain*> covered_source_ids;
for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) {
for (IterDomain* source_id_ref :
get_source_iter_domains(reference_tv->getLogicalDomain())) {
covered_source_ids.insert(source_id_ref);
}
// It's safe to have unmapped broadcast IterDomain. There're quite a few tests
Expand All @@ -226,9 +224,10 @@ bool DomainMap::areAllTargetIdsCoveredBy(

// Note that ideally we should also be able to handle merge/split on
// broadcast IDs, so we should really move this skip inside the loop below
// `get_source_iter_domains(target_tv)` and skip broadcast source IDs.
// currently we have the issue that split/merge does not preserve expanded
// broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126
// `get_source_iter_domains(target_tv->getLogicalDomain())` and skip
// broadcast source IDs. currently we have the issue that split/merge does
// not preserve expanded broadcasts, see issue:
// https://github.com/NVIDIA/Fuser/issues/1126
covered_source_ids.insert(id_out);
}
}
Expand All @@ -248,21 +247,37 @@ bool DomainMap::areAllTargetIdsCoveredBy(
// https://github.com/NVIDIA/Fuser/issues/3542

// Check all source iter domain involved in producing target_tv
for (IterDomain* source_id_out : get_source_iter_domains(target_tv)) {
for (IterDomain* source_id_out :
get_source_iter_domains(target_tv->getLogicalDomain())) {
// NOTE: we use concrete id instead. This allows us to link indirect
// broadcast. So in the example below: T2[i0, i1] = T0[i0, b0] + T1[i0, i1]
// T3[i0, i9] = pad(T0[i0, b0])
// broadcast. So in the example below:
// input T0[
// T2[i0, i2*i3] = T0[i0, i2, i3]
// T3[i0, i2*i3] = T1[i0, b0] + T2[i0, i2*i3]
// T4[i0, i9] = pad(T1[i0, b0])
// We have i9 in T3
// -> source ID b0
// -> concrete map to i1
// -> concrete map to i2*i3
// -> source ID from i2*i3 to [i2, i3]
// So T3 is contained by T2. See test `PointwiseTest.DomainMapPad1`
auto concrete_source_id_out =
auto concrete_id_out =
ca_map_.getConcreteMappedID(source_id_out, IdMappingMode::PERMISSIVE);
// if we find any source_id_out that's not contained, it's possible our
// propagation would fail since transformation involving this iter domain
// can't be resolved.
if (!getMappedInputConcreteID(covered_source_ids, concrete_source_id_out)) {
return false;

// After mapping with PERMISSIVE map, `concrete_id_out` might no longer be a
// source ID. We project to source ID again from concrete_id_out. See test
// DomainMapBroadcastIssue3653
// In the example above. `i2*i3` is not a source ID. Hence we needed to go
// through another projection to source IDs in order to map it to
// covered_source_ids.
for (IterDomain* concrete_source_id_out :
get_source_iter_domains({concrete_id_out})) {
// if we find any source_id_out that's not contained, it's possible our
// propagation would fail since transformation involving this iter
// domain can't be resolved.
if (!getMappedInputConcreteID(
covered_source_ids, concrete_source_id_out)) {
return false;
}
}
}
return true;
Expand Down Expand Up @@ -359,7 +374,8 @@ IterDomain* DomainMap::anyMapped(
}

// Determine if output TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input and output
// The reference tensor must map to all the iterDomains in each input and
// output
bool DomainMap::isValidReference(TensorView* tv) const {
for (auto input_tv : ir_utils::filterByType<TensorView>(fusion_->inputs())) {
if (input_tv->uses().empty()) {
Expand All @@ -372,8 +388,8 @@ bool DomainMap::isValidReference(TensorView* tv) const {
}
}
// The check on outputs are optional, transpose scheduler might propose a
// secondary reference that only applies to a subset of IO tensors. Ideally we
// should have a more robust check and consider the IO groups instead of
// secondary reference that only applies to a subset of IO tensors. Ideally
// we should have a more robust check and consider the IO groups instead of
// blindly skip outputs.
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
Expand Down
40 changes: 40 additions & 0 deletions tests/cpp/test_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1242,4 +1242,44 @@ TEST_F(PointwiseTest, DomainMapSlice1) {
testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, DomainMapBroadcastIssue3653) {
auto fusion_ptr = std::make_unique<Fusion>();
FusionGuard fg(fusion_ptr.get());
Fusion& fusion = *fusion_ptr;

auto tv0 = makeConcreteTensor({2, 4, 8});
fusion.addInput(tv0);
auto tv1 = makeConcreteTensor({2});
fusion.addInput(tv1);

auto tv2 = reshape(tv0, {2, 4, 8}, {2, 32});
auto tv3 = broadcast(tv1, {false, true});
auto tv4 = add(tv2, tv3);

// tv4 covers source IDs {2, 4, 8}.
fusion.addOutput(tv4);
// meanwhile, tv3's broadcast ID map through permissive to `32`, which is not
// directly contained by tv4's source IDs. This test ensures that we project
// the mapped ID back to its source IDs and correctly schedule this fusion as
// a single kernel.
fusion.addOutput(tv3);

DomainMapUnitTest domain_map(fusion_ptr.get());
EXPECT_TRUE(domain_map.isValidReference(tv4));

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({2, 4, 8}, options);
auto t1 = at::randn({2}, options);
std::vector<c10::IValue> inputs({t0, t1});

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto out_tensors = executor_cache.runFusionWithInputs(inputs);

FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
NVF_CHECK(!runtime->isSegmented());

testValidate(
executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__);
}

} // namespace nvfuser

0 comments on commit ed5b3c9

Please sign in to comment.