Skip to content

[CB] add min batch size of 2 in decode #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 5, 2025
Merged

Conversation

nikolaospapandreou
Copy link
Collaborator

This PR adds support for minimum batch size of 2 in decode steps.

Signed-off-by: Nikolaos Papandreou <[email protected]>
Copy link

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, first install the linting requirements, then run format.sh and commit the changes. This can be done with uv directly:

uv sync --frozen --group lint --active --inexact

Or this can be done with pip:

uv pip compile --group lint > requirements-lint.txt
pip install -r requirements-lint.txt
bash format.sh

Now you are good to go 🚀

Signed-off-by: Nikolaos Papandreou <[email protected]>
Signed-off-by: Nikolaos Papandreou <[email protected]>
Copy link
Collaborator

@sducouedic sducouedic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (just a minor comment)
nice changes 👍

# free the blocks used for padding to minimum decode batch size of 2
if self.dummy_req_ids2blocks:
for freed_block in self.dummy_req_ids2blocks:
self.free_blocks.appendleft(freed_block)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think appendleft won't do exactly what you want. It will append the values from dummy_req_ids2blocks to the left, but in reverse order. e.g. [6, 7] --> appendleft --> [7, 6, 8, 9, 10, ...]. I would suggest to use append simply, just as you did in line 620. Or iterate the dummy_req_ids2blocks in reverse order (and do the same in 620)

d = self.BLOCK_SIZE
num_blocks = (n + d - 1) // d
for i in range(num_blocks - len(self.dummy_req_ids2blocks)):
self.dummy_req_ids2blocks.append(self.free_blocks.popleft())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this popping free blocks for every decode iteration of a batch with a single request? It looks like len(input_tokens) == 1 will always be true when there's only a single request running, since cached_requests comes from the scheduler and won't have our dummy request in it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh I see the num_blocks - len(self.dummy_req_ids2blocks) is only popping ones as the tkv increases 👍

@yannicks1 yannicks1 self-requested a review June 3, 2025 10:01
@@ -583,6 +583,7 @@ def __init__(
self.req_ids2left_pads: dict[str, int] = {}
self.tkv = 0
self.free_blocks = deque([i for i in range(NUM_BLOCKS)])
self.dummy_req_ids2blocks: list[int] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we simply use self.req_ids2blocks for this?
Just add an entry self.req_ids2blocks['dummy_req'] to it when needed and clean it up again afterwards. This would not only eliminate this additional variable, but potentially you could also reuse some of the code of freeing up the blocks again.

Comment on lines 630 to 635
# free the blocks used for padding to minimum decode batch size of 2
if self.dummy_req_ids2blocks:
for freed_block in self.dummy_req_ids2blocks:
self.free_blocks.appendleft(freed_block)
self.dummy_req_ids2blocks = []

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we need this cleanup here as well? self._update_states() is always called before self.prepare_model_input() (which invokes _prepare_prompt()). Would only doing the cleanup in one place be possible?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this would do the cleanup when the last requests finishes?

Comment on lines +789 to +791
input_tokens.append([0])
input_positions.append([self.tkv])
left_padded_prompt_mask.append(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if these have to make sense (align in some way) for fms/AIU Spyre compiler, but to be on the safe side, we could simply re-use the values from the real request here.

Copy link
Collaborator

@sducouedic sducouedic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not completely sure if this won't automatically be handled correctly in the subsequent _prepare_decode, but probably in reduce_left_padding we should also free the dummy blocks

n = self.tkv + 1
d = self.BLOCK_SIZE
num_blocks = (n + d - 1) // d
for i in range(num_blocks - len(self.dummy_req_ids2blocks)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i in range(num_blocks - len(self.dummy_req_ids2blocks)):
for _ in range(num_blocks - len(self.dummy_req_ids2blocks)):

@yannicks1
Copy link
Collaborator

I am not completely sure if this won't automatically be handled correctly in the subsequent _prepare_decode, but probably in reduce_left_padding we should also free the dummy blocks

good point! this will be solved when doing the removal of the left padding in every step (decode and prefill). I am taking care of this on this branch: https://github.com/vllm-project/vllm-spyre/tree/ysc-refactor-left-padding-stripping

@yannicks1 yannicks1 enabled auto-merge (squash) June 5, 2025 17:17
@github-actions github-actions bot added the ready label Jun 5, 2025
@yannicks1 yannicks1 disabled auto-merge June 5, 2025 17:40
@yannicks1 yannicks1 enabled auto-merge (squash) June 5, 2025 17:47
@yannicks1 yannicks1 merged commit 0d42959 into main Jun 5, 2025
21 checks passed
@yannicks1 yannicks1 deleted the npo-min-decode-batch-size-2 branch June 5, 2025 19:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants