Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mlx_lm)!: batch_generate #948

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

llllvvuu
Copy link
Contributor

@llllvvuu llllvvuu commented Aug 21, 2024

This is based on @willccbb's implementation at https://github.com/willccbb/mlx_parallm.

BREAKING CHANGE: generate_step takes (bs, seq_len) instead of (seq_len,). In particular, sampler and logits_processors will need to handle logits of shape (bs, vocab_size) instead of (vocab_size,).

@llllvvuu llllvvuu changed the title feat: support batch input in generate() feat(mlx_lm): support batch input in generate() Aug 21, 2024
@llllvvuu llllvvuu force-pushed the feat/batch_generate branch from 7332759 to 332a713 Compare August 21, 2024 05:20
@llllvvuu llllvvuu marked this pull request as draft August 21, 2024 05:22
@llllvvuu llllvvuu force-pushed the feat/batch_generate branch from 332a713 to 12c6066 Compare August 21, 2024 05:25
@llllvvuu llllvvuu marked this pull request as ready for review August 21, 2024 05:25
@llllvvuu llllvvuu force-pushed the feat/batch_generate branch from 12c6066 to ef92993 Compare August 21, 2024 05:45
@llllvvuu
Copy link
Contributor Author

llllvvuu commented Aug 26, 2024

Kind of interesting: for quantized models, the throughput is doesn't go up a lot between small bs (bs=1,2,3,4), but then it starts to go up a lot at higher bs, which is the opposite of what I expected intuitively. For unquantized models the throughput does goes up between small bs. I observe the same on @willccbb's original repo.

The `prompt` argument can now be either a `str` or `list[str]`.

The change to `generate()` is backwards-compatible.

The changes to `generate_step()`, `top_p_sampling()`, and
`min_p_sampling()` are backwards-incompatible in order to unify shapes;
this could be changed by adding a few if-statements, if preferred.
@llllvvuu llllvvuu force-pushed the feat/batch_generate branch from 5105b31 to 2caa832 Compare August 29, 2024 12:15
llms/mlx_lm/utils.py Outdated Show resolved Hide resolved
@awni
Copy link
Member

awni commented Aug 29, 2024

I think it makes sense to minimize the complexity to the generate function (which is becoming a bit spaghetti) to split out the batched generation into a separate function called batch_generate. I would simplify that function to have fewer arguments (like no formatter, no printing during generation, verbose only prints the timings (e.g. as you have it now).

Also maybe more tricky is the fact that I think for this to be correct, the causal masks need to consider the left padding in the input (please correct me if I'm wrong about that). This has two implications:

  1. Probably we'd need to add a mask parameter to the model __call__ functions and provide an appropriately constructed mask for the batch case.
  2. The Rotating KV cache will be broken in this case (it keeps the initial tokens which would be the padded tokens) and when rotates the mask would need to be updated to consider the padding (which is a bit complicated/tedious). In this case I may suggest disabling this option entirely..

Let me know what you think about the above.

@llllvvuu
Copy link
Contributor Author

I think it makes sense to minimize the complexity to the generate function (which is becoming a bit spaghetti) to split out the batched generation into a separate function called batch_generate. I would simplify that function to have fewer arguments (like no formatter, no printing during generation, verbose only prints the timings (e.g. as you have it now).

Makes sense to me, will implement.

Also maybe more tricky is the fact that I think for this to be correct, the causal masks need to consider the left padding in the input (please correct me if I'm wrong about that). This has two implications:

  1. Probably we'd need to add a mask parameter to the model __call__ functions and provide an appropriately constructed mask for the batch case.

Yes, this sounds straightforward enough.

  1. The Rotating KV cache will be broken in this case (it keeps the initial tokens which would be the padded tokens) and when rotates the mask would need to be updated to consider the padding (which is a bit complicated/tedious). In this case I may suggest disabling this option entirely..

I'll do a bit of thinking if there's an easy way to handle this, otherwise I'll remove that parameter in batch_generate.

Will update when these changes are ready!

@awni
Copy link
Member

awni commented Sep 27, 2024

@llllvvuu are you coming back to this?

@llllvvuu
Copy link
Contributor Author

hey @awni , sorry for the delay, I'd been job hunting this month. I should be able to get back to this in ~a week

@awni
Copy link
Member

awni commented Sep 28, 2024

No worries, just checking. I'll follow up in a week or so.

@llllvvuu llllvvuu force-pushed the feat/batch_generate branch from bea0c4b to 8fb82fe Compare October 9, 2024 19:13
@nath1295
Copy link

Just realised the attention mask has been mentioned in this PR, which is the reason I raised this issue #1044

TODO: Re-implement `batch_generate`
TODO: Update all `generate_step` callsites

NOTE: `generate_step` taking `(bs, seq_len)` instead of `(seq_len,)` is
a breaking change. In particular, `sampler` and `logits_processors` will
need to handle logits of shape `(bs, vocab_size)` instead of `(vocab_size,)`.
@llllvvuu llllvvuu marked this pull request as draft December 27, 2024 09:17
@llllvvuu llllvvuu changed the title feat(mlx_lm): support batch input in generate() feat(mlx_lm)!: batch_generate Dec 27, 2024
@llllvvuu llllvvuu marked this pull request as ready for review December 27, 2024 23:53
@llllvvuu
Copy link
Contributor Author

llllvvuu commented Dec 27, 2024

Sorry for the delay @awni . I took advantage of #1173 to update this PR. It is pending versioned release of ml-explore/mlx#1726 for the mask dtype.

I noticed one other potential issue: For absolute/rotary positional encodings, the position IDs of padded prompts won't start from 0 (this becomes more tricky if a padded prompt cache is added as then the position IDs should become non-contiguous IIUC). I'm not sure what the priority of this is or if it requires any change to mx.fast.rope.

@qinxuye
Copy link

qinxuye commented Jan 13, 2025

Any update? we do need this to support parallel generation.

@awni
Copy link
Member

awni commented Jan 13, 2025

Will get to this soon. Sorry for the delay.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants