diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index 44bac91563..5f0dd96388 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -305,7 +305,8 @@ class TrainingArguments: refined_recompute (`str`, *optional*, defaults to `""`): The refined recompute parameter is designed to optimize the balance between GPU memory usage and computational speed. An example configuration could be: `attention_column_ln:-1,attention_row_ln:-1,flash_attn:-1,mlp_column_ln:5,mlp_row_ln:-1`. - The supported parameters for refining recompute are `attention_column_ln`, `attention_row_ln`, `flash_attn`, `mlp_column_ln`, and `mlp_row_ln`. + The supported parameters for refining recompute are `attention_column_ln`, `attention_row_ln`, `flash_attn`, `mlp_column_ln`, `mlp_row_ln`, and `global`. + -`global`: global configuration that applies to ALL operators. The associated number, `skip_num`, determines how many times to bypass recomputation for the specified operation. A `skip_num` of `-1` indicates no recomputation across all stages, maximizing memory usage; A `skip_num` of `0` enforces recomputation at every stage, minimizing memory usage. @@ -2029,6 +2030,7 @@ def is_context_parallel_supported(): "attention_column_ln": 0, "mlp_column_ln": 0, "flash_attn": 0, + "global": 0, } ops = self.refined_recompute.split(",") enable_rr = False