Skip to content

Commit 3d23a2d

Browse files
committed
finalize config
1 parent d4c65ac commit 3d23a2d

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/scripts/train/OLMo2-26B.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from olmo_core.config import DType
88
from olmo_core.distributed.parallel import DataParallelType
9+
from olmo_core.float8 import Float8Config
910
from olmo_core.internal.experiment import CommonComponents, main
1011
from olmo_core.nn.transformer import (
1112
TransformerActivationCheckpointingConfig,
@@ -30,13 +31,14 @@ def build_model_config(common: CommonComponents) -> TransformerConfig:
3031
ac_config=TransformerActivationCheckpointingConfig(
3132
mode=TransformerActivationCheckpointingMode.full
3233
),
34+
float8_config=Float8Config(compile=True),
3335
)
3436

3537

3638
def build_optim_config(common: CommonComponents) -> AdamWConfig:
3739
del common
3840
return AdamWConfig(
39-
lr=3e-4,
41+
lr=6e-4,
4042
weight_decay=0.1,
4143
betas=(0.9, 0.95),
4244
group_overrides=[
@@ -50,7 +52,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
5052
return (
5153
TrainerConfig(
5254
save_folder=common.save_folder,
53-
rank_microbatch_size=1 * 4096,
55+
rank_microbatch_size=4 * 4096,
5456
save_overwrite=True,
5557
metrics_collect_interval=10,
5658
cancel_check_interval=1,

0 commit comments

Comments
 (0)