-
Notifications
You must be signed in to change notification settings - Fork 21
[CB] add min batch size of 2 in decode #182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Nikolaos Papandreou <[email protected]>
👋 Hi! Thank you for contributing to vLLM support on Spyre.
Or this can be done with
Now you are good to go 🚀 |
Signed-off-by: Nikolaos Papandreou <[email protected]>
Signed-off-by: Nikolaos Papandreou <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (just a minor comment)
nice changes 👍
# free the blocks used for padding to minimum decode batch size of 2 | ||
if self.dummy_req_ids2blocks: | ||
for freed_block in self.dummy_req_ids2blocks: | ||
self.free_blocks.appendleft(freed_block) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think appendleft
won't do exactly what you want. It will append the values from dummy_req_ids2blocks
to the left, but in reverse order. e.g. [6, 7] --> appendleft --> [7, 6, 8, 9, 10, ...]
. I would suggest to use append
simply, just as you did in line 620. Or iterate the dummy_req_ids2blocks
in reverse order (and do the same in 620)
d = self.BLOCK_SIZE | ||
num_blocks = (n + d - 1) // d | ||
for i in range(num_blocks - len(self.dummy_req_ids2blocks)): | ||
self.dummy_req_ids2blocks.append(self.free_blocks.popleft()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this popping free blocks for every decode iteration of a batch with a single request? It looks like len(input_tokens) == 1
will always be true when there's only a single request running, since cached_requests
comes from the scheduler and won't have our dummy request in it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahhh I see the num_blocks - len(self.dummy_req_ids2blocks)
is only popping ones as the tkv increases 👍
@@ -583,6 +583,7 @@ def __init__( | |||
self.req_ids2left_pads: dict[str, int] = {} | |||
self.tkv = 0 | |||
self.free_blocks = deque([i for i in range(NUM_BLOCKS)]) | |||
self.dummy_req_ids2blocks: list[int] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't we simply use self.req_ids2blocks
for this?
Just add an entry self.req_ids2blocks['dummy_req']
to it when needed and clean it up again afterwards. This would not only eliminate this additional variable, but potentially you could also reuse some of the code of freeing up the blocks again.
# free the blocks used for padding to minimum decode batch size of 2 | ||
if self.dummy_req_ids2blocks: | ||
for freed_block in self.dummy_req_ids2blocks: | ||
self.free_blocks.appendleft(freed_block) | ||
self.dummy_req_ids2blocks = [] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why we need this cleanup here as well? self._update_states()
is always called before self.prepare_model_input()
(which invokes _prepare_prompt()
). Would only doing the cleanup in one place be possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this would do the cleanup when the last requests finishes?
input_tokens.append([0]) | ||
input_positions.append([self.tkv]) | ||
left_padded_prompt_mask.append(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if these have to make sense (align in some way) for fms/AIU Spyre compiler, but to be on the safe side, we could simply re-use the values from the real request here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not completely sure if this won't automatically be handled correctly in the subsequent _prepare_decode
, but probably in reduce_left_padding
we should also free the dummy blocks
n = self.tkv + 1 | ||
d = self.BLOCK_SIZE | ||
num_blocks = (n + d - 1) // d | ||
for i in range(num_blocks - len(self.dummy_req_ids2blocks)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i in range(num_blocks - len(self.dummy_req_ids2blocks)): | |
for _ in range(num_blocks - len(self.dummy_req_ids2blocks)): |
good point! this will be solved when doing the removal of the left padding in every step (decode and prefill). I am taking care of this on this branch: https://github.com/vllm-project/vllm-spyre/tree/ysc-refactor-left-padding-stripping |
Signed-off-by: Nikolaos Papandreou <[email protected]>
This PR adds support for minimum batch size of 2 in decode steps.