From 299852a93125f9b0a49e9170cd2fd2a345fb3896 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 30 May 2024 11:20:22 +0100 Subject: [PATCH] remove checkpoints from push --- training/run_parler_tts_training.py | 12 +++++++++--- training/utils.py | 6 ++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 22e091f..ed85fba 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -680,6 +680,8 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"): checkpoint = last_checkpoint if accelerator.is_main_process: + if training_args.output_dir is not None: + os.makedirs(training_args.output_dir, exist_ok=True) if training_args.push_to_hub: api = HfApi(token=training_args.hub_token) @@ -692,8 +694,6 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"): with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore: if "wandb" not in gitignore: gitignore.write("wandb\n") - elif training_args.output_dir is not None: - os.makedirs(training_args.output_dir, exist_ok=True) accelerator.wait_for_everyone() # Now save everything to be able to create a single processor later @@ -755,6 +755,9 @@ def compute_metrics(audios, descriptions, prompts, device="cpu"): # This fix the issue. "min_new_tokens": num_codebooks + 1, } + generation_config = model.generation_config + for key in gen_kwargs: + generation_config.key = gen_kwargs[key] # Define gradient update step fn def train_step( @@ -883,9 +886,11 @@ def generate_step(batch): # safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix) # https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074 accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False) + config.save_pretrained(intermediate_dir) + generation_config.save_pretrained(intermediate_dir) accelerator.wait_for_everyone() if accelerator.is_main_process: - rotate_checkpoints( + checkpoints_to_be_deleted = rotate_checkpoints( training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger ) @@ -900,6 +905,7 @@ def generate_step(batch): folder_path=training_args.output_dir, commit_message=f"Saving train state of step {cur_step}", run_as_future=True, + delete_patterns=checkpoints_to_be_deleted, ) if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps): diff --git a/training/utils.py b/training/utils.py index 2328575..1415ad8 100644 --- a/training/utils.py +++ b/training/utils.py @@ -3,7 +3,7 @@ import shutil from pathlib import Path from dataclasses import field -from typing import Dict, List +from typing import Dict, List, Union import torch from wandb import Audio @@ -44,7 +44,7 @@ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[ return checkpoints_sorted -def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint", logger=None) -> None: +def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint", logger=None) -> Union[List, None]: """Helper function to delete old checkpoints.""" if save_total_limit is None or save_total_limit <= 0: return @@ -58,6 +58,8 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix for checkpoint in checkpoints_to_be_deleted: logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint, ignore_errors=True) + checkpoints_to_be_deleted = [f"*{Path(checkpoint).absolute().name}*" for checkpoint in checkpoints_to_be_deleted] + return checkpoints_to_be_deleted def log_metric(