Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiny refactor and cleanup #75

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more
fzyzcjy committed Jan 2, 2025
commit 45fd26c0fd4d40159c5bb1181a307b9670dcf352
18 changes: 9 additions & 9 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
@@ -80,12 +80,12 @@ def __init__(self, config: DictConfig, role: str):

# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size //= self.device_mesh.shape[0]
self.config.actor.ppo_micro_batch_size //= self.device_mesh.shape[0]
config_normalize_batch_size(self.config.actor, 'ppo_mini_batch_size', self.device_mesh.shape[0])
config_normalize_batch_size(self.config.actor, 'ppo_micro_batch_size', self.device_mesh.shape[0])
if self._is_rollout:
self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.shape[0]
config_normalize_batch_size(self.config.rollout, 'log_prob_micro_batch_size', self.device_mesh.shape[0])
if self._is_ref:
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.shape[0]
config_normalize_batch_size(self.config.ref, 'log_prob_micro_batch_size', self.device_mesh.shape[0])

def _build_model_optimizer(self,
model_path,
@@ -446,9 +446,9 @@ def __init__(self, config):
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload

# normalize config
self.config.ppo_mini_batch_size //= torch.distributed.get_world_size()
self.config.ppo_micro_batch_size //= torch.distributed.get_world_size()
self.config.forward_micro_batch_size //= torch.distributed.get_world_size()
config_normalize_batch_size(self.config, 'ppo_mini_batch_size', torch.distributed.get_world_size())
config_normalize_batch_size(self.config, 'ppo_micro_batch_size', torch.distributed.get_world_size())
config_normalize_batch_size(self.config, 'forward_micro_batch_size', torch.distributed.get_world_size())

def _build_critic_model_optimizer(self, config):
# the following line is necessary
@@ -576,7 +576,7 @@ def compute_values(self, data: DataProto):
load_fsdp_param_and_grad(module=self.critic_module,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
micro_batch_size = self.config.forward_micro_batch_size
micro_batch_size = self.config.forward_micro_batch_size_normalized
data.meta_info['micro_batch_size'] = micro_batch_size
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={'values': values})
@@ -651,7 +651,7 @@ def __init__(self, config):
torch.distributed.init_process_group(backend="nccl")
self.config = config

self.config.micro_batch_size //= torch.distributed.get_world_size()
config_normalize_batch_size(self.config, 'micro_batch_size', torch.distributed.get_world_size())

def _build_model(self, config):
# the following line is necessary
9 changes: 5 additions & 4 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
import torch.nn as nn
from omegaconf import DictConfig
from verl.single_controller.base.megatron.worker import MegatronWorker
from verl.utils.config import config_normalize_batch_size
from verl.workers.actor.megatron_actor import MegatronPPOActor
from verl.workers.critic.megatron_critic import MegatronPPOCritic
from verl.workers.hybrid_engine import AllGatherPPModel
@@ -111,14 +112,14 @@ def __init__(self, config: DictConfig, role: str):

# normalize config
if self._is_actor and self._is_rollout:
self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()
self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()
self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
config_normalize_batch_size(self.config.actor, 'ppo_mini_batch_size', mpu.get_data_parallel_world_size())
config_normalize_batch_size(self.config.actor, 'ppo_micro_batch_size', mpu.get_data_parallel_world_size())
config_normalize_batch_size(self.config.rollout, 'log_prob_micro_batch_size', mpu.get_data_parallel_world_size())
self._is_offload_param = self.config.actor.get('param_offload', False)
self._is_offload_grad = self.config.actor.get('grad_offload', False)
self._is_offload_optimizer = self.config.actor.get('optimizer_offload', False)
elif self._is_ref:
self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
config_normalize_batch_size(self.config.ref, 'log_prob_micro_batch_size', mpu.get_data_parallel_world_size())
self._is_offload_param = self.config.ref.get('param_offload', False)

def _build_model_optimizer(self,