Skip to content

Commit

Permalink
refactor: move code to ckpt_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Aug 19, 2024
1 parent 9cd4a43 commit 8b7b6a8
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 66 deletions.
68 changes: 68 additions & 0 deletions open_diloco/ckpt_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
import fsspec
from pydantic_config import BaseConfig
import torch
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.distributed.checkpoint as dcp
import os
from torchdata.stateful_dataloader import StatefulDataLoader
from fsspec.generic import GenericFileSystem


GLOBAL_STATE_FILE = "global_state_dict.pt"
CKPT_PREFIX = "model_step"


class CkptConfig(BaseConfig):
resume: str | bool | None = None # if resume is a boolean, it means we should resume from the last checkpoint
interval: int | None = None
path: str = "outputs"
topk: int | None = None # how many checkpoints to keep

def get_resume_path(self):
if self.resume is None:
raise ValueError("Resume path is not set")
elif isinstance(self.resume, bool):
# Using fsspec to list directory contents
fs = GenericFileSystem()
ckpt_files = [f for f in fs.ls(self.path, detail=False) if filter_ckpt_files(f)]

if len(ckpt_files) == 0:
raise ValueError(f"No checkpoints found in {self.path}")

latest_ckpt = max(ckpt_files, key=lambda f: int(f.split("_")[-1]))
return latest_ckpt

return self.resume


def save_checkpoint(
Expand Down Expand Up @@ -117,3 +144,44 @@ def load_checkpoint(
if scaler is not None:
scaler.load_state_dict(global_state_dict["scaler"])
return global_state_dict["loss"]


def filter_ckpt_files(f):
if CKPT_PREFIX not in f:
return False
else:
try:
int(f.split("_")[-1])
return True
except ValueError:
return False


def delete_old_checkpoints(checkpoint_path: str, topk: int) -> list[str]:
fs = GenericFileSystem()
ckpt_files = [f for f in fs.ls(checkpoint_path, detail=False) if filter_ckpt_files(f)]
ckpt_files.sort(key=lambda x: int(x.split("_")[-1]))

ckpt_deleted = []
for ckpt_file in ckpt_files[:-topk]:
fs.rm(ckpt_file, recursive=True)
ckpt_deleted.append(ckpt_file)
return ckpt_deleted


def check_checkpoint_path_access(checkpoint_path: str, rank: int, world_rank_hv: int | None = None):
if world_rank_hv:
dummy_file_path = os.path.join(
checkpoint_path, get_diloco_rank_dir_name(world_rank_hv), f"dummy_file_{rank}.txt"
)
else:
dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt")

with fsspec.open(dummy_file_path, "w") as f:
f.write("This is a dummy file for testing access.")
gfs = GenericFileSystem()
gfs.rm(dummy_file_path)


def get_diloco_rank_dir_name(world_rank_diloco: int) -> str:
return f"diloco_rank_{world_rank_diloco}"
78 changes: 12 additions & 66 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
import datetime
from typing import Any, Literal

import fsspec
from pydantic import model_validator
import torch
import wandb
from pydantic_config import parse_argv, BaseConfig
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from fsspec.generic import GenericFileSystem
from torch.distributed import destroy_process_group, init_process_group

from torchdata.stateful_dataloader import StatefulDataLoader
Expand All @@ -38,7 +36,15 @@
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed import broadcast_object_list
from open_diloco.ckpt_utils import load_checkpoint, save_checkpoint
from open_diloco.ckpt_utils import (
CKPT_PREFIX,
CkptConfig,
check_checkpoint_path_access,
delete_old_checkpoints,
get_diloco_rank_dir_name,
load_checkpoint,
save_checkpoint,
)
from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer


Expand All @@ -58,7 +64,6 @@
TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120)
TARGET_LAYER_ACTIVATIONS = ["self_attn", "lm_head"]
TEST_VOCAB_SIZE = 1024
CKPT_PREFIX = "model_step"


# Function to initialize the distributed process group
Expand All @@ -71,33 +76,6 @@ def log(message):
logger.info(f"[rank {os.environ['LOCAL_RANK']}] {message}")


def check_checkpoint_path_access(checkpoint_path: str, rank: int, world_rank_hv: int | None = None):
if world_rank_hv:
dummy_file_path = os.path.join(
checkpoint_path, get_diloco_rank_dir_name(world_rank_hv), f"dummy_file_{rank}.txt"
)
else:
dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt")

with fsspec.open(dummy_file_path, "w") as f:
f.write("This is a dummy file for testing access.")
gfs = GenericFileSystem()
gfs.rm(dummy_file_path)


def get_diloco_rank_dir_name(world_rank_diloco: int) -> str:
return f"diloco_rank_{world_rank_diloco}"


def delete_old_checkpoints(checkpoint_path: str, topk: int):
fs = GenericFileSystem()
ckpt_files = [f for f in fs.ls(checkpoint_path, detail=False) if filter_ckpt_files(f)]
ckpt_files.sort(key=lambda x: int(x.split("_")[-1]))
for ckpt_file in ckpt_files[:-topk]:
log(f"Deleting old checkpoint {ckpt_file}")
fs.rm(ckpt_file, recursive=True)


class HvConfig(BaseConfig):
outer_lr: float = 0.7
local_steps: int = 500
Expand All @@ -123,40 +101,6 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
return values


def filter_ckpt_files(f):
if CKPT_PREFIX not in f:
return False
else:
try:
int(f.split("_")[-1])
return True
except ValueError:
return False


class CkptConfig(BaseConfig):
resume: str | bool | None = None # if resume is a boolean, it means we should resume from the last checkpoint
interval: int | None = None
path: str = "outputs"
topk: int | None = None # how many checkpoints to keep

def get_resume_path(self):
if self.resume is None:
raise ValueError("Resume path is not set")
elif isinstance(self.resume, bool):
# Using fsspec to list directory contents
fs = GenericFileSystem()
ckpt_files = [f for f in fs.ls(self.path, detail=False) if filter_ckpt_files(f)]

if len(ckpt_files) == 0:
raise ValueError(f"No checkpoints found in {self.path}")

latest_ckpt = max(ckpt_files, key=lambda f: int(f.split("_")[-1]))
return latest_ckpt

return self.resume


class Config(BaseConfig):
path_model: str = "PrimeIntellect/llama-150m-fresh"
torch_compile: bool = True
Expand Down Expand Up @@ -559,7 +503,9 @@ def scheduler_fn(opt):
if local_rank == 0:
# only the rank 0 deletes the checkpoints
if config.ckpt.topk is not None:
delete_old_checkpoints(config.ckpt.path, config.ckpt.topk)
ckpt_deleted = delete_old_checkpoints(config.ckpt.path, config.ckpt.topk)
if ckpt_deleted:
log(f"Deleted old checkpoints: {ckpt_deleted}")

loss_batch = 0

Expand Down

0 comments on commit 8b7b6a8

Please sign in to comment.