Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#67 from ZeroCool940711/dev
Browse files Browse the repository at this point in the history
Changes all the print statements to use `accelerator.print` for compatibility with latest accelerate library and colab.
  • Loading branch information
ZeroCool940711 committed Sep 3, 2023
2 parents d78a914 + 3842bfb commit c61f185
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
24 changes: 12 additions & 12 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,10 +590,10 @@ def main():
# Load the VAE
with accelerator.main_process_first():
if args.vae_path:
print("Loading Muse VQGanVAE")
accelerator.print("Loading Muse VQGanVAE")

if args.latest_checkpoint:
print("Finding latest VAE checkpoint...")
accelerator.print("Finding latest VAE checkpoint...")
orig_vae_path = args.vae_path

if os.path.isfile(args.vae_path) or ".pt" in args.vae_path:
Expand All @@ -610,7 +610,7 @@ def main():
if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(
latest_checkpoint_file, os.R_OK
):
print(
accelerator.print(
f"Warning: latest VAE checkpoint {latest_checkpoint_file} is empty or unreadable."
)
if len(checkpoint_files) > 1:
Expand All @@ -619,19 +619,19 @@ def main():
checkpoint_files[:-1],
key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)),
)
print("Using second last VAE checkpoint: ", latest_checkpoint_file)
accelerator.print("Using second last VAE checkpoint: ", latest_checkpoint_file)
else:
print("No usable checkpoint found.")
accelerator.print("No usable checkpoint found.")
elif latest_checkpoint_file != orig_vae_path:
print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file)
accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file)
else:
print("Using VAE checkpoint specified in vae_path: ", orig_vae_path)
accelerator.print("Using VAE checkpoint specified in vae_path: ", orig_vae_path)

args.vae_path = latest_checkpoint_file
else:
print("No VAE checkpoints found in directory: ", args.vae_path)
accelerator.print("No VAE checkpoints found in directory: ", args.vae_path)
else:
print("Resuming VAE from: ", args.vae_path)
accelerator.print("Resuming VAE from: ", args.vae_path)

# use config next to checkpoint if there is one and merge the cli arguments to it
# the cli arguments will take priority so we can use it to override any value we want.
Expand All @@ -654,7 +654,7 @@ def main():
vae.load(args.vae_path)

elif args.taming_model_path is not None and args.taming_config_path is not None:
print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}")
accelerator.print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}")
vae = VQGanVAETaming(
vqgan_model_path=args.taming_model_path,
vqgan_config_path=args.taming_config_path,
Expand Down Expand Up @@ -804,7 +804,7 @@ def main():
total_params = sum(p.numel() for p in maskgit.parameters())
args.total_params = total_params

print(f"Total number of parameters: {format(total_params, ',d')}")
accelerator.print(f"Total number of parameters: {format(total_params, ',d')}")

# Create the dataset objects
with accelerator.main_process_first():
Expand Down Expand Up @@ -893,7 +893,7 @@ def main():
is_local_main = accelerator.is_local_main_process

with accelerator.local_main_process_first():
print(
accelerator.print(
f"[P{proc_idx:03d}]: PID {proc_idx}/{n_procs}, local #{local_proc_idx}, ",
f"XLA ord={xm_ord}/{xm_world}, local={xm_local_ord}, master={xm_master_ord} ",
f"Accelerate: main={is_main}, local main={is_local_main} ",
Expand Down
14 changes: 7 additions & 7 deletions train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def main():
args = parser.parse_args(namespace=Arguments())

if args.config_path:
print("Using config file and ignoring CLI args")
accelerator.print("Using config file and ignoring CLI args")

try:
conf = OmegaConf.load(args.config_path)
Expand All @@ -403,10 +403,10 @@ def main():
try:
args_to_convert[key] = conf[key]
except KeyError:
print(f"Error parsing config - {key}: {conf[key]} | Using default or parsed")
accelerator.print(f"Error parsing config - {key}: {conf[key]} | Using default or parsed")

except FileNotFoundError:
print("Could not find config, using default and parsed values...")
accelerator.print("Could not find config, using default and parsed values...")

project_config = ProjectConfiguration(
project_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"),
Expand Down Expand Up @@ -457,7 +457,7 @@ def main():
]
if args.streaming:
if dataset.info.dataset_size is None:
print("Dataset doesn't support streaming, disabling streaming")
accelerator.print("Dataset doesn't support streaming, disabling streaming")
args.streaming = False
if args.cache_path:
dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name]
Expand Down Expand Up @@ -536,7 +536,7 @@ def main():
current_step = 0

elif args.taming_model_path is not None and args.taming_config_path is not None:
print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}")
accelerator.print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}")
vae = VQGanVAETaming(
vqgan_model_path=args.taming_model_path,
vqgan_config_path=args.taming_config_path,
Expand All @@ -547,7 +547,7 @@ def main():

current_step = 0
else:
print("Initialising empty VAE")
accelerator.print("Initialising empty VAE")
vae = VQGanVAE(
dim=args.dim,
vq_codebook_dim=args.vq_codebook_dim,
Expand All @@ -564,7 +564,7 @@ def main():
total_params = sum(p.numel() for p in vae.parameters())
args.total_params = total_params

print(f"Total number of parameters: {format(total_params, ',d')}")
accelerator.print(f"Total number of parameters: {format(total_params, ',d')}")

dataset = ImageDataset(
dataset,
Expand Down

0 comments on commit c61f185

Please sign in to comment.