diff --git a/muse_maskgit_pytorch/trainers/vqvae_trainers.py b/muse_maskgit_pytorch/trainers/vqvae_trainers.py index ae0bc89..8d9d916 100644 --- a/muse_maskgit_pytorch/trainers/vqvae_trainers.py +++ b/muse_maskgit_pytorch/trainers/vqvae_trainers.py @@ -402,7 +402,7 @@ def train(self): if self.validation_folder_at_end_of_epoch: vae_folder_validation(self.accelerator, self.model, self.validation_folder_at_end_of_epoch, self.args, - checkpoint_name=os.path.join(self.results_dir, f'vae.{steps}.pt'), + checkpoint_name=os.path.join(self.results_dir, f'vae.{steps}_E{epoch + 1}.pt'), ) diff --git a/muse_maskgit_pytorch/utils.py b/muse_maskgit_pytorch/utils.py index 3b96b1a..928348f 100644 --- a/muse_maskgit_pytorch/utils.py +++ b/muse_maskgit_pytorch/utils.py @@ -147,7 +147,7 @@ def remove_duplicate_weights(ema_state_dict, non_ema_state_dict): del ema_state_dict_copy[key] return ema_state_dict_copy -def vae_folder_validation(accelerator, vae, dataset, args=None, checkpoint_name="vae", save_originals=False): +def vae_folder_validation(accelerator, vae, dataset, args=None, checkpoint_name="vae", save_originals=False, epoch=None): # Create output directory and save input images and reconstructions as grids output_dir = os.path.join(args.results_dir, "outputs", @@ -202,7 +202,7 @@ def vae_folder_validation(accelerator, vae, dataset, args=None, checkpoint_name= now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S") hash = hashlib.sha1(input_image.tobytes()).hexdigest() - filename = f"{hash}_{now}-{os.path.basename(checkpoint_name)}.png" + filename = f"{hash}_{now}{'-' + epoch if epoch else ''}-{os.path.basename(checkpoint_name)}.png" grid_image.save(f"{output_dir}/{filename}", format="PNG") if not save_originals: