-
Notifications
You must be signed in to change notification settings - Fork 374
Add meshes and config for TRN2/1 for Fuji models #885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,13 +15,15 @@ | |
import itertools | ||
from typing import Any, Optional, Union | ||
|
||
import jax | ||
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies | ||
|
||
from axlearn.common import causal_lm, config | ||
from axlearn.common.attention import ( | ||
BaseStackedTransformerLayer, | ||
FusedGroupedQKVLinear, | ||
FusedQKVLinear, | ||
GroupedQKVLinear, | ||
GroupedQueryAttention, | ||
MultiheadAttention, | ||
RepeatedTransformerLayer, | ||
|
@@ -174,6 +176,12 @@ def get_trainer_kwargs( | |
train_batch_size=train_batch_size, | ||
max_step=max_step, | ||
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), | ||
mesh_rules=( | ||
( | ||
"neuron-(trn2|trn2n).48xlarge-64", | ||
mesh_shape_from_axes(fsdp=-1, model=4), | ||
), | ||
) | ||
) | ||
elif model_size == "3B": | ||
trainer_kwargs = dict( | ||
|
@@ -192,6 +200,12 @@ def get_trainer_kwargs( | |
train_batch_size=train_batch_size, | ||
max_step=max_step, | ||
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), | ||
mesh_rules=( | ||
( | ||
"neuron-(trn2|trn2n).48xlarge-64", | ||
mesh_shape_from_axes(fsdp=-1, model=4), | ||
), | ||
) | ||
) | ||
elif model_size == "7B": | ||
trainer_kwargs = dict( | ||
|
@@ -287,6 +301,14 @@ def get_trainer_kwargs( | |
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", | ||
mesh_shape_from_axes(data=-1, fsdp=8), | ||
), | ||
( | ||
"neuron-(trn2|trn2n).48xlarge-64", | ||
mesh_shape_from_axes(fsdp=-1, model=4), | ||
), | ||
( | ||
"neuron-(trn1|trn1n).32xlarge-64", | ||
mesh_shape_from_axes(fsdp=-1, model=8), | ||
), | ||
), | ||
) | ||
elif model_size == "8B": | ||
|
@@ -367,6 +389,10 @@ def get_trainer_kwargs( | |
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", | ||
mesh_shape_from_axes(data=-1, fsdp=8), | ||
), | ||
( | ||
"neuron-(trn2|trn2n).48xlarge-64", | ||
mesh_shape_from_axes(fsdp=-1, model=4), | ||
), | ||
), | ||
) | ||
elif model_size == "70B": | ||
|
@@ -417,12 +443,17 @@ def get_trainer_kwargs( | |
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", | ||
mesh_shape_from_axes(data=-1, fsdp=128), | ||
), | ||
( | ||
"neuron-(trn2|trn2n).48xlarge-64", | ||
mesh_shape_from_axes(fsdp=-1, model=4), | ||
), | ||
), | ||
) | ||
else: | ||
raise NotImplementedError(f"Unknown model size {model_size}.") | ||
model_kwargs = trainer_kwargs.pop("model_kwargs") | ||
model_kwargs.setdefault("vocab_size", vocab_size) | ||
model_kwargs.setdefault("stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will the use of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are in the middle of optimizing RepeatedTransformer to use a new hardware feature in TRN2 to make dynamic memory operations faster. In the meantime, please continue to use StackedTransformer. Neuron compiler has a module to detect repeating blocks, compile once and reuse. So, compile time won't grow with the number of layers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Nice! Could you add this as a comment? |
||
trainer_kwargs["model_cfg"] = model_config(**model_kwargs) | ||
trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config( | ||
max_step=trainer_kwargs["max_step"], | ||
|
@@ -473,7 +504,9 @@ def model_config( | |
ffn_dim = scaled_hidden_dim(scale=8 / 3, round_up_to_multiples_of=256) | ||
if num_kv_heads: | ||
atten_cfg = GroupedQueryAttention.default_config() | ||
atten_input_linear = FusedGroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads) | ||
backend = jax.default_backend() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fuji config should not depend on Instead, we can create separate configs for a backend that requires different settings. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, please follow this example instead if you really need to overwrite some configs, you can add another custom There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the review, will update the PR with a custom |
||
qkv_linear = FusedGroupedQKVLinear if backend != "neuron" else GroupedQKVLinear | ||
atten_input_linear = qkv_linear.default_config().set(num_kv_heads=num_kv_heads) | ||
else: | ||
atten_cfg = MultiheadAttention.default_config() | ||
atten_input_linear = FusedQKVLinear.default_config() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment on why we set
model=4
for neuron?