Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
import itertools
from typing import Any, Optional, Union

import jax
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies

from axlearn.common import causal_lm, config
from axlearn.common.attention import (
BaseStackedTransformerLayer,
FusedGroupedQKVLinear,
FusedQKVLinear,
GroupedQKVLinear,
GroupedQueryAttention,
MultiheadAttention,
RepeatedTransformerLayer,
Expand Down Expand Up @@ -174,6 +176,12 @@ def get_trainer_kwargs(
train_batch_size=train_batch_size,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on why we set model=4 for neuron?

),
)
)
elif model_size == "3B":
trainer_kwargs = dict(
Expand All @@ -192,6 +200,12 @@ def get_trainer_kwargs(
train_batch_size=train_batch_size,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
),
)
)
elif model_size == "7B":
trainer_kwargs = dict(
Expand Down Expand Up @@ -287,6 +301,14 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
),
(
"neuron-(trn1|trn1n).32xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=8),
),
),
)
elif model_size == "8B":
Expand Down Expand Up @@ -367,6 +389,10 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
),
),
)
elif model_size == "70B":
Expand Down Expand Up @@ -417,12 +443,17 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
),
),
)
else:
raise NotImplementedError(f"Unknown model size {model_size}.")
model_kwargs = trainer_kwargs.pop("model_kwargs")
model_kwargs.setdefault("vocab_size", vocab_size)
model_kwargs.setdefault("stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the use of StackedTransformerLayer (vs. RepeatedTransformerLayer) lead to large XLA programs and long compilation time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are in the middle of optimizing RepeatedTransformer to use a new hardware feature in TRN2 to make dynamic memory operations faster. In the meantime, please continue to use StackedTransformer. Neuron compiler has a module to detect repeating blocks, compile once and reuse. So, compile time won't grow with the number of layers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are in the middle of optimizing RepeatedTransformer to use a new hardware feature in TRN2 to make dynamic memory operations faster. In the meantime, please continue to use StackedTransformer. Neuron compiler has a module to detect repeating blocks, compile once and reuse. So, compile time won't grow with the number of layers.

Nice! Could you add this as a comment?

trainer_kwargs["model_cfg"] = model_config(**model_kwargs)
trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config(
max_step=trainer_kwargs["max_step"],
Expand Down Expand Up @@ -473,7 +504,9 @@ def model_config(
ffn_dim = scaled_hidden_dim(scale=8 / 3, round_up_to_multiples_of=256)
if num_kv_heads:
atten_cfg = GroupedQueryAttention.default_config()
atten_input_linear = FusedGroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads)
backend = jax.default_backend()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fuji config should not depend on jax.default_backend(), otherwise the golden configs will not reflect the actual config being used.

Instead, we can create separate configs for a backend that requires different settings.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, please follow this example instead if you really need to overwrite some configs, you can add another custom LayerConfigModifierlike this one: https://github.com/apple/axlearn/blob/main/axlearn/common/trainer_config_modifier.py#L69,

Copy link
Contributor Author

@apoorvtintin apoorvtintin Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review, will update the PR with a custom LayerConfigModifier.

qkv_linear = FusedGroupedQKVLinear if backend != "neuron" else GroupedQKVLinear
atten_input_linear = qkv_linear.default_config().set(num_kv_heads=num_kv_heads)
else:
atten_cfg = MultiheadAttention.default_config()
atten_input_linear = FusedQKVLinear.default_config()
Expand Down