Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiny refactor and cleanup #75

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ Critic Model

Most parameters for Critic are similar to Actor Model.

- ``critic.forward_micro_batch_size``: Micro batch size for forward-only computations.
This can be different from the forward-backward batch size, since forward-only usually
consumes less memory.

Reward Model
~~~~~~~~~~~~

Expand Down Expand Up @@ -317,6 +321,7 @@ Trainer
critic_warmup: 0
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} # hdfs checkpoint path
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} # local checkpoint path
val_before_train: True

- ``trainer.total_epochs``: Number of epochs in training.
- ``trainer.project_name``: For wandb
Expand All @@ -329,3 +334,6 @@ Trainer
- ``trainer.test_freq``: The validation frequency (by iteration).
- ``trainer.critic_warmup``: The number of iteration to train the critic
model before actual policy learning.
- ``trainer.default_hdfs_dir``: Default HDFS directory to use.
- ``trainer.default_local_dir``: Default local directory to use.
- ``trainer.val_before_train``: Whether to validate once before training starts.
2 changes: 1 addition & 1 deletion docs/start/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ Set the ``data.train_files`` ,\ ``data.val_files``, ``actor_rollout_ref.model.pa
critic.ppo_micro_batch_size=4 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=['console'] \
+trainer.val_before_train=False \
trainer.val_before_train=False \
trainer.default_hdfs_dir=null \
trainer.n_gpus_per_node=1 \
trainer.nnodes=1 \
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,4 @@ trainer:
critic_warmup: 0
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
val_before_train: True
fzyzcjy marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,4 @@ trainer:
critic_warmup: 0
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
val_before_train: True
21 changes: 10 additions & 11 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import os

from verl.utils.config import config_normalize_batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this import above L35, which would be tidier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(will do that in the next batch after discussions of other things)


os.environ['NCCL_DEBUG'] = 'WARN'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

Expand All @@ -29,7 +31,7 @@
import torch.distributed
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig
from transformers import AutoModelForCausalLM, PreTrainedModel, AutoConfig
from verl.utils.torch_functional import get_cosine_schedule_with_warmup
from tensordict import TensorDict
from torch.utils.data import DataLoader, DistributedSampler
Expand Down Expand Up @@ -83,11 +85,8 @@ def _normalize_config_bsz(self):
if self.device_mesh.get_rank() == 0:
print(f'Normalize batch size by dp {dp_size}')

assert self.config.data.train_batch_size % dp_size == 0
assert self.config.data.micro_batch_size % dp_size == 0

self.config.data.train_batch_size //= dp_size
self.config.data.micro_batch_size //= dp_size
config_normalize_batch_size(self.config.data, 'train_batch_size', dp_size)
config_normalize_batch_size(self.config.data, 'micro_batch_size', dp_size)

def _build_dataloader(self):
config = self.config
Expand Down Expand Up @@ -118,7 +117,7 @@ def _build_dataloader(self):
rank=rank,
drop_last=True)
self.train_dataloader = DataLoader(dataset=self.train_dataset,
batch_size=config.data.train_batch_size,
batch_size=config.data.train_batch_size_normalized,
sampler=self.train_sampler,
drop_last=True)

Expand All @@ -128,7 +127,7 @@ def _build_dataloader(self):
rank=rank,
drop_last=True)
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=config.data.micro_batch_size,
batch_size=config.data.micro_batch_size_normalized,
sampler=self.val_sampler,
drop_last=True)

Expand Down Expand Up @@ -254,7 +253,7 @@ def training_step(self, batch: TensorDict):

log_gpu_memory_usage('After optimizer zero_grad', logger=logger)

micro_batches = batch.split(self.config.data.micro_batch_size)
micro_batches = batch.split(self.config.data.micro_batch_size_normalized)
n_micro_batches = len(micro_batches)
for micro_batch in micro_batches:
loss = self._compute_loss(batch=micro_batch) / n_micro_batches
Expand Down Expand Up @@ -319,7 +318,7 @@ def fit(self):
for epoch in range(self.config.trainer.total_epochs):
self.train_sampler.set_epoch(epoch=epoch)
for data in self.train_dataloader:
data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda()
data = TensorDict(data, batch_size=self.config.data.train_batch_size_normalized).cuda()
metric = self.training_step(data)
if rank == 0:
tracking.log(data=metric, step=global_step)
Expand All @@ -328,7 +327,7 @@ def fit(self):
# validation
val_losses = []
for data in self.val_dataloader:
data = TensorDict(data, batch_size=self.config.data.micro_batch_size).cuda()
data = TensorDict(data, batch_size=self.config.data.micro_batch_size_normalized).cuda()
val_loss = self.validation_step(data)
val_losses.append(val_loss)
if rank == 0:
Expand Down
13 changes: 12 additions & 1 deletion verl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,21 @@

from typing import Dict

from omegaconf import DictConfig
from omegaconf import DictConfig, open_dict


def update_dict_with_config(dictionary: Dict, config: DictConfig):
for key in dictionary:
if hasattr(config, key):
dictionary[key] = getattr(config, key)


def config_normalize_batch_size(config, key: str, divider: int):
PeterSH6 marked this conversation as resolved.
Show resolved Hide resolved
value_raw = config[key]
assert value_raw % divider == 0
value_normalized = value_raw // divider

with open_dict(config):
del config[key]
config[f'{key}_raw'] = value_raw
config[f'{key}_normalized'] = value_normalized
8 changes: 4 additions & 4 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
"""
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
data = data.select(batch_keys=select_keys)
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size_normalized,
epochs=self.config.ppo_epochs,
dataloader_kwargs={'shuffle': self.config.shuffle})

Expand Down Expand Up @@ -113,16 +113,16 @@ def update_policy(self, data: DataProto):
# make sure we are in training mode
self.actor_module.train()

assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
assert self.config.ppo_mini_batch_size_normalized % self.config.ppo_micro_batch_size_normalized == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size_normalized // self.config.ppo_micro_batch_size_normalized
temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error

dataloader = self._make_minibatch_iterator(data=data)

metrics = {}
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
micro_batches = data.batch.split(self.config.ppo_micro_batch_size)
micro_batches = data.batch.split(self.config.ppo_micro_batch_size_normalized)

self.actor_optimizer.zero_grad()

Expand Down
4 changes: 2 additions & 2 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
"""
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
data = data.select(batch_keys=select_keys)
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size_normalized,
epochs=self.config.ppo_epochs,
dataloader_kwargs={'shuffle': self.config.shuffle})

Expand All @@ -232,7 +232,7 @@ def forward_backward_batch(self, data: DataProto, forward_only=False, post_proce
if data.meta_info.get('micro_batch_size', None) is not None:
batch_size = data.meta_info['micro_batch_size']
else:
batch_size = self.config.ppo_micro_batch_size
batch_size = self.config.ppo_micro_batch_size_normalized
batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size)
# compute input shapes for pp stages
input_shapes = compute_transformers_input_shapes(
Expand Down
8 changes: 4 additions & 4 deletions verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt
self.critic_module = critic_module
self.critic_optimizer = critic_optimizer

assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
assert self.config.ppo_mini_batch_size_normalized % self.config.ppo_micro_batch_size_normalized == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size_normalized // self.config.ppo_micro_batch_size_normalized

def _forward_micro_batch(self, micro_batch):
response_length = micro_batch['responses'].size(-1)
Expand All @@ -56,7 +56,7 @@ def _forward_micro_batch(self, micro_batch):
def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns']
data = data.select(batch_keys=select_keys)
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size_normalized,
epochs=self.config.ppo_epochs,
dataloader_kwargs={'shuffle': self.config.shuffle})

Expand Down Expand Up @@ -97,7 +97,7 @@ def update_critic(self, data: DataProto):

for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
micro_batches = data.batch.split(self.config.ppo_micro_batch_size)
micro_batches = data.batch.split(self.config.ppo_micro_batch_size_normalized)
self.critic_optimizer.zero_grad()

for data in micro_batches:
Expand Down
8 changes: 4 additions & 4 deletions verl/workers/critic/megatron_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def compute_values(self, data: DataProto) -> DataProto:
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns']
data = data.select(batch_keys=select_keys)
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size_normalized,
epochs=self.config.ppo_epochs,
dataloader_kwargs={'shuffle': self.config.shuffle})

Expand All @@ -118,7 +118,7 @@ def forward_backward_batch(self, data: DataProto, forward_only=False):
group=mpu.get_pipeline_model_parallel_group())
# split into micro-batches
data.batch['attention_mask'] = data.batch['attention_mask'].to(bool)
batches = split_dict_tensor_into_batches(data.batch, batch_size=self.config.ppo_micro_batch_size)
batches = split_dict_tensor_into_batches(data.batch, batch_size=self.config.ppo_micro_batch_size_normalized)
n_micro_batch = len(batches)
seq_len = batches[0]['input_ids'].shape[1]

Expand Down Expand Up @@ -182,7 +182,7 @@ def forward_step(batch_iter, model):
model=self.critic_module,
num_microbatches=n_micro_batch,
input_shapes=input_shapes, # must set for flash-attn sequence packing
seq_length=self.config.ppo_micro_batch_size * seq_len, # no use when input_shapes was set
seq_length=self.config.ppo_micro_batch_size_normalized * seq_len, # no use when input_shapes was set
hidden_size=self.model_config.hidden_size, # no use when input_shapes was set
micro_batch_size=1, # no use when input_shapes was set
forward_only=forward_only,
Expand All @@ -193,7 +193,7 @@ def forward_step(batch_iter, model):
data_iterator=batch_generator,
model=self.critic_module,
num_microbatches=n_micro_batch,
seq_length=self.config.ppo_micro_batch_size * seq_len, # in use for pp = 1
seq_length=self.config.ppo_micro_batch_size_normalized * seq_len, # in use for pp = 1
hidden_size=self.model_config.hidden_size, # in use for pp = 1
micro_batch_size=1, # in use for pp = 1
forward_only=forward_only,
Expand Down
23 changes: 12 additions & 11 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
from verl.utils import hf_tokenizer
from verl.utils.config import config_normalize_batch_size
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, offload_fsdp_grad, init_fn, get_init_weight_context_manager
Expand Down Expand Up @@ -79,12 +80,12 @@ def __init__(self, config: DictConfig, role: str):

# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size //= self.device_mesh.shape[0]
self.config.actor.ppo_micro_batch_size //= self.device_mesh.shape[0]
config_normalize_batch_size(self.config.actor, 'ppo_mini_batch_size', self.device_mesh.shape[0])
config_normalize_batch_size(self.config.actor, 'ppo_micro_batch_size', self.device_mesh.shape[0])
if self._is_rollout:
self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.shape[0]
config_normalize_batch_size(self.config.rollout, 'log_prob_micro_batch_size', self.device_mesh.shape[0])
if self._is_ref:
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.shape[0]
config_normalize_batch_size(self.config.ref, 'log_prob_micro_batch_size', self.device_mesh.shape[0])

def _build_model_optimizer(self,
model_path,
Expand Down Expand Up @@ -363,7 +364,7 @@ def generate_sequences(self, prompts: DataProto):

if self._is_actor and recompute_log_prob:
# we should always recompute old_log_probs when it is HybridEngine
output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size
output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_normalized
output.meta_info['temperature'] = self.config.rollout.temperature
old_log_probs = self.actor.compute_log_prob(data=output)
output.batch['old_log_probs'] = old_log_probs
Expand All @@ -389,7 +390,7 @@ def compute_ref_log_prob(self, data: DataProto):
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)

micro_batch_size = self.config.ref.log_prob_micro_batch_size
micro_batch_size = self.config.ref.log_prob_micro_batch_size_normalized
data.meta_info['micro_batch_size'] = micro_batch_size
data.meta_info['temperature'] = self.config.rollout.temperature
output = self.ref_policy.compute_log_prob(data=data)
Expand Down Expand Up @@ -445,9 +446,9 @@ def __init__(self, config):
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload

# normalize config
self.config.ppo_mini_batch_size //= torch.distributed.get_world_size()
self.config.ppo_micro_batch_size //= torch.distributed.get_world_size()
self.config.forward_micro_batch_size //= torch.distributed.get_world_size()
config_normalize_batch_size(self.config, 'ppo_mini_batch_size', torch.distributed.get_world_size())
config_normalize_batch_size(self.config, 'ppo_micro_batch_size', torch.distributed.get_world_size())
config_normalize_batch_size(self.config, 'forward_micro_batch_size', torch.distributed.get_world_size())

def _build_critic_model_optimizer(self, config):
# the following line is necessary
Expand Down Expand Up @@ -575,7 +576,7 @@ def compute_values(self, data: DataProto):
load_fsdp_param_and_grad(module=self.critic_module,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
micro_batch_size = self.config.forward_micro_batch_size
micro_batch_size = self.config.forward_micro_batch_size_normalized
data.meta_info['micro_batch_size'] = micro_batch_size
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={'values': values})
Expand Down Expand Up @@ -650,7 +651,7 @@ def __init__(self, config):
torch.distributed.init_process_group(backend="nccl")
self.config = config

self.config.micro_batch_size //= torch.distributed.get_world_size()
config_normalize_batch_size(self.config, 'micro_batch_size', torch.distributed.get_world_size())

def _build_model(self, config):
# the following line is necessary
Expand Down
Loading
Loading