Skip to content

Commit

Permalink
[misc]fix: pad dataproto when pad size is larger than len(dataproto) (#…
Browse files Browse the repository at this point in the history
…150)

- As titled
- Solved: #149 

Waiting for testing from @chujiezheng

---------

Co-authored-by: Chi Zhang <[email protected]>
  • Loading branch information
PeterSH6 and vermouth1992 authored Jan 28, 2025
1 parent 9fca71d commit ab525bc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
14 changes: 14 additions & 0 deletions tests/utility/test_tensor_dict_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ def test_dataproto_pad_unpad():
assert (unpadd_data.non_tensor_batch['labels'] == labels).all()
assert unpadd_data.meta_info == {'info': 'test_info'}

padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=7)
assert pad_size == 4

expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])
expected_labels = ['a', 'b', 'c', 'a', 'b', 'c', 'a']
assert torch.all(torch.eq(padded_data.batch['obs'], expected_obs))
assert (padded_data.non_tensor_batch['labels'] == expected_labels).all()
assert padded_data.meta_info == {'info': 'test_info'}

unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)
assert torch.all(torch.eq(unpadd_data.batch['obs'], obs))
assert (unpadd_data.non_tensor_batch['labels'] == labels).all()
assert unpadd_data.meta_info == {'info': 'test_info'}


def test_dataproto_fold_unfold():
from verl.protocol import fold_batch_dim, unfold_batch_dim, DataProto
Expand Down
8 changes: 7 additions & 1 deletion verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int):
assert isinstance(data, DataProto), 'data must be a DataProto'
if len(data) % size_divisor != 0:
pad_size = size_divisor - len(data) % size_divisor
data_padded = DataProto.concat([data, data[:pad_size]])
padding_protos = []
remaining_pad = pad_size
while remaining_pad > 0:
take_size = min(remaining_pad, len(data))
padding_protos.append(data[:take_size])
remaining_pad -= take_size
data_padded = DataProto.concat([data] + padding_protos)
else:
pad_size = 0
data_padded = data
Expand Down

0 comments on commit ab525bc

Please sign in to comment.