forked from rwitten/HighPerfLLMs2024
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAfterSessionExercises.txt
11 lines (8 loc) · 1.6 KB
/
AfterSessionExercises.txt
1
2
3
4
5
6
7
8
9
10
11
(1) During class, we implemented one specific element of attention manually. However, a better exercise would be to compute the outupt for the i-th sequence for a specific head, as pointed out by Zac.
To do this manually (ignoring BATCH and HEADS):
(a) dot the i-th sequence's q with all the k's. This should be SEQUENCE in length.
(b) take a softmax -- now it should be SEQUENCE in length but sum to 1 and all be non-negative.
(c) take the weighted average over the V's. This should be a single vector of length HEAD_DIM
(d) to compare, validate that this to the reference implementation, either the one we computated in class or `pallas_attention.mha_reference`.
(2) A very important optimization to attention is to support multiple input sequences that don't communicate. This is achieved by passing in a 1D input of sequence_length that is "segment_ids", and imagine attention operates as we learned expect sequences wiht the same segment_id go through separately -- experiencing their own attention wihtout interaction from the rest. This can be computationally attractive to avoid "wasting" tokens if we're running with a batch of size 1024 and say the examples are of length 300 on average.
How would you implement this? It is similar to "causal" attention which we saw except now we need to zero out the weights that correspond to cross sequence communication -- the row i, col j should be zeroed out if segment_id[i] != segment_id[j]. Using `pallas_attention.mha_reference` as your reference implementation to check for correctness, implement this! (And if you get stuck, you can see how they implement it in mha_reference.