Skip to content

Commit f9d581a

Browse files
committed
use FusedAdam from TE
Signed-off-by: Peter St. John <[email protected]>
1 parent da4686a commit f9d581a

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ fp8_config:
5353
# Optimizer config
5454
adamw_kwargs:
5555
lr: 4e-4
56-
fused: true
56+
adam_w_mode: true
5757
betas: [0.9, 0.98]
5858
eps: 1e-8
5959
weight_decay: 0.01

bionemo-recipes/recipes/esm2_native_te/train_ddp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import transformer_engine.pytorch
2222
from omegaconf import DictConfig
2323
from torch.distributed.device_mesh import init_device_mesh
24-
from torch.optim import AdamW
2524
from transformer_engine.common.recipe import Format
25+
from transformer_engine.pytorch.optimizers import FusedAdam
2626
from transformers import AutoConfig, AutoModelForMaskedLM
2727

2828
from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint
@@ -81,7 +81,7 @@ def main(args: DictConfig) -> float | None:
8181
pass
8282

8383
# Create optimizer.
84-
optimizer = AdamW(model.parameters(), **args.adamw_kwargs)
84+
optimizer = FusedAdam(model.parameters(), **args.adamw_kwargs)
8585
scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
8686

8787
model = model.to(device=device)

bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from omegaconf import DictConfig, OmegaConf
2323
from torch.distributed.device_mesh import init_device_mesh
2424
from torch.distributed.fsdp import fully_shard
25-
from torch.optim import AdamW
2625
from transformer_engine.common.recipe import Format
26+
from transformer_engine.pytorch.optimizers import FusedAdam
2727
from transformers import AutoConfig, AutoModelForMaskedLM
2828

2929
# This import seems to be needed with meta device init and AutoModel.from_config
@@ -87,7 +87,7 @@ def main(args: DictConfig) -> float | None: # noqa: C901
8787
fully_shard(model, mesh=device_mesh["dp"])
8888

8989
# Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
90-
optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
90+
optimizer = FusedAdam(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
9191
scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
9292

9393
if args.use_meta_device:

bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from megatron_fsdp.fully_shard import fully_shard
2424
from omegaconf import DictConfig, OmegaConf
2525
from torch.distributed.device_mesh import init_device_mesh
26-
from torch.optim import AdamW
2726
from transformer_engine.common.recipe import Format
27+
from transformer_engine.pytorch.optimizers import FusedAdam
2828
from transformers import AutoConfig, AutoModelForMaskedLM
2929

3030
from checkpoint import load_checkpoint_mfsdp, save_checkpoint_mfsdp, save_final_model_mfsdp, should_save_checkpoint
@@ -85,7 +85,7 @@ def main(args: DictConfig) -> float | None:
8585
logger.info("Initialized Model:\n%s", model)
8686

8787
# Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
88-
optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
88+
optimizer = FusedAdam(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
8989

9090
# Wrap model in megatron-fsdp
9191
model, optimizer = fully_shard(

0 commit comments

Comments
 (0)