Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anisha seq parallel #1147

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
05ef9c7
Add sequence parallelism sharding q
raymondzouu Nov 5, 2024
c87ecf9
Add segmentid implementation
raymondzouu Nov 5, 2024
a965a13
add debug statements
A9isha Jan 3, 2025
61d5ef8
Add segmentid implementation
raymondzouu Nov 5, 2024
d01af2e
temp commit to fix splash attention kernel call
A9isha Jan 4, 2025
a0e093e
fix seq sharding parameter in make_splash_mha
A9isha Jan 6, 2025
ce2c42f
update
A9isha Jan 7, 2025
ece53a5
fix the -1 value for context parallelism
A9isha Jan 7, 2025
54e9a55
add load_balancing in causal mask
A9isha Jan 8, 2025
4241570
reorder mask first commit
A9isha Jan 8, 2025
567233d
try to make static
A9isha Jan 8, 2025
b70e50b
static argnames in progress
A9isha Jan 9, 2025
106ae91
made multi_head_mask np.ndarray
A9isha Jan 9, 2025
442c27a
wrap ndarray
A9isha Jan 10, 2025
0ca559a
fix permuted mask
A9isha Jan 10, 2025
c283c4f
fix permutation
A9isha Jan 10, 2025
4d52390
clean up
A9isha Jan 10, 2025
7feeb07
fix non load balanced case
A9isha Jan 10, 2025
bf2d21a
clean up
A9isha Jan 10, 2025
2edbfcf
use ComputableMask
A9isha Jan 10, 2025
32539c9
fix traced array
A9isha Jan 11, 2025
b47eb81
debugging _ComputableMask but moving to dynamic slice creation
A9isha Jan 13, 2025
968f235
try out new computable mask
A9isha Jan 21, 2025
a026cb1
try running v6e-256 experiment
A9isha Jan 22, 2025
5adc54f
clean up config
A9isha Jan 22, 2025
fca5d0e
try different sa_block values
A9isha Jan 22, 2025
b70ea1e
codestyle changes
A9isha Jan 29, 2025
a829d23
revert remaining jnp to np in mask computation
A9isha Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add load_balancing in causal mask
A9isha committed Jan 22, 2025
commit 54e9a55c1dfe65cbc3124c95bed6c073f0045304
90 changes: 89 additions & 1 deletion MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
@@ -301,8 +301,21 @@ def tpu_flash_attention(
value: Array,
decoder_segment_ids: Array | None,
attn_logits_soft_cap: float | None = None,
load_balanced_context_parallel: bool = True
) -> Array:
"""TPU Flash Attention."""


#Anisha: reorder tensors which is currently [B,S,H,KV]
cp_size = self.mesh.shape["context"]
if cp_size>1 and load_balanced_context_parallel:
query = self.reorder_causal_load_balancing(tensor = query, cp_size= cp_size, seq_dim= 1, to_contiguous=False)
decoder_segment_ids = self.reorder_causal_load_balancing(tensor = decoder_segment_ids, cp_size= cp_size, seq_dim= 1, to_contiguous=False)





# Transpose to ('batch', 'heads', 'length', 'kv')
query = jnp.transpose(query, axes=(0, 2, 1, 3))
key = jnp.transpose(key, axes=(0, 2, 1, 3))
@@ -340,6 +353,8 @@ def tpu_flash_attention(
"Batch dimension should be shardable among the devices in data and fsdp" " axis"
)



#create_splash_attention kernel
block_sizes = splash_attention_kernel.BlockSizes(
block_q=min(global_block_q, query.shape[2]),
@@ -358,6 +373,11 @@ def tpu_flash_attention(
# jax.debug.print("query.shape = {qs}, key.shape = {ks}", qs = query.shape, ks = key.shape)
mask = splash_attention_mask.CausalMask(shape=(query.shape[2], key.shape[2]))

# Anisha: permute the mask
if cp_size>1 and load_balanced_context_parallel:
mask = self.reorder_causal_load_balancing(mask, cp_size, 1, to_contiguous=False)

#Anisha: figure out local_sliding attention + load_balancing, default is global
# Apply local masking if local sliding attention is enabled.
if self.attention_type == AttentionType.LOCAL_SLIDING:
if self.sliding_window_size is None:
@@ -373,7 +393,7 @@ def tpu_flash_attention(
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
head_shards=1, # we would need to change this to the size of the axis if sharding over heads
q_seq_shards=self.mesh.shape["context"], #axis for sequence sharding
q_seq_shards=cp_size, #axis for sequence sharding
block_sizes=block_sizes,
attn_logits_soft_cap=attn_logits_soft_cap,
)
@@ -417,8 +437,76 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segme
# jax.debug.print("{x}", x=x)

x = jnp.transpose(x, axes=(0, 2, 1, 3))

if cp_size>1 and load_balanced_context_parallel:
#Anisha: inverse reorder for load_balancing
x = self.reorder_causal_load_balancing(tensor = x, cp_size= cp_size, seq_dim= 1, to_contiguous=True)


return x


def reorder_causal_load_balancing(self, tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
"""Reorders a tensor for load balancing the compute of causal attention."""

if tensor is None:
return tensor

if cp_size == 1:
return tensor

if cp_size % 2 != 0:
raise ValueError(f"{cp_size=} must be a multiple of 2.")

# Need to ensure we have 2 pairs to swap for balancing between cp ranks
if tensor.shape[seq_dim] % (cp_size * 2) != 0:
raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}")

# [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] #Anisha: this is ours
# [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]

ori_tensor_shape = tensor.shape
tensor = jnp.reshape(
tensor,
(
*ori_tensor_shape[:seq_dim],
2 * cp_size,
ori_tensor_shape[seq_dim] // (2 * cp_size),
*ori_tensor_shape[seq_dim + 1 :],
)
)

parts = []
if not to_contiguous:
for cp_rank in range(cp_size):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
parts.append(jnp.take(tensor, index, axis=seq_dim))
else:
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 4 * cp_rank
index = jnp.array([base, base + 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 2 * cp_size - 1 - 4 * cp_rank
index = jnp.array([base, base - 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))

# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
combined = jnp.stack(parts, axis=seq_dim)

return jnp.reshape(combined,ori_tensor_shape)





def cudnn_flash_attention(
self,
query: Array,