Skip to content

Commit 41f645d

Browse files
authored
[misc] feat: support mfu calculation (#117)
1 parent 1ec5eb5 commit 41f645d

File tree

4 files changed

+155
-2
lines changed

4 files changed

+155
-2
lines changed

verl/trainer/config/ppo_trainer.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ actor_rollout_ref:
4141
param_offload: False
4242
grad_offload: False
4343
optimizer_offload: False
44+
fsdp_size: -1
4445
ref:
4546
fsdp_config:
4647
param_offload: False
4748
wrap_policy:
4849
# transformer_layer_cls_to_wrap: None
4950
min_num_params: 0
51+
fsdp_size: -1
5052
log_prob_micro_batch_size: 128
5153
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
5254
rollout:
@@ -94,6 +96,7 @@ critic:
9496
wrap_policy:
9597
# transformer_layer_cls_to_wrap: None
9698
min_num_params: 0
99+
fsdp_size: -1
97100
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
98101
ppo_micro_batch_size: 64
99102
forward_micro_batch_size: ${critic.ppo_micro_batch_size}

verl/trainer/ppo/ray_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,9 @@ def fit(self):
559559
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
560560
batch = batch.union(gen_batch_output)
561561

562+
# compute global_valid tokens
563+
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
564+
562565
if self.use_reference_policy:
563566
# compute reference log_prob
564567
with _timer('ref', timing_raw):

verl/utils/flops_counter.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from transformers import PretrainedConfig, Qwen2Config, LlamaConfig
17+
18+
VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig)
19+
20+
21+
def get_device_flops(unit="T"):
22+
23+
def unit_convert(number, level):
24+
units = ["B", "K", "M", "G", "T", "P"]
25+
if number <= 0:
26+
return number
27+
ptr = 0
28+
while ptr < len(units) and units[ptr] != level:
29+
number /= 1000
30+
ptr += 1
31+
return number
32+
33+
device_name = torch.cuda.get_device_name()
34+
flops = float("inf") # INF flops for unkown gpu type
35+
if "H100" in device_name or "H800" in device_name:
36+
flops = 989e12
37+
elif "A100" in device_name or "A800" in device_name:
38+
flops = 312e12
39+
elif "L40" in device_name:
40+
flops = 181.05e12
41+
elif "L20" in device_name:
42+
flops = 119.5e12
43+
elif "H20" in device_name:
44+
flops = 148e12
45+
elif "910B" in device_name:
46+
flops = 354e12
47+
flops_unit = unit_convert(flops, unit)
48+
return flops_unit
49+
50+
51+
class FlopsCounter:
52+
"""
53+
Used to count mfu during training loop
54+
55+
Example:
56+
flops_counter = FlopsCounter(config)
57+
flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
58+
59+
"""
60+
61+
def __init__(self, config: PretrainedConfig):
62+
if not isinstance(config, VALID_CONFIG_TYPE):
63+
print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {type(config)}. "
64+
f"MFU will always be zero.")
65+
66+
self.estimate_func = {"qwen2": self._estimate_qwen2_flops, 'llama': self._estimate_qwen2_flops}
67+
self.config = config
68+
69+
def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):
70+
return 0
71+
72+
def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
73+
assert isinstance(self.config, (Qwen2Config, LlamaConfig))
74+
hidden_size = self.config.hidden_size
75+
vocab_size = self.config.vocab_size
76+
num_hidden_layers = self.config.num_hidden_layers
77+
num_key_value_heads = self.config.num_key_value_heads
78+
num_attention_heads = self.config.num_attention_heads
79+
intermediate_size = self.config.intermediate_size
80+
81+
head_dim = hidden_size // num_attention_heads
82+
q_size = num_attention_heads * head_dim
83+
k_size = num_key_value_heads * head_dim
84+
v_size = num_key_value_heads * head_dim
85+
86+
# non-attn per layer parm
87+
# Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
88+
mlp_N = hidden_size * intermediate_size * 3
89+
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
90+
emd_and_lm_head_N = vocab_size * hidden_size * 2
91+
# non-attn all_layer parm
92+
dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
93+
# non-attn all_layer & all_token fwd & bwd flops
94+
dense_N_flops = 6 * dense_N * tokens_sum
95+
96+
# attn all_layer & all_token fwd & bwd flops
97+
seqlen_square_sum = 0
98+
for seqlen in batch_seqlens:
99+
seqlen_square_sum += seqlen * seqlen
100+
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
101+
102+
# all_layer & all_token fwd & bwd flops
103+
flops_all_token = dense_N_flops + attn_qkv_flops
104+
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
105+
return flops_achieved
106+
107+
def estimate_flops(self, batch_seqlens, delta_time):
108+
"""
109+
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
110+
111+
Args:
112+
batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch.
113+
delta_time (float): The time taken to process the batch, in seconds.
114+
115+
Returns:
116+
estimated_flops (float): The estimated FLOPS based on the input tokens and time.
117+
promised_flops (float): The expected FLOPS of the current device.
118+
"""
119+
tokens_sum = sum(batch_seqlens)
120+
func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
121+
estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
122+
promised_flops = get_device_flops()
123+
return estimated_flops, promised_flops

verl/workers/fsdp_workers.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@
3535
load_fsdp_param_and_grad
3636
from verl.utils.import_utils import import_external_libs
3737
from verl.utils.model import compute_position_id_with_mask
38+
from verl.utils.flops_counter import FlopsCounter
3839
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
3940

41+
from codetiming import Timer
42+
4043
logger = logging.getLogger(__file__)
4144
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
4245

@@ -341,6 +344,9 @@ def init_model(self):
341344
self.config.ref.use_remove_padding = use_remove_padding
342345
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
343346

347+
if self._is_actor:
348+
self.flops_counter = FlopsCounter(self.actor_model_config)
349+
344350
torch.cuda.empty_cache()
345351

346352
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@@ -362,7 +368,13 @@ def update_actor(self, data: DataProto):
362368
with self.ulysses_sharding_manager:
363369
data = self.ulysses_sharding_manager.preprocess_data(data=data)
364370
# perform training
365-
metrics = self.actor.update_policy(data=data)
371+
with Timer(name='update_policy', logger=None) as timer:
372+
metrics = self.actor.update_policy(data=data)
373+
delta_time = timer.last
374+
global_num_tokens = data.meta_info['global_token_num']
375+
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
376+
metrics['mfu/actor'] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
377+
366378
self.actor_lr_scheduler.step()
367379
lr = self.actor_lr_scheduler.get_last_lr()[0]
368380
metrics['actor/lr'] = lr
@@ -580,6 +592,8 @@ def _build_critic_model_optimizer(self, config):
580592
if self.rank == 0:
581593
print_model_size(critic_module)
582594

595+
self.critic_model_config = critic_model_config
596+
583597
fsdp_config = self.config.model.fsdp_config
584598
mixed_precision_config = fsdp_config.get('mixed_precision', None)
585599
if mixed_precision_config is not None:
@@ -643,6 +657,9 @@ def init_model(self):
643657
self.critic = DataParallelPPOCritic(config=self.config,
644658
critic_module=self.critic_module,
645659
critic_optimizer=self.critic_optimizer)
660+
661+
self.flops_counter = FlopsCounter(self.critic_model_config)
662+
646663
torch.cuda.empty_cache()
647664

648665
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
@@ -681,7 +698,14 @@ def update_critic(self, data: DataProto):
681698
# perform forward computation
682699
with self.ulysses_sharding_manager:
683700
data = self.ulysses_sharding_manager.preprocess_data(data=data)
684-
metrics = self.critic.update_critic(data=data)
701+
702+
with Timer(name='update_critic', logger=None) as timer:
703+
metrics = self.critic.update_critic(data=data)
704+
delta_time = timer.last
705+
706+
global_num_tokens = data.meta_info['global_token_num']
707+
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
708+
metrics['mfu/critic'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
685709

686710
self.critic_lr_scheduler.step()
687711
lr = self.critic_lr_scheduler.get_last_lr()[0]

0 commit comments

Comments
 (0)