Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddleformers/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ class TrainingArguments:
refined_recompute (`str`, *optional*, defaults to `""`):
The refined recompute parameter is designed to optimize the balance between GPU memory usage and computational speed.
An example configuration could be: `attention_column_ln:-1,attention_row_ln:-1,flash_attn:-1,mlp_column_ln:5,mlp_row_ln:-1`.
The supported parameters for refining recompute are `attention_column_ln`, `attention_row_ln`, `flash_attn`, `mlp_column_ln`, and `mlp_row_ln`.
The supported parameters for refining recompute are `attention_column_ln`, `attention_row_ln`, `flash_attn`, `mlp_column_ln`, `mlp_row_ln`, and `global`.
-`global`: global configuration that applies to ALL operators.
The associated number, `skip_num`, determines how many times to bypass recomputation for the specified operation.
A `skip_num` of `-1` indicates no recomputation across all stages, maximizing memory usage;
A `skip_num` of `0` enforces recomputation at every stage, minimizing memory usage.
Expand Down Expand Up @@ -2029,6 +2030,7 @@ def is_context_parallel_supported():
"attention_column_ln": 0,
"mlp_column_ln": 0,
"flash_attn": 0,
"global": 0,
}
ops = self.refined_recompute.split(",")
enable_rr = False
Expand Down