Skip to content

Commit

Permalink
refactor checkpointing: remove date and run id
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Aug 6, 2024
1 parent 269cf90 commit ba63895
Showing 1 changed file with 16 additions and 26 deletions.
42 changes: 16 additions & 26 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from datasets.distributed import split_dataset_by_node
from fsspec.generic import GenericFileSystem
from torch.distributed import destroy_process_group, init_process_group

from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import (
AutoTokenizer,
Expand Down Expand Up @@ -69,26 +70,18 @@ def log(message):
logger.info(f"[rank {os.environ['LOCAL_RANK']}] {message}")


def get_ckpt_folder(checkpoint_path, training_date, project, run_id):
return os.path.join(checkpoint_path, training_date, project, run_id)


def check_checkpoint_path_access(checkpoint_path: str, training_date, project, run_id, rank):
dummy_file_path = os.path.join(
get_ckpt_folder(
checkpoint_path=checkpoint_path,
training_date=training_date,
project=project,
run_id=run_id,
),
f"dummy_file_{rank}.txt",
)
def check_checkpoint_path_access(checkpoint_path: str, rank: int):
dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt")
with fsspec.open(dummy_file_path, "w") as f:
f.write("This is a dummy file for testing access.")
gfs = GenericFileSystem()
gfs.rm(dummy_file_path)


def get_diloco_rank_dir_name(world_rank_diloco: int) -> str:
return f"diloco_rank_{world_rank_diloco}"


class HvConfig(BaseConfig):
outer_lr: float = 0.7
local_steps: int = 500
Expand Down Expand Up @@ -202,10 +195,6 @@ def train(config: Config):
assert batch_size % config.per_device_train_batch_size == 0
gradient_accumulation_steps = batch_size // config.per_device_train_batch_size

training_date = datetime.datetime.now().strftime(
"%Y-%m-%d"
) # we define the data at the beginning of training in case the training take several days

if config.hv is not None:
sharding_strategy = ShardingStrategy.NO_SHARD
log("Hivemind is used, ShardingStrategy.NO_SHARD is used")
Expand All @@ -232,7 +221,7 @@ def train(config: Config):
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=False)

if local_rank == 0:
check_checkpoint_path_access(config.checkpoint_path, training_date, config.project, run_id, rank)
check_checkpoint_path_access(config.checkpoint_path, rank)

# DataLoader preparation
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
Expand Down Expand Up @@ -290,7 +279,9 @@ def scheduler_fn(opt):
# Otherwise the world messenger will get lonely and hang
fake_optimizer = inner_optimizer(model.parameters())
last_loss = load_checkpoint(
checkpoint_path=config.resume_from_checkpoint,
checkpoint_path=os.path.join(
config.resume_from_checkpoint, get_diloco_rank_dir_name(config.hv.world_rank)
),
model=model,
optimizer=fake_optimizer,
)
Expand Down Expand Up @@ -329,7 +320,9 @@ def scheduler_fn(opt):

if config.resume_from_checkpoint:
last_loss = load_checkpoint(
checkpoint_path=config.resume_from_checkpoint,
checkpoint_path=os.path.join(
config.resume_from_checkpoint, get_diloco_rank_dir_name(config.hv.world_rank)
),
model=model,
optimizer=optimizer.inner_optimizer,
scheduler=scheduler,
Expand Down Expand Up @@ -470,16 +463,13 @@ def scheduler_fn(opt):
# Save checkpoint every 'checkpoint_interval' steps
if config.checkpoint_interval is not None and real_step % config.checkpoint_interval == 0:
log(f"saving at step {real_step}, step {step+1}")
ckpt_path = os.path.join(
get_ckpt_folder(config.checkpoint_path, training_date, config.project, run_id),
f"model_step_{int(real_step)}",
)
ckpt_path = os.path.join(config.checkpoint_path, f"model_step_{int(real_step)}")

if world_messenger_hv:
assert isinstance(optimizer, DiLoCoOptimizer)
with optimizer.tracker.pause_updates():
save_checkpoint(
checkpoint_path=ckpt_path,
checkpoint_path=os.path.join(ckpt_path, get_diloco_rank_dir_name(config.hv.world_rank)),
model=model,
optimizer=optimizer.inner_optimizer,
scheduler=scheduler,
Expand Down

0 comments on commit ba63895

Please sign in to comment.