-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[example] add a split placement tutorial (#43)
* [example] add a split placement tutorial * lint
- Loading branch information
Showing
6 changed files
with
595 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Split Placement Example | ||
Here we introduce how to run the naive implementation of the split placement of PPO algorithm. | ||
We will release the complete version of flexible placement in the near future. | ||
|
||
For quickstart, you can only follow Step 2 to modify the code and then follow Step 4 to execute the split placement example. | ||
|
||
### Step 1: Placing the models to different GPUs | ||
Specify the placement and resource allocation. In the example, we place the actor and reference in the first half of the GPUs while map the critic and reward model (if any) to the second half of the GPUs. | ||
```python | ||
actor_rollout_ref_pool_id = 'actor_rollout_ref_pool' | ||
critic_pool_id = 'critic_pool' | ||
if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0: | ||
resource_pool_spec = { | ||
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, | ||
critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, | ||
} | ||
else: | ||
resource_pool_spec = { | ||
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), | ||
critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), | ||
} | ||
print(f'resource_pool_spec: {resource_pool_spec}') | ||
mapping = { | ||
Role.ActorRollout: actor_rollout_ref_pool_id, | ||
Role.Critic: critic_pool_id, | ||
Role.RefPolicy: actor_rollout_ref_pool_id, | ||
} | ||
mapping[Role.RewardModel] = critic_pool_id | ||
``` | ||
|
||
### Step 2: Make the models executed asynchronously | ||
Based on the model placement, we need to make the models executed asynchronously. | ||
|
||
To do so, you need to turn off the `blocking` flag (i.e., `blocking=False`) in our decorator of some model operations. | ||
For example, we hope the actor update and critic update can be executed in parallel, then we need to make the following modification in `fsdp_workers.py` | ||
|
||
``` | ||
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) | ||
def update_actor(self, data: DataProto): | ||
... | ||
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) | ||
def update_critic(self, data: DataProto): | ||
... | ||
``` | ||
|
||
We can also parallelize the computation of `ref_log_prob` and `values` and `rewards` in the split placement. For simplicity of the tutorial, we | ||
|
||
### Step 3: Execute these operation in parallel in the single controller process | ||
To implement the parallel execution of the actor and critic update, the only thing we need to modify in the `ray_trainer.py` is to `get` the concurrent `futures` on the single controller process. | ||
|
||
```python | ||
critic_output = critic_output.get() | ||
actor_output = actor_output.get() | ||
``` | ||
|
||
### Step 4: Run the split placement example | ||
|
||
``` | ||
bash run_deepseek7b_llm.sh | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
data: | ||
tokenizer: null | ||
train_files: ~/data/rlhf/gsm8k/train.parquet | ||
val_files: ~/data/rlhf/gsm8k/test.parquet | ||
prompt_key: prompt | ||
max_prompt_length: 512 | ||
max_response_length: 512 | ||
train_batch_size: 1024 | ||
val_batch_size: 1312 | ||
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs | ||
return_raw_chat: False | ||
|
||
actor_rollout_ref: | ||
hybrid_engine: True | ||
model: | ||
path: ~/models/deepseek-llm-7b-chat | ||
external_lib: null | ||
override_config: {} | ||
enable_gradient_checkpointing: False | ||
actor: | ||
strategy: fsdp # This is for backward-compatibility | ||
ppo_mini_batch_size: 256 | ||
ppo_micro_batch_size: 64 | ||
grad_clip: 1.0 | ||
clip_ratio: 0.2 | ||
entropy_coeff: 0.001 | ||
ppo_epochs: 1 | ||
shuffle: True | ||
optim: | ||
lr: 1e-6 | ||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime | ||
min_lr_ratio: null # only useful for warmup with cosine | ||
warmup_style: constant # select from constant/cosine | ||
total_training_steps: -1 # must be override by program | ||
fsdp_config: | ||
wrap_policy: | ||
# transformer_layer_cls_to_wrap: None | ||
min_num_params: 0 | ||
param_offload: False | ||
grad_offload: False | ||
optimizer_offload: False | ||
ref: | ||
fsdp_config: | ||
param_offload: False | ||
wrap_policy: | ||
# transformer_layer_cls_to_wrap: None | ||
min_num_params: 0 | ||
log_prob_micro_batch_size: 128 | ||
rollout: | ||
name: vllm | ||
temperature: 1.0 | ||
top_k: -1 # 0 for hf rollout, -1 for vllm rollout | ||
top_p: 1 | ||
prompt_length: ${data.max_prompt_length} # not use for opensource | ||
response_length: ${data.max_response_length} | ||
# for vllm rollout | ||
dtype: bfloat16 # should align with FSDP | ||
gpu_memory_utilization: 0.5 | ||
ignore_eos: False | ||
enforce_eager: True | ||
free_cache_engine: True | ||
load_format: dummy_dtensor | ||
tensor_model_parallel_size: 2 | ||
max_num_batched_tokens: 8192 | ||
max_num_seqs: 1024 | ||
log_prob_micro_batch_size: 128 | ||
# for hf rollout | ||
do_sample: True | ||
|
||
critic: | ||
strategy: fsdp | ||
optim: | ||
lr: 1e-5 | ||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime | ||
min_lr_ratio: null # only useful for warmup with cosine | ||
warmup_style: constant # select from constant/cosine | ||
total_training_steps: -1 # must be override by program | ||
model: | ||
path: ~/models/deepseek-llm-7b-chat | ||
tokenizer_path: ${actor_rollout_ref.model.path} | ||
override_config: {} | ||
external_lib: ${actor_rollout_ref.model.external_lib} | ||
enable_gradient_checkpointing: False | ||
fsdp_config: | ||
param_offload: False | ||
grad_offload: False | ||
optimizer_offload: False | ||
wrap_policy: | ||
# transformer_layer_cls_to_wrap: None | ||
min_num_params: 0 | ||
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} | ||
ppo_micro_batch_size: 64 | ||
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} | ||
shuffle: ${actor_rollout_ref.actor.shuffle} | ||
grad_clip: 1.0 | ||
cliprange_value: 0.5 | ||
|
||
reward_model: | ||
enable: False | ||
strategy: fsdp | ||
model: | ||
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical | ||
path: ~/models/FsfairX-LLaMA3-RM-v0.1 | ||
external_lib: ${actor_rollout_ref.model.external_lib} | ||
fsdp_config: | ||
min_num_params: 0 | ||
param_offload: False | ||
micro_batch_size: 64 | ||
max_length: null | ||
|
||
algorithm: | ||
gamma: 1.0 | ||
lam: 1.0 | ||
adv_estimator: gae | ||
kl_penalty: kl # how to estimate kl divergence | ||
kl_ctrl: | ||
type: fixed | ||
kl_coef: 0.001 | ||
|
||
trainer: | ||
total_epochs: 30 | ||
project_name: verl_examples | ||
experiment_name: gsm8k | ||
logger: ['console', 'tracking'] | ||
nnodes: 1 | ||
n_gpus_per_node: 8 | ||
save_freq: -1 | ||
test_freq: 2 | ||
critic_warmup: 0 | ||
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} | ||
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
# Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. | ||
""" | ||
|
||
from verl import DataProto | ||
import torch | ||
from verl.utils.reward_score import gsm8k, math | ||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer | ||
|
||
|
||
def _select_rm_score_fn(data_source): | ||
if data_source == 'openai/gsm8k': | ||
return gsm8k.compute_score | ||
elif data_source == 'lighteval/MATH': | ||
return math.compute_score | ||
else: | ||
raise NotImplementedError | ||
|
||
|
||
class RewardManager(): | ||
|
||
def __init__(self, tokenizer, num_examine) -> None: | ||
self.tokenizer = tokenizer | ||
self.num_examine = num_examine # the number of batches of decoded responses to print to the console | ||
|
||
def __call__(self, data: DataProto): | ||
"""We will expand this function gradually based on the available datasets""" | ||
|
||
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn | ||
if 'rm_scores' in data.batch.keys(): | ||
return data.batch['rm_scores'] | ||
|
||
reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) | ||
|
||
already_print_data_sources = {} | ||
|
||
for i in range(len(data)): | ||
data_item = data[i] # DataProtoItem | ||
|
||
prompt_ids = data_item.batch['prompts'] | ||
|
||
prompt_length = prompt_ids.shape[-1] | ||
|
||
valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() | ||
valid_prompt_ids = prompt_ids[-valid_prompt_length:] | ||
|
||
response_ids = data_item.batch['responses'] | ||
valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() | ||
valid_response_ids = response_ids[:valid_response_length] | ||
|
||
# decode | ||
sequences = torch.cat((valid_prompt_ids, valid_response_ids)) | ||
sequences_str = self.tokenizer.decode(sequences) | ||
|
||
ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] | ||
|
||
# select rm_score | ||
data_source = data_item.non_tensor_batch['data_source'] | ||
compute_score_fn = _select_rm_score_fn(data_source) | ||
|
||
score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth) | ||
reward_tensor[i, valid_response_length - 1] = score | ||
|
||
if data_source not in already_print_data_sources: | ||
already_print_data_sources[data_source] = 0 | ||
|
||
if already_print_data_sources[data_source] < self.num_examine: | ||
already_print_data_sources[data_source] += 1 | ||
print(sequences_str) | ||
|
||
return reward_tensor | ||
|
||
|
||
import ray | ||
import hydra | ||
from split_monkey_patch import fit | ||
|
||
|
||
@hydra.main(config_path='config', config_name='ppo_trainer_split', version_base=None) | ||
def main(config): | ||
if not ray.is_initialized(): | ||
# this is for local ray cluster | ||
ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}}) | ||
|
||
ray.get(main_task.remote(config)) | ||
|
||
|
||
@ray.remote | ||
def main_task(config): | ||
from verl.utils.fs import copy_local_path_from_hdfs | ||
from transformers import AutoTokenizer | ||
|
||
# print initial config | ||
from pprint import pprint | ||
from omegaconf import OmegaConf | ||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values | ||
OmegaConf.resolve(config) | ||
|
||
# download the checkpoint from hdfs | ||
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) | ||
|
||
# instantiate tokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(local_path) | ||
from verl.utils import set_pad_token_id | ||
set_pad_token_id(tokenizer) | ||
|
||
# define worker classes | ||
if config.actor_rollout_ref.actor.strategy == 'fsdp': | ||
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy | ||
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker | ||
from single_controller.ray import RayWorkerGroup | ||
ray_worker_group_cls = RayWorkerGroup | ||
|
||
elif config.actor_rollout_ref.actor.strategy == 'megatron': | ||
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy | ||
from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker | ||
from single_controller.ray.megatron import NVMegatronRayWorkerGroup | ||
ray_worker_group_cls = NVMegatronRayWorkerGroup | ||
|
||
else: | ||
raise NotImplementedError | ||
|
||
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role | ||
|
||
role_worker_mapping = { | ||
Role.ActorRollout: ActorRolloutRefWorker, | ||
Role.Critic: CriticWorker, | ||
Role.RefPolicy: ActorRolloutRefWorker | ||
} | ||
|
||
# NOTE: initialze two resource pool | ||
actor_rollout_ref_pool_id = 'actor_rollout_ref_pool' | ||
critic_pool_id = 'critic_pool' | ||
if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0: | ||
resource_pool_spec = { | ||
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, | ||
critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, | ||
} | ||
else: | ||
resource_pool_spec = { | ||
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), | ||
critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), | ||
} | ||
print(f'resource_pool_spec: {resource_pool_spec}') | ||
mapping = { | ||
Role.ActorRollout: actor_rollout_ref_pool_id, | ||
Role.Critic: critic_pool_id, | ||
Role.RefPolicy: actor_rollout_ref_pool_id, | ||
} | ||
|
||
# we should adopt a multi-source reward function here | ||
# - for rule-based rm, we directly call a reward score | ||
# - for model-based rm, we call a model | ||
# - for code related prompt, we send to a sandbox if there are test cases | ||
# - finally, we combine all the rewards together | ||
# - The reward type depends on the tag of the data | ||
if config.reward_model.enable: | ||
if config.reward_model.strategy == 'fsdp': | ||
from verl.trainer.ppo.workers.fsdp_workers import RewardModelWorker | ||
elif config.reward_model.strategy == 'megatron': | ||
from verl.trainer.ppo.workers.megatron_workers import RewardModelWorker | ||
else: | ||
raise NotImplementedError | ||
role_worker_mapping[Role.RewardModel] = RewardModelWorker | ||
mapping[Role.RewardModel] = critic_pool_id | ||
|
||
reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0) | ||
|
||
# Note that we always use function-based RM for validation | ||
val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1) | ||
|
||
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) | ||
|
||
RayPPOTrainer.fit = fit | ||
trainer = RayPPOTrainer(config=config, | ||
tokenizer=tokenizer, | ||
role_worker_mapping=role_worker_mapping, | ||
resource_pool_manager=resource_pool_manager, | ||
ray_worker_group_cls=ray_worker_group_cls, | ||
reward_fn=reward_fn, | ||
val_reward_fn=val_reward_fn) | ||
trainer.init_workers() | ||
trainer.fit() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.