diff --git a/open_diloco/ckpt_utils.py b/open_diloco/ckpt_utils.py index 261b7c4..97688c6 100644 --- a/open_diloco/ckpt_utils.py +++ b/open_diloco/ckpt_utils.py @@ -40,6 +40,7 @@ def save_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR, + outer_scheduler: torch.optim.lr_scheduler.LambdaLR | None = None, outer_optimizer: torch.optim.Optimizer | None = None, scaler: torch.cuda.amp.GradScaler | None = None, loss: float | None = None, @@ -81,6 +82,8 @@ def save_checkpoint( # 2. Save global states global_state_dict = {"scheduler": scheduler.state_dict(), "loss": loss if loss is not None else 0} + if outer_scheduler is not None: + global_state_dict["outer_scheduler"] = outer_scheduler.state_dict() if outer_optimizer is not None: global_state_dict["outer_optimizer"] = outer_optimizer.state_dict() if scaler is not None: @@ -95,6 +98,7 @@ def load_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR | None = None, + outer_scheduler: torch.optim.lr_scheduler.LambdaLR | None = None, outer_optimizer: torch.optim.Optimizer | None = None, scaler: torch.cuda.amp.GradScaler | None = None, data_loader: StatefulDataLoader | None = None, @@ -139,8 +143,13 @@ def load_checkpoint( if scheduler is not None: scheduler.load_state_dict(global_state_dict["scheduler"]) optimizer.param_groups[0]["lr"] = scheduler.get_last_lr()[0] + if outer_optimizer is not None: outer_optimizer.load_state_dict(global_state_dict["outer_optimizer"]) + if outer_scheduler is not None: + outer_scheduler.load_state_dict(global_state_dict["outer_scheduler"]) + outer_optimizer.param_groups[0]["lr"] = outer_scheduler.get_last_lr()[0] + if scaler is not None: scaler.load_state_dict(global_state_dict["scaler"]) return global_state_dict["loss"] diff --git a/open_diloco/hivemind_diloco.py b/open_diloco/hivemind_diloco.py index 308608b..74d94d5 100644 --- a/open_diloco/hivemind_diloco.py +++ b/open_diloco/hivemind_diloco.py @@ -334,6 +334,7 @@ def __init__( inner_optimizer: OptimizerFactory, params: Optional[Union[Parameters, ParamGroups]] = None, scheduler: Optional[SchedulerFactory] = None, + outer_scheduler: Optional[SchedulerFactory] = None, averager_opts: Optional[dict] = None, grad_compression: CompressionBase = NoCompression(), tracker_opts: Optional[dict] = None, @@ -365,7 +366,7 @@ def __init__( # since we have two optimizers, we need to persist the params to a list self.num_inner_steps = num_inner_steps - for opt_or_scheduler in [outer_optimizer, scheduler]: + for opt_or_scheduler in [outer_optimizer, scheduler, outer_scheduler]: if not (callable(opt_or_scheduler) or opt_or_scheduler is None): raise TypeError("You need to pass inner and outer optimizer as well as scheduler as callable") @@ -405,6 +406,8 @@ def __init__( ) self.diloco_grad_averager = self._make_gradient_averager(compression=grad_compression) + self.outer_scheduler = outer_scheduler(self.state_averager.optimizer) if outer_scheduler else None + def _check_kwargs(self, kwargs) -> None: """DiLoCo Optimizer only support a subset of Hivemind Optimizer kwargs. This function raise an error if some kwargs are not supported""" @@ -555,6 +558,9 @@ def step( if self.tracker.ready_to_update_epoch: self._update_global_epoch() + if self.outer_scheduler is not None: + self.outer_scheduler.step() + return loss def _compute_schema_hash(self) -> int: diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index ab4efe2..4383e7c 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -7,6 +7,7 @@ """ from functools import partial +import math import os import time from contextlib import nullcontext @@ -26,7 +27,6 @@ DataCollatorForLanguageModeling, LlamaConfig, LlamaForCausalLM, - get_cosine_schedule_with_warmup, ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, @@ -46,6 +46,7 @@ ) from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer from open_diloco.utils import WandbLogger, DummyLogger +from torch.optim.lr_scheduler import LambdaLR from hivemind.dht.dht import DHT from hivemind.utils.networking import log_visible_maddrs @@ -90,6 +91,8 @@ class HvConfig(BaseConfig): world_rank: int galaxy_size: int fail_rank_drop: bool = False # fail if we lose a diloco worker + warmup_outerstep: int = 10 + outer_scheduler: bool = False @model_validator(mode="before") def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: @@ -173,6 +176,61 @@ def get_model(config: Config) -> LlamaForCausalLM: return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float, + min_lr_rate: float = 0.0, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_cosine_schedule_with_warmup(optimizer, config: Config): + lambda_lr = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=config.warmup_steps, + num_training_steps=config.total_steps, + num_cycles=0.5, + ) + return LambdaLR(optimizer, lambda_lr, -1) + + +def _get_lr_outer( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float, + min_lr_rate: float = 0.0, +): + if current_step < num_warmup_steps: + return 1 + + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_lr_outer(optimizer, config: Config): + lambda_lr = partial( + _get_lr_outer, + num_warmup_steps=config.warmup_steps, + # num_training_steps=config.total_steps, + num_training_steps=config.total_steps, + num_cycles=0.5, + ) + return LambdaLR(optimizer, lambda_lr, -1) + + def train(config: Config): sharding_strategy = get_sharding_strategy(config.sharding_strategy) local_rank = int(os.environ["LOCAL_RANK"]) @@ -252,10 +310,12 @@ def train(config: Config): def scheduler_fn(opt): return get_cosine_schedule_with_warmup( opt, - num_warmup_steps=config.warmup_steps, - num_training_steps=config.total_steps, + config=config, ) + def outer_scheduler_fn(opt): + return get_lr_outer(opt, config=config) + if config.hv is not None: if config.ckpt.resume: # We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer @@ -281,6 +341,7 @@ def scheduler_fn(opt): outer_optimizer=outer_optimizer, inner_optimizer=inner_optimizer, scheduler=None, + outer_scheduler=outer_scheduler_fn if config.hv.outer_scheduler else None, params=model.parameters(), delay_optimizer_step=False, delay_grad_averaging=False, @@ -311,6 +372,7 @@ def scheduler_fn(opt): model=model, optimizer=optimizer.inner_optimizer, scheduler=scheduler, + outer_scheduler=optimizer.outer_scheduler, outer_optimizer=optimizer.state_averager.optimizer, scaler=scaler, data_loader=train_dataloader, @@ -400,6 +462,7 @@ def scheduler_fn(opt): scaler.update() scheduler.step() + optimizer.zero_grad() if config.hv is not None: @@ -476,6 +539,7 @@ def scheduler_fn(opt): model=model, optimizer=optimizer.inner_optimizer, scheduler=scheduler, + outer_scheduler=optimizer.outer_scheduler, outer_optimizer=optimizer.state_averager.optimizer, loss=loss_batch.item(), scaler=scaler,