Skip to content

Commit

Permalink
Revert "Auxiliary commit to revert individual files from f9f21c4"
Browse files Browse the repository at this point in the history
This reverts commit 1de0f3153db6072e9e5d4698dcf9dbcff000026f, reversing
changes made to c4fa929.
  • Loading branch information
fzyzcjy committed Jan 2, 2025
1 parent f9f21c4 commit 9f03572
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,9 @@ def __init__(self, config):
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload

# normalize config
self.config.ppo_mini_batch_size //= torch.distributed.get_world_size()
self.config.ppo_micro_batch_size //= torch.distributed.get_world_size()
self.config.forward_micro_batch_size //= torch.distributed.get_world_size()
config_normalize_batch_size(self.config, 'ppo_mini_batch_size', torch.distributed.get_world_size())
config_normalize_batch_size(self.config, 'ppo_micro_batch_size', torch.distributed.get_world_size())
config_normalize_batch_size(self.config, 'forward_micro_batch_size', torch.distributed.get_world_size())

def _build_critic_model_optimizer(self, config):
# the following line is necessary
Expand Down Expand Up @@ -576,7 +576,7 @@ def compute_values(self, data: DataProto):
load_fsdp_param_and_grad(module=self.critic_module,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
micro_batch_size = self.config.forward_micro_batch_size
micro_batch_size = self.config.forward_micro_batch_size_normalized
data.meta_info['micro_batch_size'] = micro_batch_size
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={'values': values})
Expand Down

0 comments on commit 9f03572

Please sign in to comment.