-
Notifications
You must be signed in to change notification settings - Fork 933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(mlx_lm)!: batch_generate #948
base: main
Are you sure you want to change the base?
Conversation
generate()
generate()
7332759
to
332a713
Compare
332a713
to
12c6066
Compare
12c6066
to
ef92993
Compare
Kind of interesting: for quantized models, the throughput is doesn't go up a lot between small bs (bs=1,2,3,4), but then it starts to go up a lot at higher bs, which is the opposite of what I expected intuitively. For unquantized models the throughput does goes up between small bs. I observe the same on @willccbb's original repo. |
The `prompt` argument can now be either a `str` or `list[str]`. The change to `generate()` is backwards-compatible. The changes to `generate_step()`, `top_p_sampling()`, and `min_p_sampling()` are backwards-incompatible in order to unify shapes; this could be changed by adding a few if-statements, if preferred.
5105b31
to
2caa832
Compare
I think it makes sense to minimize the complexity to the Also maybe more tricky is the fact that I think for this to be correct, the causal masks need to consider the left padding in the input (please correct me if I'm wrong about that). This has two implications:
Let me know what you think about the above. |
Makes sense to me, will implement.
Yes, this sounds straightforward enough.
I'll do a bit of thinking if there's an easy way to handle this, otherwise I'll remove that parameter in Will update when these changes are ready! |
@llllvvuu are you coming back to this? |
hey @awni , sorry for the delay, I'd been job hunting this month. I should be able to get back to this in ~a week |
No worries, just checking. I'll follow up in a week or so. |
bea0c4b
to
8fb82fe
Compare
308ad24
to
9ee726c
Compare
Just realised the attention mask has been mentioned in this PR, which is the reason I raised this issue #1044 |
TODO: Re-implement `batch_generate` TODO: Update all `generate_step` callsites NOTE: `generate_step` taking `(bs, seq_len)` instead of `(seq_len,)` is a breaking change. In particular, `sampler` and `logits_processors` will need to handle logits of shape `(bs, vocab_size)` instead of `(vocab_size,)`.
generate()
Sorry for the delay @awni . I took advantage of #1173 to update this PR. It is pending versioned release of ml-explore/mlx#1726 for the mask dtype. I noticed one other potential issue: For absolute/rotary positional encodings, the position IDs of padded prompts won't start from 0 (this becomes more tricky if a padded prompt cache is added as then the position IDs should become non-contiguous IIUC). I'm not sure what the priority of this is or if it requires any change to |
Any update? we do need this to support parallel generation. |
Will get to this soon. Sorry for the delay. |
This is based on @willccbb's implementation at https://github.com/willccbb/mlx_parallm.
BREAKING CHANGE:
generate_step
takes(bs, seq_len)
instead of(seq_len,)
. In particular,sampler
andlogits_processors
will need to handle logits of shape(bs, vocab_size)
instead of(vocab_size,)
.