teticio commited on
Commit
7bd9ee1
1 Parent(s): d8c6a4b

remove deprecated averaged_model

Browse files
audiodiffusion/__init__.py CHANGED
@@ -9,7 +9,7 @@ from tqdm.auto import tqdm
9
  # from diffusers import AudioDiffusionPipeline
10
  from .pipeline_audio_diffusion import AudioDiffusionPipeline
11
 
12
- VERSION = "1.5.0"
13
 
14
 
15
  class AudioDiffusion:
 
9
  # from diffusers import AudioDiffusionPipeline
10
  from .pipeline_audio_diffusion import AudioDiffusionPipeline
11
 
12
+ VERSION = "1.5.1"
13
 
14
 
15
  class AudioDiffusion:
scripts/train_unet.py CHANGED
@@ -304,10 +304,12 @@ def main(args):
304
  if ((epoch + 1) % args.save_model_epochs == 0
305
  or (epoch + 1) % args.save_images_epochs == 0
306
  or epoch == args.num_epochs - 1):
 
 
 
307
  pipeline = AudioDiffusionPipeline(
308
  vqvae=vqvae,
309
- unet=accelerator.unwrap_model(
310
- ema_model.averaged_model if args.use_ema else model),
311
  mel=mel,
312
  scheduler=noise_scheduler,
313
  )
 
304
  if ((epoch + 1) % args.save_model_epochs == 0
305
  or (epoch + 1) % args.save_images_epochs == 0
306
  or epoch == args.num_epochs - 1):
307
+ unet = accelerator.unwrap_model(model)
308
+ if args.use_ema:
309
+ ema_model.copy_to(unet.parameters())
310
  pipeline = AudioDiffusionPipeline(
311
  vqvae=vqvae,
312
+ unet=unet,
 
313
  mel=mel,
314
  scheduler=noise_scheduler,
315
  )