Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anisha seq parallel #1147

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
05ef9c7
Add sequence parallelism sharding q
raymondzouu Nov 5, 2024
c87ecf9
Add segmentid implementation
raymondzouu Nov 5, 2024
a965a13
add debug statements
A9isha Jan 3, 2025
61d5ef8
Add segmentid implementation
raymondzouu Nov 5, 2024
d01af2e
temp commit to fix splash attention kernel call
A9isha Jan 4, 2025
a0e093e
fix seq sharding parameter in make_splash_mha
A9isha Jan 6, 2025
ce2c42f
update
A9isha Jan 7, 2025
ece53a5
fix the -1 value for context parallelism
A9isha Jan 7, 2025
54e9a55
add load_balancing in causal mask
A9isha Jan 8, 2025
4241570
reorder mask first commit
A9isha Jan 8, 2025
567233d
try to make static
A9isha Jan 8, 2025
b70e50b
static argnames in progress
A9isha Jan 9, 2025
106ae91
made multi_head_mask np.ndarray
A9isha Jan 9, 2025
442c27a
wrap ndarray
A9isha Jan 10, 2025
0ca559a
fix permuted mask
A9isha Jan 10, 2025
c283c4f
fix permutation
A9isha Jan 10, 2025
4d52390
clean up
A9isha Jan 10, 2025
7feeb07
fix non load balanced case
A9isha Jan 10, 2025
bf2d21a
clean up
A9isha Jan 10, 2025
2edbfcf
use ComputableMask
A9isha Jan 10, 2025
32539c9
fix traced array
A9isha Jan 11, 2025
b47eb81
debugging _ComputableMask but moving to dynamic slice creation
A9isha Jan 13, 2025
968f235
try out new computable mask
A9isha Jan 21, 2025
a026cb1
try running v6e-256 experiment
A9isha Jan 22, 2025
5adc54f
clean up config
A9isha Jan 22, 2025
fca5d0e
try different sa_block values
A9isha Jan 22, 2025
b70ea1e
codestyle changes
A9isha Jan 29, 2025
a829d23
revert remaining jnp to np in mask computation
A9isha Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MaxText/common_types.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@

BATCH = "activation_batch"
LENGTH = "activation_length"
KV_LENGTH = "activation_length_kv"
EMBED = "activation_embed"
HEAD = "activation_heads"
PREFILL_KV_BATCH = "activation_prefill_kv_batch"
25 changes: 15 additions & 10 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
@@ -228,15 +228,18 @@ jax_cache_dir: "~/jax_cache"
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'

# Parallelism
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor','sequence','tensor_sequence']],
['activation_kv_heads', ['tensor','sequence','tensor_sequence']],
['activation_length', ['sequence']],
['activation_norm_length', ['tensor_sequence', 'sequence']],
['activation_length', ['sequence', 'context']],
['activation_length', ['context']],
['activation_length_q', ['context']],
['activation_length_kv', []],
['activation_embed', 'tensor'],
['activation_mlp', ['tensor', 'tensor_sequence']],
['activation_kv', ['tensor', 'tensor_sequence']],
@@ -251,13 +254,13 @@ logical_axis_rules: [
['activation_exp', 'expert'],
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['vocab', ['tensor', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
['embed', ['fsdp', 'sequence', 'expert']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed_no_exp', ['fsdp', 'sequence']],
['norm', ['tensor', 'tensor_sequence']],
['q_heads', ['tensor', 'tensor_sequence', 'autoregressive']],
['heads', ['tensor', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['embed', ['fsdp', 'sequence', 'context', 'expert']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'context']],
['norm', 'tensor'],
['q_heads', ['tensor', 'autoregressive']],
['heads', ['tensor', 'autoregressive']],
['layers', 'stage'],
['kv', []],
['kv_heads', ['tensor', 'tensor_sequence', 'autoregressive']],
@@ -270,7 +273,7 @@ logical_axis_rules: [
['exp', 'expert'],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']]
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']]

# sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters.
sharding_tolerance: 0.02
@@ -288,6 +291,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended
dcn_pipeline_parallelism: 1
dcn_expert_parallelism: 1
dcn_autoregressive_parallelism: 1 # never recommended
dcn_context_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_transpose_parallelism: 1
@@ -297,6 +301,7 @@ ici_tensor_sequence_parallelism: 1
ici_autoregressive_parallelism: 1
ici_pipeline_parallelism: 1
ici_expert_parallelism: 1
ici_context_parallelism: 1

# The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation,
# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1.
Loading