fffiloni commited on
Commit
c6e7b6c
·
verified ·
1 Parent(s): 21cbd99

update LatentNoiseTrainer class to include iteration callbacks

Browse files
Files changed (1) hide show
  1. training/trainer.py +4 -1
training/trainer.py CHANGED
@@ -51,6 +51,7 @@ class LatentNoiseTrainer:
51
  prompt: str,
52
  optimizer: torch.optim.Optimizer,
53
  save_dir: Optional[str] = None,
 
54
  ) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]:
55
  logging.info(f"Optimizing latents for prompt '{prompt}'.")
56
  best_loss = torch.inf
@@ -120,6 +121,8 @@ class LatentNoiseTrainer:
120
  image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
121
  image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
122
  image_pil.save(f"{save_dir}/{iteration}.png")
 
 
123
  image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
124
  image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
125
- return image_pil, initial_rewards, best_rewards
 
51
  prompt: str,
52
  optimizer: torch.optim.Optimizer,
53
  save_dir: Optional[str] = None,
54
+ progress_callback=None,
55
  ) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]:
56
  logging.info(f"Optimizing latents for prompt '{prompt}'.")
57
  best_loss = torch.inf
 
121
  image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
122
  image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
123
  image_pil.save(f"{save_dir}/{iteration}.png")
124
+ if progress_callback:
125
+ progress_callback(iteration + 1)
126
  image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
127
  image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
128
+ return image_pil, initial_rewards, best_rewards