Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anisha seq parallel #1147

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
05ef9c7
Add sequence parallelism sharding q
raymondzouu Nov 5, 2024
c87ecf9
Add segmentid implementation
raymondzouu Nov 5, 2024
a965a13
add debug statements
A9isha Jan 3, 2025
61d5ef8
Add segmentid implementation
raymondzouu Nov 5, 2024
d01af2e
temp commit to fix splash attention kernel call
A9isha Jan 4, 2025
a0e093e
fix seq sharding parameter in make_splash_mha
A9isha Jan 6, 2025
ce2c42f
update
A9isha Jan 7, 2025
ece53a5
fix the -1 value for context parallelism
A9isha Jan 7, 2025
54e9a55
add load_balancing in causal mask
A9isha Jan 8, 2025
4241570
reorder mask first commit
A9isha Jan 8, 2025
567233d
try to make static
A9isha Jan 8, 2025
b70e50b
static argnames in progress
A9isha Jan 9, 2025
106ae91
made multi_head_mask np.ndarray
A9isha Jan 9, 2025
442c27a
wrap ndarray
A9isha Jan 10, 2025
0ca559a
fix permuted mask
A9isha Jan 10, 2025
c283c4f
fix permutation
A9isha Jan 10, 2025
4d52390
clean up
A9isha Jan 10, 2025
7feeb07
fix non load balanced case
A9isha Jan 10, 2025
bf2d21a
clean up
A9isha Jan 10, 2025
2edbfcf
use ComputableMask
A9isha Jan 10, 2025
32539c9
fix traced array
A9isha Jan 11, 2025
b47eb81
debugging _ComputableMask but moving to dynamic slice creation
A9isha Jan 13, 2025
968f235
try out new computable mask
A9isha Jan 21, 2025
a026cb1
try running v6e-256 experiment
A9isha Jan 22, 2025
5adc54f
clean up config
A9isha Jan 22, 2025
fca5d0e
try different sa_block values
A9isha Jan 22, 2025
b70ea1e
codestyle changes
A9isha Jan 29, 2025
a829d23
revert remaining jnp to np in mask computation
A9isha Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix permutation
A9isha committed Jan 22, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit c283c4f8e32a61d00bf32bd75aa86b234e4b9f88
5 changes: 3 additions & 2 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
@@ -306,12 +306,13 @@ def tpu_flash_attention(
) -> Array:
"""TPU Flash Attention."""

decoder_segment_ids_permuted = None

#Anisha: reorder tensors which is currently [B,S,H,KV]
cp_size = self.mesh.shape["context"]
if cp_size>1 and load_balanced_context_parallel:
query = self.reorder_causal_load_balancing(tensor = query, cp_size= cp_size, seq_dim= 1, to_contiguous=False)
decoder_segment_ids = self.reorder_causal_load_balancing(tensor = decoder_segment_ids, cp_size= cp_size, seq_dim= 1, to_contiguous=False)
decoder_segment_ids_permuted = self.reorder_causal_load_balancing(tensor = decoder_segment_ids, cp_size= cp_size, seq_dim= 1, to_contiguous=False)



@@ -468,7 +469,7 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segme

return attention_output

x = wrap_flash_attention(query, key, value, decoder_segment_ids, decoder_segment_ids, splash_kernel)
x = wrap_flash_attention(query, key, value, decoder_segment_ids_permuted, decoder_segment_ids, splash_kernel)
# pdb.set_trace()
# jax.debug.print("{x}", x=x)