Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate use_remove_padding when applying sequence parallelism #153

Merged
merged 7 commits into from
Jan 31, 2025

Conversation

chujiezheng
Copy link
Contributor

This is because the ulysses_sp is activated only when use_remove_padding is enabled:

if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
# pad and slice the inputs if sp > 1
if self.use_ulysses_sp:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \
position_ids_rmpad, \
sp_size=self.ulysses_sequence_parallel_size)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None,
self.ulysses_sequence_parallel_size)
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.actor_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
logits_rmpad.div_(temperature)
# compute entropy
entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad)
# if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
# gather log_prob if sp > 1
if self.use_ulysses_sp:
# gather and unpad for the ulysses sp
log_probs = gather_outpus_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size)
entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad,
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
# pad back to (bsz, seqlen)
full_entropy = pad_input(hidden_states=entropy_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
full_log_probs = pad_input(hidden_states=log_probs.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
# only return response part:
entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)

Without this check, users may encounter OOM issues when the set sp_size > 1 but use_remove_padding is mistakenly disabled.

@vermouth1992
Copy link
Collaborator

Could you format the code by running bash script/format.sh

@chujiezheng
Copy link
Contributor Author

Done!

@@ -86,4 +86,13 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus

# Check if use_remove_padding is enabled when using sequence parallelism
if config.actor_rollout_ref.actor.ulysses_sequence_parallel_size > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the correct key is config. actor_rollout_ref.model.use_remove_padding, critic.model.use_remove_padding and reward_model.model.use_remove_padding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed now!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, the key is still not correct :(

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chujiezheng The key should be actor_rollout_ref.model.use_remove_padding. The actor and ref are not necessary as they have the same model type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vermouth1992 @PeterSH6 Fixed now 🥹

@vermouth1992 vermouth1992 merged commit fb3793a into volcengine:main Jan 31, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants