Skip to content

Commit

Permalink
Small QOL when using the validation_folder_at_end_of_epoch argument. …
Browse files Browse the repository at this point in the history
…We now append the epoch number to the filename so we can identify at what epoch the image was saved. (Sygil-Dev#74)
  • Loading branch information
ZeroCool940711 committed Sep 27, 2023
2 parents c219b44 + f670611 commit bb96aad
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def train(self):
if self.validation_folder_at_end_of_epoch:
vae_folder_validation(self.accelerator, self.model, self.validation_folder_at_end_of_epoch,
self.args,
checkpoint_name=os.path.join(self.results_dir, f'vae.{steps}.pt'),
checkpoint_name=os.path.join(self.results_dir, f'vae.{steps}_E{epoch + 1}.pt'),

)

Expand Down
4 changes: 2 additions & 2 deletions muse_maskgit_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def remove_duplicate_weights(ema_state_dict, non_ema_state_dict):
del ema_state_dict_copy[key]
return ema_state_dict_copy

def vae_folder_validation(accelerator, vae, dataset, args=None, checkpoint_name="vae", save_originals=False):
def vae_folder_validation(accelerator, vae, dataset, args=None, checkpoint_name="vae", save_originals=False, epoch=None):

# Create output directory and save input images and reconstructions as grids
output_dir = os.path.join(args.results_dir, "outputs",
Expand Down Expand Up @@ -202,7 +202,7 @@ def vae_folder_validation(accelerator, vae, dataset, args=None, checkpoint_name=
now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
hash = hashlib.sha1(input_image.tobytes()).hexdigest()

filename = f"{hash}_{now}-{os.path.basename(checkpoint_name)}.png"
filename = f"{hash}_{now}{'-' + epoch if epoch else ''}-{os.path.basename(checkpoint_name)}.png"
grid_image.save(f"{output_dir}/{filename}", format="PNG")

if not save_originals:
Expand Down

0 comments on commit bb96aad

Please sign in to comment.