-
Notifications
You must be signed in to change notification settings - Fork 523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix GuideLogitsProcessor for MPS device #1306
base: main
Are you sure you want to change the base?
Conversation
I think we'd want to see some benchmarks + tests on this one for sure -- I'm worried that this might have some unanticipated consequences. I'm curious if it's also possible to specify this by a keyword argument somewhere. |
As I cant't do a benchmark comparison on my MPS machine, would you mind doing that? Which tests in addition to the existing ones do you imagine?
Do you suggest here to run different code depending on the device in use or by a user setting? |
I believe this was introduced before MLX was integrated. Good find. Could you see what this does to the benchmarks on your Apple Silicon machine? |
This also fixes #1316 |
@lapp0 I obviously can't compare blocking vs non_blocking but with the fix, this is what I get:
I don't know what to make of the overall significant difference between patterns and especially the huge slowdown for torch_mps for pattern Z*. The latter could be related to MPS performing worse if the number of dims is small (see pytorch/pytorch#77799). |
Another idea: Wouldn't it be even better to remove the necessity of device synchronization by creating the with torch.device(mask.device):
for i, guide_state in enumerate(sequence_states):
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
allowed_tokens_batch.append(allowed_tokens)
batch_indices.append(
torch.full_like(allowed_tokens, i)
) # Store batch index for each allowed token This seems to be the recommended approach. |
The following seems to perform even better. Concat the indexing tensors in cpu and move only these to target device: allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
....
allowed_tokens_concat = torch.cat(allowed_tokens_batch).to(logits.device)
batch_indices_concat = torch.cat(batch_indices).to(logits.device)
This also removes the need to use non-blocking syncs. Will push that. |
While debugging #1282, I found the the issue to be caused by
non_blocking=True
on a Mac with an MPS device.The usage of
non_blocking=True
is only safe for CPU->GPU, as far as I understand from this guide.For other directions, especially CPU->MPS in my case, this results in a all-zero vector instead of the real token ids which generates wrong tokens and results in errors like described in #1282.
I also tried to use
torch.mps.synchronize()
, but it doesn't help.I'm haven't benchmarked the difference in speed, but I suspect it to be neglectable because the created vector is accessed directly afterwards.
Fixes #1282