Skip to content

Commit

Permalink
allow to fail when loosing a diloco workers
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Aug 12, 2024
1 parent 3e1fa45 commit eaae47a
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class HvConfig(BaseConfig):
skip_load_from_peers: bool = False
world_rank: int
galaxy_size: int
fail_rank_drop: bool = False # fail if we lose a diloco worker

@model_validator(mode="before")
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -369,6 +370,9 @@ def scheduler_fn(opt):

loss_batch = 0

if world_messenger_hv:
max_num_peers = 0

for step, batch in enumerate(iterable=train_dataloader, start=start_step * gradient_accumulation_steps):
real_step = (step + 1) // gradient_accumulation_steps
is_accumulating = bool((step + 1) % gradient_accumulation_steps)
Expand Down Expand Up @@ -448,6 +452,9 @@ def scheduler_fn(opt):
if world_messenger_hv:
outer_lr = [group["lr"] for group in optimizer.state_averager.optimizer.param_groups][0]
num_peers = optimizer.tracker.global_progress.num_peers

max_num_peers = max(max_num_peers, num_peers)

if num_peers == 0:
num_peers = 1

Expand All @@ -457,6 +464,13 @@ def scheduler_fn(opt):
if logging_activations_steps:
metrics.update(activation_monitor.log_activations)

if world_messenger_hv and num_peers < max_num_peers:
log(message=f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}")
if config.hv.fail_rank_drop:
raise ValueError(
f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}"
)

current_time = time.time()

wandb.log(metrics)
Expand Down

0 comments on commit eaae47a

Please sign in to comment.