Skip to content

Commit

Permalink
fix trivial sink LoopMapping merge
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjiyi committed Feb 18, 2025
1 parent de6560b commit 18b4b1b
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 20 deletions.
17 changes: 8 additions & 9 deletions paddle/cinn/operator_fusion/graph_transformer/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ struct MergeTrivialPatternOperation {
"The trivial pattern wait for sinking should has "
"at least 1 downstream , but got %d.",
upstream->downstream().size()));

VLOG(4) << "Sink trivial pattern: \nupstream: " << upstream->DebugStr();
std::vector<PatternNodePtr> fusion_candidate = upstream->downstream();
upstream->ClearDownstream();

for (const auto& downstream : fusion_candidate) {
for (int i = 0; i < fusion_candidate.size(); ++i) {
const auto& downstream = fusion_candidate[i];
bool can_fuse =
std::holds_alternative<ReducePattern>(downstream->stmt_pattern()) ||
std::holds_alternative<TrivialPattern>(downstream->stmt_pattern()) ||
Expand All @@ -45,15 +46,13 @@ struct MergeTrivialPatternOperation {
std::holds_alternative<AnchorPattern>(downstream->stmt_pattern());

if (can_fuse) {
VLOG(4) << "\ndownstream [" << i << "]: " << downstream->DebugStr();
auto merged_node = graph->MergeNode(upstream, downstream, MergePattern);
merged_node->set_fusion_iters(
graph->iters_fusion_policy()->SingleDownstreamItersFusion(
upstream, downstream));
graph->RemoveNode(downstream);
VLOG(4) << "Splitting trivial pattern: \nupstream "
<< upstream->DebugStr() << "\ndownstream "
<< downstream->DebugStr() << "\nmerged "
<< merged_node->DebugStr();
VLOG(4) << "\nmerged [" << i << "] " << merged_node->DebugStr();
merged_node->AppendInstr(std::make_shared<TrivialInlineInstr>(
upstream->id(), downstream->id(), merged_node->id()));
} else {
Expand Down Expand Up @@ -98,6 +97,8 @@ struct MergeReduceTreeAndTrivialOperation {
"The downstream of the ReduceTree node should be 1, but got %d.",
node->downstream().size()));
auto downstream = node->downstream().at(0);
VLOG(4) << "MergeReduceTreeAndTrivialOperation: \nupstream "
<< node->DebugStr() << "\ndownstream " << downstream->DebugStr();
auto fake_reduce_iter_idx = graph->policy_manager()
.template GetPolicy<RelativeJudgePolicy>()
->GetFakeReduceIterIdx(node, downstream);
Expand Down Expand Up @@ -132,9 +133,7 @@ struct MergeReduceTreeAndTrivialOperation {

graph->RemoveNode(downstream);
graph->RemoveNode(node);
VLOG(4) << "MergeReduceTreeAndTrivialOperation: \nupstream "
<< node->DebugStr() << "\ndownstream " << downstream->DebugStr()
<< "\nmerged " << merged_node->DebugStr();
VLOG(4) << "merged " << merged_node->DebugStr();
merged_node->UpdateTracker();
return merged_node;
}
Expand Down
12 changes: 6 additions & 6 deletions paddle/cinn/operator_fusion/pattern_fuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first,
second.sink_op(),
std::make_shared<FusionTracker>(first.tracker_, second.tracker_));
result.set_loop_mapping(
LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false));
TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping()));
return result;
}

Expand All @@ -86,7 +86,7 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first,
contents,
std::make_shared<FusionTracker>(first.tracker_, second.tracker_));
result.set_loop_mapping(
LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false));
TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping()));
return result;
}

Expand Down Expand Up @@ -116,7 +116,7 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first,
FusePatternIfConnected(first, second.GetRootPattern(), connect_ops),
std::make_shared<FusionTracker>(first.tracker_, second.tracker_));
result.set_loop_mapping(
LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false));
TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping()));
return result;
}

Expand All @@ -129,7 +129,7 @@ static StmtPattern MergePatternImpl(
std::make_shared<FusionTracker>(first.tracker_, second.tracker_));
result.fake_reduce_iter_idx = second.fake_reduce_iter_idx;
result.set_loop_mapping(
LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false));
TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping()));
return result;
}

Expand All @@ -140,7 +140,7 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first,
std::make_shared<FusionTracker>(first.tracker_, second.tracker_),
second.loop_dims());
result.set_loop_mapping(
LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false));
TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping()));
return result;
}

Expand All @@ -149,7 +149,7 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first,
return AnchorPattern(
UniqueConcatVector(GetOpsInPattern(first), GetOpsInPattern(second)),
std::make_shared<FusionTracker>(first.tracker_, second.tracker_),
LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false));
TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping()));
}

// RR & RT
Expand Down
9 changes: 4 additions & 5 deletions paddle/cinn/operator_fusion/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,10 @@ std::vector<PatternNodePtr> PatternGraph::SortByReverseTopoOrder() const {
}

void PatternGraph::SinkTrivialPattern() {
GraphTransformer<NodePattern,
And<StmtPatternGraphMatcher<TrivialPattern>,
OnlyOneDownstreamMatcher,
Not<ReshapeConnectionMatcher>>,
MergeTrivialPatternOperation>(this);
GraphTransformer<
NodePattern,
And<StmtPatternGraphMatcher<TrivialPattern>, OnlyOneDownstreamMatcher>,
MergeTrivialPatternOperation>(this);

// TODO(huangjiyi): remove sink multi downstream transpose after
// supporting transpose plus reduce anchor fusion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,22 @@ LoopAxisMapping LoopMappingMerge(const LoopAxisMapping& upstream,
return result;
}

LoopAxisMapping TrivialSinkLoopMappingMerge(const LoopAxisMapping& upstream,
const LoopAxisMapping& downstream) {
auto result = LoopMappingMergeImpl(upstream, downstream, false);
auto upstream_out_value = upstream.output_values[0];
auto indices = FindPosInVector(result.output_values, upstream_out_value);
if (!indices.empty()) {
auto idx = indices.front();
result.output_values.erase(result.output_values.begin() + idx);
result.loop2output.erase(result.loop2output.begin() + idx);
result.outputs_use_count.erase(upstream_out_value);
}
result.SimplifyForwardMapping();
result.SetReverseMapping();
return result;
}

std::vector<int> GetFakeReduceAxisIdx(const std::vector<symbol::DimExpr>& loop,
const AxisTransformRoute& route,
int reduce_axis_num) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ LoopAxisMapping CreateLoopMapping(pir::Operation* op);
LoopAxisMapping LoopMappingMerge(const LoopAxisMapping& upstream,
const LoopAxisMapping& downstream,
bool upstream_is_anchor);
LoopAxisMapping TrivialSinkLoopMappingMerge(const LoopAxisMapping& upstream,
const LoopAxisMapping& downstream);
LoopAxisMapping ReducePlusTrivialLoopMappingMerge(
const LoopAxisMapping& upstream, const LoopAxisMapping& downstream);

Expand Down

0 comments on commit 18b4b1b

Please sign in to comment.