Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ISSUE] The Pull Request at https://github.com/FasterDecoding/Medusa/pull/97 from Narsil/medusa2 needs to be rolled back. #112

Open
super-ahn opened this issue Jul 11, 2024 · 0 comments

Comments

@super-ahn
Copy link

super-ahn commented Jul 11, 2024

Hello.

After fine-tuning the Medusa head, I discovered an issue affecting inference performance and would like to share my findings. Normally, when a model is trained correctly, using TGI to serve the model should reduce inference latency.

I followed the guide at https://github.com/huggingface/text-generation-inference/blob/main/docs/source/basic_tutorials/train_medusa.md for training. I used the same dataset as the guide, specifically ShareGPT_V4.3_unfiltered_cleaned_split.json, for self-distillation and then trained the Medusa head using the resulting dataset.

However, when I served the trained Medusa head using TGI (v2.0.4), I did not observe a reduction in inference latency compared to the original model. Upon examining the Mistral-7B-Instruct-v0.2-medusa model uploaded to Hugging Face at https://huggingface.co/text-generation-inference/, I noticed differences in the size of the medusa_lm_head.safetensors file and the contents of the config.json file.

Upon reviewing the Medusa code, I found that the architecture was changed by updating the Medusa head to v2 in the pull request at #97 by @Narsil.

When I reverted the architecture to its original form with an additional linear layer and re-trained it on the same dataset, serving it through TGI resulted in the expected latency reduction.

Although I am unsure why the latency did not decrease after changing the Medusa head architecture to v2 since TGI does not show any metrics related to speculative decoding efficiency, the current Medusa head v2 architecture has issues with inference performance.

A rollback to the main branch is necessary.

@super-ahn super-ahn reopened this Jul 11, 2024
@super-ahn super-ahn changed the title [ISSUE] The Pull Request at https://github.com/FasterDecoding/Medusa/pull/97 from Narsil/medusa2 should be rolled back. [ISSUE] The Pull Request at https://github.com/FasterDecoding/Medusa/pull/97 from Narsil/medusa2 needs to be rolled back. Jul 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant