Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Resize scheduler update #3657

Draft
wants to merge 59 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
1a370e7
Exclusiveness analysis
naoyam Dec 18, 2024
ac5a1bc
cleanup
naoyam Dec 18, 2024
8b8c708
cleanup
naoyam Dec 19, 2024
d364442
PR feedback
naoyam Dec 19, 2024
7380a40
Resolve conflicts by recomputation
naoyam Dec 19, 2024
eb9fffa
Merge remote-tracking branch 'origin/main' into resize_scheduler_reco…
naoyam Dec 20, 2024
9631958
test fix
naoyam Dec 20, 2024
76dbab9
fix
naoyam Dec 20, 2024
75338a4
cleanup
naoyam Dec 20, 2024
fd0cd06
Add a c++ test version of a rope module
naoyam Dec 20, 2024
bb193f7
WIP
naoyam Dec 20, 2024
6d83336
cleanup
naoyam Dec 20, 2024
8d0f05f
Add pre-mma seciton
naoyam Dec 20, 2024
a2c0652
bug fix
naoyam Dec 21, 2024
dd3a2d9
vectorize hack
naoyam Dec 21, 2024
1069c6f
vec
naoyam Dec 21, 2024
e5ec7d8
vec tuning
naoyam Dec 21, 2024
907ae7a
WIP: segmenter opt
naoyam Dec 21, 2024
523fe84
WIP
naoyam Dec 22, 2024
80bfc84
mistral bwd
naoyam Dec 23, 2024
d6711a7
fix
naoyam Dec 24, 2024
221a323
Merge remote-tracking branch 'origin/main' into resize_scheduler_reco…
naoyam Dec 24, 2024
e48a2f6
Recomputation needs to be done in a topological order
naoyam Dec 24, 2024
9a96e99
Translate repeat
naoyam Dec 24, 2024
cab9cd8
Avoid transpose
naoyam Dec 24, 2024
a73dab6
cleanup
naoyam Dec 24, 2024
e25a464
Translate the repetition pattern with expand and reshape
naoyam Dec 24, 2024
330a62f
cleanup
naoyam Dec 25, 2024
4c1abda
Drop support of addition-based concat as it isn't immediately necessary
naoyam Dec 25, 2024
e5fcf14
Merge remote-tracking branch 'origin/main' into translate_repeat_pattern
naoyam Dec 25, 2024
baebd7b
remove
naoyam Dec 25, 2024
a2f7df1
Merge remote-tracking branch 'origin/translate_repeat_pattern' into r…
naoyam Dec 25, 2024
795dd49
manual seg
naoyam Dec 25, 2024
1a5bcae
forward other single-input ops
naoyam Dec 25, 2024
2348d55
Skip transpose by default
naoyam Dec 25, 2024
197b0e7
Mistral bwd adjustment
naoyam Dec 26, 2024
8051a56
resize heuristic param with split gdimx
naoyam Dec 26, 2024
53cc163
reshape cancelation and vec
naoyam Dec 27, 2024
1527c84
bug fix
naoyam Dec 27, 2024
e4619e2
scheduling update
naoyam Dec 27, 2024
86fba69
scheduler update
naoyam Dec 30, 2024
cd39203
fix
naoyam Dec 30, 2024
5205d1d
war
naoyam Dec 30, 2024
410ee87
vec fix
naoyam Dec 30, 2024
4a45a99
Mistral and Litgpt benchmarks
naoyam Dec 31, 2024
75ec122
cleanup
naoyam Dec 31, 2024
b497d8e
Merge branch 'main' into resize_scheduler_opt
naoyam Dec 31, 2024
52f3353
Merge branch 'rope_benchmark_cpp_test' into resize_scheduler_opt
naoyam Dec 31, 2024
f9a2d37
Merge branch 'main' into resize_scheduler_recomputation
naoyam Dec 31, 2024
2653e13
remove a file added by accident
naoyam Dec 31, 2024
d66a67d
Merge branch 'main' into resize_scheduler_recomputation
naoyam Dec 31, 2024
87b4713
cleanup
naoyam Dec 31, 2024
0b48837
Merge branch 'main' into resize_scheduler_opt
naoyam Dec 31, 2024
30a5dc1
Merge branch 'resize_scheduler_recomputation' into resize_scheduler_opt
naoyam Dec 31, 2024
350af8d
Merge branch 'main' into resize_scheduler_opt
naoyam Dec 31, 2024
70b8920
remove file added accidentally
naoyam Dec 31, 2024
1d42538
Merge branch 'main' into resize_scheduler_opt
naoyam Dec 31, 2024
f577705
Merge branch 'main' into resize_scheduler_opt
naoyam Dec 31, 2024
e06cc3b
cleanup
naoyam Jan 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions csrc/bfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,75 @@ class BFS {
Direction allowed_direction_ = Direction::Undefined;
};

template <
typename ExprT,
typename ValT,
typename DefinitionT,
typename UsesT,
typename InputsT,
typename OutputsT>
class BFSWithPermissiveDependence
: public BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT> {
public:
using NodeType =
typename BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>::
NodeType;

BFSWithPermissiveDependence(
DefinitionT definition,
UsesT uses,
InputsT inputs,
OutputsT outputs,
std::vector<NodeType> from,
std::vector<NodeType> to,
bool require_all_to_visited = true,
Direction allowed_direction = Direction::Undefined)
: BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>(
definition,
uses,
inputs,
outputs,
std::move(from),
std::move(to),
require_all_to_visited,
allowed_direction) {}

std::optional<std::pair<Direction, std::vector<NodeType>>> isReady(
const ExprT& expr) const override {
// Either any inputs or any outputs must have been visited
decltype(auto) inputs = this->inputs_(expr);
if (!inputs.empty() && this->allowed_direction_ != Direction::Backward &&
std::any_of(
inputs.begin(), inputs.end(), [&](const ValT& input) -> bool {
return this->isDependencySatisfied(input);
})) {
std::vector<NodeType> prev_nodes;
std::copy_if(
inputs.begin(),
inputs.end(),
std::back_inserter(prev_nodes),
[&](const ValT& input) -> bool { return this->isVisited(input); });
return std::make_pair(Direction::Forward, prev_nodes);
}

decltype(auto) outputs = this->outputs_(expr);
if (!outputs.empty() && this->allowed_direction_ != Direction::Forward &&
std::any_of(
outputs.begin(), outputs.end(), [&](const ValT& output) -> bool {
return this->isDependencySatisfied(output);
})) {
std::vector<NodeType> prev_nodes;
std::copy_if(
outputs.begin(),
outputs.end(),
std::back_inserter(prev_nodes),
[&](const ValT& output) -> bool { return this->isVisited(output); });
return std::make_pair(Direction::Backward, prev_nodes);
}
return std::nullopt;
}
};

// Find the shortest path from the from vals to the to
// vals. Dependency between vals and exprs must be satisfied.
// It is an error if no valid path is found unless
Expand Down
112 changes: 96 additions & 16 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2198,6 +2198,7 @@ SegmentedGroup* SegmentCandidateFinder::mergeNodes() {
group2->exprs_.end());

auto producer_edges = getMergedProducerEdges(group1, group2);

// Connect joined group to resulting neighbors
for (auto edge : producer_edges) {
auto from = edge->from;
Expand Down Expand Up @@ -3621,13 +3622,23 @@ class PreferredMergeCandidatePicker {
PreferredMergeCandidatePicker(const std::vector<SegmentedGroup*>& groups)
: groups_(groups) {
for (auto& group : groups_) {
if (all_candidates_.count(group)) {
continue;
}
// Currently there's only one preference for select-like
// ops. Additional preferences can be added similarly.
auto neighbor_to_merge = mergeSelectLikeOpsWithProducers(group);
if (!neighbor_to_merge.has_value()) {
continue;
if (auto neighbor_to_merge = mergeSelectLikeOpsWithProducers(group);
neighbor_to_merge.has_value()) {
candidates_.emplace_back(group, *neighbor_to_merge);
all_candidates_.insert(group);
all_candidates_.insert(neighbor_to_merge->group);
} else if (auto neighbor_to_merge =
mergeCastToHigherPrecisionWithConsumers(group);
neighbor_to_merge.has_value()) {
candidates_.emplace_back(group, *neighbor_to_merge);
all_candidates_.insert(group);
all_candidates_.insert(neighbor_to_merge->group);
}
candidates_.emplace_back(group, *neighbor_to_merge);
}
}

Expand All @@ -3650,8 +3661,12 @@ class PreferredMergeCandidatePicker {
std::optional<SegmentedGroup::NeighborGroup> mergeSelectLikeOpsWithProducers(
SegmentedGroup* group) const;

std::optional<SegmentedGroup::NeighborGroup>
mergeCastToHigherPrecisionWithConsumers(SegmentedGroup* group) const;

private:
const std::vector<SegmentedGroup*>& groups_;
std::unordered_set<SegmentedGroup*> all_candidates_;
std::vector<std::pair<SegmentedGroup*, SegmentedGroup::NeighborGroup>>
candidates_;
};
Expand Down Expand Up @@ -3707,10 +3722,66 @@ std::optional<SegmentedGroup::NeighborGroup> PreferredMergeCandidatePicker::
return std::nullopt;
}

if (all_candidates_.count((*producer_edge_it)->from)) {
return std::nullopt;
}

return SegmentedGroup::NeighborGroup(
(*producer_edge_it)->from, *producer_edge_it);
}

std::optional<SegmentedGroup::NeighborGroup> PreferredMergeCandidatePicker::
mergeCastToHigherPrecisionWithConsumers(SegmentedGroup* group) const {
if (!getenv("CAST_SEGMENT")) {
return std::nullopt;
}

const auto& exprs = group->exprs();

if (exprs.size() != 1) {
return std::nullopt;
}

auto uop = dynamic_cast<UnaryOp*>(exprs.at(0));

if (uop == nullptr || uop->getUnaryOpType() != UnaryOpType::Cast) {
return std::nullopt;
}

auto inp_tv = ir_utils::getTvInput(uop);
auto out_tv = ir_utils::getTvOutput(uop);
if (inp_tv == nullptr || out_tv == nullptr) {
return std::nullopt;
}

auto inp_dtype = inp_tv->dtype().type;
auto out_dtype = out_tv->dtype().type;
auto inp_prim_type = std::get_if<PrimDataType>(&inp_dtype);
auto out_prim_type = std::get_if<PrimDataType>(&out_dtype);

if (inp_prim_type == nullptr || out_prim_type == nullptr) {
return std::nullopt;
}

if (primDataTypeSize(*inp_prim_type) >= primDataTypeSize(*out_prim_type)) {
return std::nullopt;
}

// For simplicity, limit this optimization only when there's only
// one consumer
if (group->consumer_edges.size() != 1) {
return std::nullopt;
}

auto edge = group->consumer_edges.front();

if (all_candidates_.count(edge->to)) {
return std::nullopt;
}

return SegmentedGroup::NeighborGroup(edge->to, edge);
}

} // namespace

bool SegmentCandidateFinder::codeGenSupportedMerge(
Expand Down Expand Up @@ -4019,7 +4090,7 @@ void SegmentCandidateFinder::findSegments() {
// fusion. Currently, we forward an input only when its single use is a UnaryOp.
// Therefore, this function returns `v`'s single unary use or nullptr if it
// decides not to forward.
UnaryOp* shouldForward(Val* v) {
Expr* shouldForward(Val* v) {
const std::vector<Expr*>& uses = v->uses();
// Just allow stripping out input with single use.
// Stripping out multi-used inputs can lead to:
Expand All @@ -4029,23 +4100,25 @@ UnaryOp* shouldForward(Val* v) {
return nullptr;
}

auto* unary_use = dynamic_cast<UnaryOp*>(uses.front());
if (unary_use == nullptr) {
auto* unary_use = uses.front();
if (!unary_use->isOneOf<UnaryOp, BroadcastOp, ExpandOp>()) {
return nullptr;
}

auto unary_use_out = unary_use->output(0);

// Don't forward an input to an output yet. Doing that would lead to an empty
// group that ought to work in theory but doesn't work in practice with the
// downstream logic. See #1813 for an example.
if (unary_use->out()->isFusionOutput()) {
if (unary_use_out->isFusionOutput()) {
return nullptr;
}

// prevent forward to a SegmenterSet, which could cause unary op forward to a
// no-op segment. See issue: https://github.com/NVIDIA/Fuser/issues/2658
if (std::any_of(
unary_use->out()->uses().begin(),
unary_use->out()->uses().end(),
unary_use_out->uses().begin(),
unary_use_out->uses().end(),
[](const Expr* next_use) {
if (const LoadStoreOp* use =
dynamic_cast<const LoadStoreOp*>(next_use)) {
Expand All @@ -4069,23 +4142,23 @@ void SegmentCandidateFinder::forwardInputs() {
// treated as complete fusion inputs.
VectorOfUniqueEntries<Val*> forwarded_inputs;
{
std::deque<UnaryOp*> to_visit;
std::deque<Expr*> to_visit;
for (Val* inp : completeFusion()->inputs()) {
if (UnaryOp* unary_use = shouldForward(inp)) {
if (Expr* unary_use = shouldForward(inp)) {
to_visit.push_back(unary_use);
}
}

while (!to_visit.empty()) {
UnaryOp* uop = to_visit.front();
Expr* uop = to_visit.front();
to_visit.pop_front();

if (UnaryOp* unary_use = shouldForward(uop->out())) {
if (Expr* unary_use = shouldForward(uop->output(0))) {
to_visit.push_back(unary_use);
} else {
// We cannot extend the chain of unary ops, so we finalize this chain by
// saving its output as a forwarded input.
forwarded_inputs.pushBack(uop->out());
forwarded_inputs.pushBack(uop->output(0));
}
// Either way, `uop` is excluded from merging until
// `resolveNonscalarForwardedInput` adds it back to one of the segments.
Expand Down Expand Up @@ -4346,7 +4419,14 @@ void SegmentCandidateFinder::resolveScalarsInGroup(SegmentedGroup* group) {

SegmentedGroup* SegmentCandidateFinder::createInputGroup(Val* forwarded_input) {
SegmentedGroup* group = segmented_fusion_->newGroup();
group->input_vals = IterVisitor::getInputsTo({forwarded_input});
// group->input_vals = IterVisitor::getInputsTo({forwarded_input});
auto inputs = IterVisitor::getInputsTo({forwarded_input});
for (auto inp : inputs) {
if (inp->isScalar()) {
continue;
}
group->input_vals.push_back(inp);
}
group->exprs_ = StmtSort::getExprsTo({forwarded_input});
return group;
}
Expand Down
48 changes: 47 additions & 1 deletion csrc/id_model/indexing_traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
// clang-format on
#include <id_model/id_model.h>
#include <id_model/indexing_traversal.h>
#include <ir/graphviz.h>
#include <ir/utils.h>

#include <fstream>

namespace nvfuser {

IndexingTraversal::IndexingTraversal(
Expand Down Expand Up @@ -260,10 +263,53 @@ std::optional<IndexingTraversal::ExprPath> IndexingTraversal::
local_graph,
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()},
/*require_all_to_visited=*/true);
/*require_all_to_visited=*/false);
traversal.traverse();
auto [path, all_visited] = traversal.getShortestExprPath();

if (!all_visited) {
auto reachable_vals = ValGroups(getReachableValsFrom<ValGraphBFS>(
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()},
Direction::Undefined,
local_graph));
for (const auto& to_group : to_groups) {
if (reachable_vals.has(to_group)) {
continue;
}

// Broadcast groups may not be reachable, which should be
// fine. For example, Index::getConsumerPerDimLogicalIndex may
// try to get an index of a broadcast logical ID. However, the
// loop domain of the tensor may not use the broadcast logical
// ID, which is completely fine.
if (to_group->front()->as<IterDomain>()->isBroadcast()) {
continue;
}

// Otherwise, this is an error. Need to understand why this
// happens.
// Dump the graph for debugging
std::ofstream ofs("local_graph.dot", std::ofstream::trunc);
auto dot_string = local_graph.toGraphvizDotGraph();
ofs << dot_string;
ofs.close();
std::stringstream ss;
ss << "Resize war indexing path failed. Expr: " << expr->toString()
<< "From IDs: " << toDelimitedString(from_ids) << "\n"
<< "From groups: " << nvfuser::toString(from_groups) << "\n"
<< "To IDs: " << toDelimitedString(to_ids) << "\n"
<< "To groups: " << nvfuser::toString(to_groups) << "\n"
<< "Reachable IDs: "
<< nvfuser::toString(getReachableValsFrom<ValGraphBFS>(
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()},
Direction::Undefined,
local_graph));
NVF_THROW(ss.str());
}
}

for (const auto& [g, d] : path) {
if (g->front()->isA<Resize>()) {
return path;
Expand Down
7 changes: 5 additions & 2 deletions csrc/preseg_passes/pre_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ namespace nvfuser::preseg_passes {
// avoid moving pad operatoins around, which could disturb the analysis
// from MarkAliasPrepare
// 2. after MoveSplitCat
// to avoid this pass moving PadOp around to break the MoveSplitCat.
OptimizationPass<MovePadPass>::runPass(fusion);
// to avoid this pass moving PadOp around to break the
// MoveSplitCat.
if (!isOptionEnabled(EnableOption::ResizeScheduler)) {
OptimizationPass<MovePadPass>::runPass(fusion);
}
// NOTE vvv this doesn't really work, since our type promotion to higher
// precision for Add cannot be canceled out with previous cast to lower
// precision. Since it's not an no-op and it has a quantization effect. I'll
Expand Down
Loading
Loading