Skip to content

Commit

Permalink
remove deprecated averaged_model
Browse files Browse the repository at this point in the history
  • Loading branch information
teticio committed Feb 3, 2023
1 parent d8c6a4b commit 7bd9ee1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion audiodiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# from diffusers import AudioDiffusionPipeline
from .pipeline_audio_diffusion import AudioDiffusionPipeline

VERSION = "1.5.0"
VERSION = "1.5.1"


class AudioDiffusion:
Expand Down
6 changes: 4 additions & 2 deletions scripts/train_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,12 @@ def transforms(examples):
if ((epoch + 1) % args.save_model_epochs == 0
or (epoch + 1) % args.save_images_epochs == 0
or epoch == args.num_epochs - 1):
unet = accelerator.unwrap_model(model)
if args.use_ema:
ema_model.copy_to(unet.parameters())
pipeline = AudioDiffusionPipeline(
vqvae=vqvae,
unet=accelerator.unwrap_model(
ema_model.averaged_model if args.use_ema else model),
unet=unet,
mel=mel,
scheduler=noise_scheduler,
)
Expand Down

0 comments on commit 7bd9ee1

Please sign in to comment.