Skip to content

Commit 0ef2a8a

Browse files
committed
Add OLMo2-26B config
1 parent 08c8073 commit 0ef2a8a

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

src/olmo_core/nn/transformer/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,20 @@ def olmo2_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
409409
**kwargs,
410410
)
411411

412+
@classmethod
413+
def olmo2_26B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
414+
"""
415+
A 26B OLMo model config.
416+
"""
417+
return cls.llama2_26B(
418+
vocab_size,
419+
block_name=kwargs.pop("block_name", TransformerBlockType.reordered_norm),
420+
qk_norm=kwargs.pop("qk_norm", True),
421+
rope_theta=kwargs.pop("rope_theta", 500_000),
422+
layer_norm_eps=1e-6,
423+
**kwargs,
424+
)
425+
412426
@classmethod
413427
def ngpt_271M(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
414428
"""

src/scripts/train/OLMo2-26B.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
Train a 26B OLMo model. Run this script without any arguments to see usage info.
3+
"""
4+
5+
import logging
6+
7+
from olmo_core.config import DType
8+
from olmo_core.distributed.parallel import DataParallelType
9+
from olmo_core.internal.experiment import CommonComponents, main
10+
from olmo_core.nn.transformer import (
11+
TransformerActivationCheckpointingConfig,
12+
TransformerActivationCheckpointingMode,
13+
TransformerConfig,
14+
TransformerDataParallelConfig,
15+
)
16+
from olmo_core.optim import AdamWConfig, OptimGroupOverride
17+
from olmo_core.train import TrainerConfig
18+
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback
19+
20+
log = logging.getLogger(__name__)
21+
22+
23+
def build_model_config(common: CommonComponents) -> TransformerConfig:
24+
return TransformerConfig.olmo2_26B(
25+
vocab_size=common.tokenizer.padded_vocab_size(),
26+
compile=True,
27+
dp_config=TransformerDataParallelConfig(
28+
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
29+
),
30+
ac_config=TransformerActivationCheckpointingConfig(
31+
mode=TransformerActivationCheckpointingMode.full
32+
),
33+
)
34+
35+
36+
def build_optim_config(common: CommonComponents) -> AdamWConfig:
37+
del common
38+
return AdamWConfig(
39+
lr=3e-4,
40+
weight_decay=0.1,
41+
betas=(0.9, 0.95),
42+
group_overrides=[
43+
OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0))
44+
],
45+
fused=True,
46+
)
47+
48+
49+
def build_trainer_config(common: CommonComponents) -> TrainerConfig:
50+
return (
51+
TrainerConfig(
52+
save_folder=common.save_folder,
53+
rank_microbatch_size=1 * 4096,
54+
save_overwrite=True,
55+
metrics_collect_interval=10,
56+
cancel_check_interval=1,
57+
z_loss_multiplier=1e-5,
58+
compile_loss=True,
59+
)
60+
.with_callback(
61+
"checkpointer",
62+
CheckpointerCallback(
63+
save_interval=10_000,
64+
ephemeral_save_interval=250,
65+
save_async=True,
66+
),
67+
)
68+
.with_callback(
69+
"comet",
70+
CometCallback(
71+
name=common.run_name,
72+
workspace="ai2",
73+
project="OLMo-core-26B",
74+
enabled=True,
75+
cancel_check_interval=10,
76+
),
77+
)
78+
.with_callback(
79+
"wandb",
80+
WandBCallback(
81+
name=common.run_name,
82+
entity="ai2-llm",
83+
project="OLMo-core-26B",
84+
enabled=False,
85+
cancel_check_interval=10,
86+
),
87+
)
88+
)
89+
90+
91+
if __name__ == "__main__":
92+
main(
93+
global_batch_size=2048 * 4096,
94+
model_config_builder=build_model_config,
95+
optim_config_builder=build_optim_config,
96+
trainer_config_builder=build_trainer_config,
97+
)

0 commit comments

Comments
 (0)