diff --git a/paddle/cinn/operator_fusion/graph_transformer/operation.h b/paddle/cinn/operator_fusion/graph_transformer/operation.h index 8e0c8514c7cc6e..5209e7a4050c2b 100644 --- a/paddle/cinn/operator_fusion/graph_transformer/operation.h +++ b/paddle/cinn/operator_fusion/graph_transformer/operation.h @@ -28,11 +28,12 @@ struct MergeTrivialPatternOperation { "The trivial pattern wait for sinking should has " "at least 1 downstream , but got %d.", upstream->downstream().size())); - + VLOG(4) << "Sink trivial pattern: \nupstream: " << upstream->DebugStr(); std::vector fusion_candidate = upstream->downstream(); upstream->ClearDownstream(); - for (const auto& downstream : fusion_candidate) { + for (int i = 0; i < fusion_candidate.size(); ++i) { + const auto& downstream = fusion_candidate[i]; bool can_fuse = std::holds_alternative(downstream->stmt_pattern()) || std::holds_alternative(downstream->stmt_pattern()) || @@ -45,15 +46,13 @@ struct MergeTrivialPatternOperation { std::holds_alternative(downstream->stmt_pattern()); if (can_fuse) { + VLOG(4) << "\ndownstream [" << i << "]: " << downstream->DebugStr(); auto merged_node = graph->MergeNode(upstream, downstream, MergePattern); merged_node->set_fusion_iters( graph->iters_fusion_policy()->SingleDownstreamItersFusion( upstream, downstream)); graph->RemoveNode(downstream); - VLOG(4) << "Splitting trivial pattern: \nupstream " - << upstream->DebugStr() << "\ndownstream " - << downstream->DebugStr() << "\nmerged " - << merged_node->DebugStr(); + VLOG(4) << "\nmerged [" << i << "] " << merged_node->DebugStr(); merged_node->AppendInstr(std::make_shared( upstream->id(), downstream->id(), merged_node->id())); } else { @@ -98,6 +97,8 @@ struct MergeReduceTreeAndTrivialOperation { "The downstream of the ReduceTree node should be 1, but got %d.", node->downstream().size())); auto downstream = node->downstream().at(0); + VLOG(4) << "MergeReduceTreeAndTrivialOperation: \nupstream " + << node->DebugStr() << "\ndownstream " << downstream->DebugStr(); auto fake_reduce_iter_idx = graph->policy_manager() .template GetPolicy() ->GetFakeReduceIterIdx(node, downstream); @@ -132,9 +133,7 @@ struct MergeReduceTreeAndTrivialOperation { graph->RemoveNode(downstream); graph->RemoveNode(node); - VLOG(4) << "MergeReduceTreeAndTrivialOperation: \nupstream " - << node->DebugStr() << "\ndownstream " << downstream->DebugStr() - << "\nmerged " << merged_node->DebugStr(); + VLOG(4) << "merged " << merged_node->DebugStr(); merged_node->UpdateTracker(); return merged_node; } diff --git a/paddle/cinn/operator_fusion/pattern_fuser.h b/paddle/cinn/operator_fusion/pattern_fuser.h index 61f9f0c434b698..203479f1820080 100644 --- a/paddle/cinn/operator_fusion/pattern_fuser.h +++ b/paddle/cinn/operator_fusion/pattern_fuser.h @@ -74,7 +74,7 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first, second.sink_op(), std::make_shared(first.tracker_, second.tracker_)); result.set_loop_mapping( - LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false)); + TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping())); return result; } @@ -86,7 +86,7 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first, contents, std::make_shared(first.tracker_, second.tracker_)); result.set_loop_mapping( - LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false)); + TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping())); return result; } @@ -116,7 +116,7 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first, FusePatternIfConnected(first, second.GetRootPattern(), connect_ops), std::make_shared(first.tracker_, second.tracker_)); result.set_loop_mapping( - LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false)); + TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping())); return result; } @@ -129,7 +129,7 @@ static StmtPattern MergePatternImpl( std::make_shared(first.tracker_, second.tracker_)); result.fake_reduce_iter_idx = second.fake_reduce_iter_idx; result.set_loop_mapping( - LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false)); + TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping())); return result; } @@ -140,7 +140,7 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first, std::make_shared(first.tracker_, second.tracker_), second.loop_dims()); result.set_loop_mapping( - LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false)); + TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping())); return result; } @@ -149,7 +149,7 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first, return AnchorPattern( UniqueConcatVector(GetOpsInPattern(first), GetOpsInPattern(second)), std::make_shared(first.tracker_, second.tracker_), - LoopMappingMerge(first.loop_mapping(), second.loop_mapping(), false)); + TrivialSinkLoopMappingMerge(first.loop_mapping(), second.loop_mapping())); } // RR & RT diff --git a/paddle/cinn/operator_fusion/pattern_graph.cc b/paddle/cinn/operator_fusion/pattern_graph.cc index acba56aeab660b..196a05ff803be7 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.cc +++ b/paddle/cinn/operator_fusion/pattern_graph.cc @@ -155,11 +155,10 @@ std::vector PatternGraph::SortByReverseTopoOrder() const { } void PatternGraph::SinkTrivialPattern() { - GraphTransformer, - OnlyOneDownstreamMatcher, - Not>, - MergeTrivialPatternOperation>(this); + GraphTransformer< + NodePattern, + And, OnlyOneDownstreamMatcher>, + MergeTrivialPatternOperation>(this); // TODO(huangjiyi): remove sink multi downstream transpose after // supporting transpose plus reduce anchor fusion diff --git a/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_axis_mapping.cc b/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_axis_mapping.cc index 4616204a2795cc..1f3b381132a9b7 100644 --- a/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_axis_mapping.cc +++ b/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_axis_mapping.cc @@ -444,6 +444,22 @@ LoopAxisMapping LoopMappingMerge(const LoopAxisMapping& upstream, return result; } +LoopAxisMapping TrivialSinkLoopMappingMerge(const LoopAxisMapping& upstream, + const LoopAxisMapping& downstream) { + auto result = LoopMappingMergeImpl(upstream, downstream, false); + auto upstream_out_value = upstream.output_values[0]; + auto indices = FindPosInVector(result.output_values, upstream_out_value); + if (!indices.empty()) { + auto idx = indices.front(); + result.output_values.erase(result.output_values.begin() + idx); + result.loop2output.erase(result.loop2output.begin() + idx); + result.outputs_use_count.erase(upstream_out_value); + } + result.SimplifyForwardMapping(); + result.SetReverseMapping(); + return result; +} + std::vector GetFakeReduceAxisIdx(const std::vector& loop, const AxisTransformRoute& route, int reduce_axis_num) { diff --git a/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_axis_mapping.h b/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_axis_mapping.h index ff0e8deccc3926..2e2968e60aa143 100644 --- a/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_axis_mapping.h +++ b/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_axis_mapping.h @@ -150,6 +150,8 @@ LoopAxisMapping CreateLoopMapping(pir::Operation* op); LoopAxisMapping LoopMappingMerge(const LoopAxisMapping& upstream, const LoopAxisMapping& downstream, bool upstream_is_anchor); +LoopAxisMapping TrivialSinkLoopMappingMerge(const LoopAxisMapping& upstream, + const LoopAxisMapping& downstream); LoopAxisMapping ReducePlusTrivialLoopMappingMerge( const LoopAxisMapping& upstream, const LoopAxisMapping& downstream);