Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comparison with SWA in Mistral #24

Open
casper-hansen opened this issue Oct 6, 2023 · 12 comments
Open

Comparison with SWA in Mistral #24

casper-hansen opened this issue Oct 6, 2023 · 12 comments

Comments

@casper-hansen
Copy link

Hi @Guangxuan-Xiao, do you have any comparison with sliding window attention from Mistral? The paper only describes SWA with re-computation which is not how it works in the new models.

Sliding Window with Re-computation rebuilds the KV states from the L recent tokens for each new token.

Basically, this is not what they do in the Mistral model. They do not rebuild the KV states, they evict the oldest part of the cache in favor of the newest parts.

@Guangxuan-Xiao
Copy link
Collaborator

Hi, please check my explanation at #33 (comment), and let me know if you have any further questions!

@verlocks
Copy link

Hi @Guangxuan-Xiao, thanks for your explanation! However, it seems you didn't mention SWA in Mistral model? In Mistral model, it utilized Sliding Window Attention when inferencing and I believe it doesn't recompute during inference, and I am wondering why it can achieve this, because in your paper, the performance of model degenerates when using Window Attention.

I am currently thinking maybe it is because Mistral model was trained with Sliding Window Attention, and in result it avoided the attention sink phenomenon. (In one of their issue, this is asked but not answered yet)

@tomaarsen
Copy link
Contributor

For reference, the Mistral model degrades in performance over time just like dense attention methods:
272347418-3a4c5634-cc1b-42d1-a35a-afb376a4f970
Here, attention_sinks refers to the StreamingLLM approach, transformers is their model used via the transformers library, and windowed is simple window attention with position ID shifting.

Furthermore, when giving it subsequent prompts (160 prompts in a row):
274319361-987513d9-75d6-41e6-96a5-5d47624faed3

Note

The automatic detection of fluency losses is very naive: it tries to count the number of real words in the response, but that can result in false positives if e.g. the prompt is to generate some German text. See demo/streaming_logs for the full logs to get a better picture of the real generative performance.

E.g. Mistral for transformers and attention_sinks - it's a big difference after like 250 lines.

@hmzo
Copy link

hmzo commented Oct 13, 2023

In my opinion, the "sliding window attention" mentioned in Mistral is equivalent to the "window attention" mentioned in attention_sinks.

@casper-hansen
Copy link
Author

casper-hansen commented Oct 13, 2023

@tomaarsen I see your point here. My point was more so towards the latency reported in the paper.

Also more interestingly would be a comparison between vLLM/TGI with and without attention sinks since nobody uses raw Huggingface generate methods in production.

I wish the author of the paper had compared with how sliding window was actually used because it has no recomputation overhead like it’s presented in the paper.

@dengxiaotian123
Copy link

dengxiaotian123 commented Dec 18, 2023

Hi @Guangxuan-Xiao, thanks for your explanation! However, it seems you didn't mention SWA in Mistral model? In Mistral model, it utilized Sliding Window Attention when inferencing and I believe it doesn't recompute during inference, and I am wondering why it can achieve this, because in your paper, the performance of model degenerates when using Window Attention.

I am currently thinking maybe it is because Mistral model was trained with Sliding Window Attention, and in result it avoided the attention sink phenomenon. (In one of their issue, this is asked but not answered yet)

Hello ,@verlocks I want to ask a question. In the 'one_file_ref.py' script of 'mistrail', it seems that sliding_window was used during training, but not during inference (because input_ids.shape[-1] should be 1 during inference).
Is the above understanding correct?

@ehuaa
Copy link

ehuaa commented Feb 28, 2024

For reference, the Mistral model degrades in performance over time just like dense attention methods: 272347418-3a4c5634-cc1b-42d1-a35a-afb376a4f970 Here, attention_sinks refers to the StreamingLLM approach, transformers is their model used via the transformers library, and windowed is simple window attention with position ID shifting.

Furthermore, when giving it subsequent prompts (160 prompts in a row): 274319361-987513d9-75d6-41e6-96a5-5d47624faed3

Note

The automatic detection of fluency losses is very naive: it tries to count the number of real words in the response, but that can result in false positives if e.g. the prompt is to generate some German text. See demo/streaming_logs for the full logs to get a better picture of the real generative performance.

E.g. Mistral for transformers and attention_sinks - it's a big difference after like 250 lines.

Hi @tomaarsen, It's a bit weird that in transformers's official api doc,https://huggingface.co/docs/transformers/en/model_doc/mistral
mistral has a maximum input length of almost 128k,
Mistral’s sliding window attention allows sequence of up to 4096*32 tokens.
but in your test, when the input length grows to 8k, it failed. Is this right?

@tomaarsen
Copy link
Contributor

when the input length grows to 8k, it failed. Is this right?

That's right. Although the model doesn't crash until 128k, it doesn't perform well once it has exceeded the pretraining size of 8k tokens.

@ehuaa
Copy link

ehuaa commented Feb 28, 2024

when the input length grows to 8k, it failed. Is this right?

That's right. Although the model doesn't crash until 128k, it doesn't perform well once it has exceeded the pretraining size of 8k tokens.

Thanks for your quick reply, so for industrial use, input exceeded the pretraning size of 8k will not work for mistral model.

@tomaarsen
Copy link
Contributor

Correct, not for mistralai/Mistral-7B-v0.1, at least. There are some Mistral-based models that work on longer sequence lengths, e.g.: https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k

@ehuaa
Copy link

ehuaa commented Feb 28, 2024

Correct, not for mistralai/Mistral-7B-v0.1, at least. There are some Mistral-based models that work on longer sequence lengths, e.g.: https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k

Thanks tom, i'll check the url later!

@ehuaa
Copy link

ehuaa commented Mar 1, 2024

Correct, not for mistralai/Mistral-7B-v0.1, at least. There are some Mistral-based models that work on longer sequence lengths, e.g.: https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k

Hi @tomaarsen , i have another problem here. In your test above, with the config in Mistral sliding_window equals 4096, when the input length grows to 8k, it still has a reasonable perplexity.
But in attention sink paper, it says "Window attention collapses once the input length exceeds the cache size,
i.e., the initial tokens are evicted". but in mistral when the input length larger than 4096, the model doesn't suddenly failed, is there something new fintuned with Mistral model with sliding window?
Can you help me figure this out, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants