Skip to content

Commit

Permalink
[Recompute] Fix recompute full_int_array bug (#70243)
Browse files Browse the repository at this point in the history
* fix recompute bug

* update

* rerun CI

* Rerun CI
  • Loading branch information
chen2016013 authored Dec 18, 2024
1 parent c7d49ec commit 7e9cca6
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/paddle/decomposition/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,12 @@ def _is_materialized(value_node, placeholder_value_nodes):
def _get_node_weight(value_node, placeholder_value_nodes):
mem_sz = cal_value_node_size(value_node)

if (
value_node.get_defining_op().name() in tending_to_recompute_ops
and mem_sz == 0
):
return 0.1

# Heuristic to bias towards nodes closer to the backwards pass
mem_sz = int(
mem_sz * (1.1 ** max(min(dist_from_bw[value_node], 100), 1))
Expand Down

0 comments on commit 7e9cca6

Please sign in to comment.