From 7e9cca6bf03ad7c5b344dfbddf29d87d40f9db50 Mon Sep 17 00:00:00 2001 From: chen2016013 <111894720+chen2016013@users.noreply.github.com> Date: Wed, 18 Dec 2024 11:33:10 +0800 Subject: [PATCH] [Recompute] Fix recompute full_int_array bug (#70243) * fix recompute bug * update * rerun CI * Rerun CI --- python/paddle/decomposition/recompute.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/paddle/decomposition/recompute.py b/python/paddle/decomposition/recompute.py index 30893b00d059c..aa735c38d1dc0 100644 --- a/python/paddle/decomposition/recompute.py +++ b/python/paddle/decomposition/recompute.py @@ -461,6 +461,12 @@ def _is_materialized(value_node, placeholder_value_nodes): def _get_node_weight(value_node, placeholder_value_nodes): mem_sz = cal_value_node_size(value_node) + if ( + value_node.get_defining_op().name() in tending_to_recompute_ops + and mem_sz == 0 + ): + return 0.1 + # Heuristic to bias towards nodes closer to the backwards pass mem_sz = int( mem_sz * (1.1 ** max(min(dist_from_bw[value_node], 100), 1))