Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing DomainMap::areAllTargetIdsCoveredBy to handle broadcast to non-root IDs #3655

Merged
merged 7 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions csrc/scheduler/tools/domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,15 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv)
bool DomainMap::areAllTargetIdsCoveredBy(
TensorView* target_tv,
TensorView* reference_tv) const {
auto get_source_iter_domains = [this](TensorView* tv) {
auto get_source_iter_domains = [this](const std::vector<IterDomain*>& ids) {
// traverse back to collect all disjoint set producer IDs for each ID in the
// logical domain of tv.
VectorOfUniqueEntries<std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>>
all_producer_sets;
std::for_each(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
[&](IterDomain* tv_logical_id) {
all_producer_sets.pushBack(
ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT));
});
std::for_each(ids.begin(), ids.end(), [&](IterDomain* tv_logical_id) {
all_producer_sets.pushBack(
ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT));
});
all_producer_sets.pushBack(
ca_map_.getAllDisjointSetProducers(all_producer_sets));

Expand All @@ -213,7 +210,8 @@ bool DomainMap::areAllTargetIdsCoveredBy(
// this contains all source iter domain that's covered by reference_tv, so
// it's safe for target_tv to have them.
std::unordered_set<IterDomain*> covered_source_ids;
for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) {
for (IterDomain* source_id_ref :
get_source_iter_domains(reference_tv->getLogicalDomain())) {
covered_source_ids.insert(source_id_ref);
}
// It's safe to have unmapped broadcast IterDomain. There're quite a few tests
Expand All @@ -226,9 +224,10 @@ bool DomainMap::areAllTargetIdsCoveredBy(

// Note that ideally we should also be able to handle merge/split on
// broadcast IDs, so we should really move this skip inside the loop below
// `get_source_iter_domains(target_tv)` and skip broadcast source IDs.
// currently we have the issue that split/merge does not preserve expanded
// broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126
// `get_source_iter_domains(target_tv->getLogicalDomain())` and skip
// broadcast source IDs. currently we have the issue that split/merge does
// not preserve expanded broadcasts, see issue:
// https://github.com/NVIDIA/Fuser/issues/1126
covered_source_ids.insert(id_out);
}
}
Expand All @@ -248,21 +247,30 @@ bool DomainMap::areAllTargetIdsCoveredBy(
// https://github.com/NVIDIA/Fuser/issues/3542

// Check all source iter domain involved in producing target_tv
for (IterDomain* source_id_out : get_source_iter_domains(target_tv)) {
for (IterDomain* source_id_out :
get_source_iter_domains(target_tv->getLogicalDomain())) {
// NOTE: we use concrete id instead. This allows us to link indirect
// broadcast. So in the example below: T2[i0, i1] = T0[i0, b0] + T1[i0, i1]
// T3[i0, i9] = pad(T0[i0, b0])
// We have i9 in T3
// -> source ID b0
// -> concrete map to i1
// So T3 is contained by T2. See test `PointwiseTest.DomainMapPad1`
auto concrete_source_id_out =
auto concrete_id_out =
ca_map_.getConcreteMappedID(source_id_out, IdMappingMode::PERMISSIVE);
// if we find any source_id_out that's not contained, it's possible our
// propagation would fail since transformation involving this iter domain
// can't be resolved.
if (!getMappedInputConcreteID(covered_source_ids, concrete_source_id_out)) {
return false;

// After mapping with PERMISSIVE map, `concrete_id_out` might no longer be a
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the patch. I added another call to get_source_iter_domain({concrete_id_out})

Copy link
Collaborator

@naoyam naoyam Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, for a given logical ID, we take its source IDs, then for each source ID, take the source IDs of its concrete ID. I understand it would work for the repro, but I wonder why it would need to take source IDs twice. For example, would it work if we took first the concrete ID of a logical ID and then its source IDs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to take the first source ID of the logical ID for the pad example to work, i.e. in the comment above from line 252 - 258.

T2[i0, i1] = T0[i0, b0] + T1[i0, i1]
T3[i0, i9] = pad(T0[i0, b0])

We want to have T2 as the reference TV and we want to be able to map T3 to T2. For i9 in T3. I wanted to first map to source ID, so it can be map to b0, which would concretize to i1.

So my naive justification of going through source IDs twice is that: the first source ID call allows us to trace back to broadcast IDs that went through a resize; while the second source ID calls allow us to resolve transformations on the concrete IDs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this PR as long as #3653 is fixed, but I'm feeling there's something not quite right here, probably because T2 and T3 shouldn't be actually mapped. They could have completely different sizes, and neither T2 nor T3 should not be considered representative.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you.

We have some existing tests that asserts the behavior, which might not be what we would want to support moving forward.

I'll update the comment here per suggestion and merge this PR as-is just so we can patch the issue.

Meanwhile, I think the case where you are making is that, we shouldn't map through PERMISSIVE mapping in the first place here. I'll start a draft PR with that to see how many tests are affected. We can discuss what's the next step then.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it's something we should worry about at this moment. No test should fail functionally since it's a segmentation issue. It would be just a performance difference, but unless we have a benchmark that has this particular pattern, it doesn't make much sense to discuss what we should do.

// source ID. We project to source ID again from concrete_id_out. See test
// DomainMapBroadcastIssue3653
for (IterDomain* concrete_source_id_out :
get_source_iter_domains({concrete_id_out})) {
// if we find any source_id_out that's not contained, it's possible our
// propagation would fail since transformation involving this iter
// domain can't be resolved.
if (!getMappedInputConcreteID(
covered_source_ids, concrete_source_id_out)) {
return false;
}
}
}
return true;
Expand Down Expand Up @@ -359,7 +367,8 @@ IterDomain* DomainMap::anyMapped(
}

// Determine if output TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input and output
// The reference tensor must map to all the iterDomains in each input and
// output
bool DomainMap::isValidReference(TensorView* tv) const {
for (auto input_tv : ir_utils::filterByType<TensorView>(fusion_->inputs())) {
if (input_tv->uses().empty()) {
Expand All @@ -372,8 +381,8 @@ bool DomainMap::isValidReference(TensorView* tv) const {
}
}
// The check on outputs are optional, transpose scheduler might propose a
// secondary reference that only applies to a subset of IO tensors. Ideally we
// should have a more robust check and consider the IO groups instead of
// secondary reference that only applies to a subset of IO tensors. Ideally
// we should have a more robust check and consider the IO groups instead of
// blindly skip outputs.
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
Expand Down
40 changes: 40 additions & 0 deletions tests/cpp/test_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1242,4 +1242,44 @@ TEST_F(PointwiseTest, DomainMapSlice1) {
testValidate(fusion, cg_results.outputs, aten_inputs, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, DomainMapBroadcastIssue3653) {
auto fusion_ptr = std::make_unique<Fusion>();
FusionGuard fg(fusion_ptr.get());
Fusion& fusion = *fusion_ptr;

auto tv0 = makeConcreteTensor({2, 4, 8});
fusion.addInput(tv0);
auto tv1 = makeConcreteTensor({2});
fusion.addInput(tv1);

auto tv2 = reshape(tv0, {2, 4, 8}, {2, 32});
auto tv3 = broadcast(tv1, {false, true});
auto tv4 = add(tv2, tv3);

// tv4 covers source IDs {2, 4, 8}.
fusion.addOutput(tv4);
// meanwhile, tv3's broadcast ID map through permissive to `32`, which is not
// directly contained by tv4's source IDs. This test ensures that we project
// the mapped ID back to its source IDs and correctly schedule this fusion as
// a single kernel.
fusion.addOutput(tv3);

DomainMapUnitTest domain_map(fusion_ptr.get());
EXPECT_TRUE(domain_map.isValidReference(tv4));

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({2, 4, 8}, options);
auto t1 = at::randn({2}, options);
std::vector<c10::IValue> inputs({t0, t1});

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto out_tensors = executor_cache.runFusionWithInputs(inputs);

FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
NVF_CHECK(!runtime->isSegmented());

testValidate(
executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__);
}

} // namespace nvfuser
Loading