ironjr commited on
Commit
1fceb42
1 Parent(s): 749679c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +44 -0
model.py CHANGED
@@ -741,6 +741,50 @@ class StreamMultiDiffusion(nn.Module):
741
  self.ready_checklist['layers_ready'] = True
742
  self.ready_checklist['flushed'] = False
743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
  @torch.no_grad()
745
  def update_single_layer(
746
  self,
 
741
  self.ready_checklist['layers_ready'] = True
742
  self.ready_checklist['flushed'] = False
743
 
744
+ @torch.no_grad()
745
+ def update_masks(
746
+ self,
747
+ masks: Optional[Union[torch.Tensor, Image.Image, List[Image.Image]]] = None,
748
+ mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
749
+ mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
750
+ ) -> None:
751
+ if not self.ready_checklist['background_registered']:
752
+ print('[WARNING] Register background image first! Request ignored.')
753
+ return
754
+
755
+ ### Register new masks
756
+
757
+ if isinstance(masks, Image.Image):
758
+ masks = [masks]
759
+ n = len(masks) if masks is not None else 0
760
+
761
+ # Modificiation.
762
+ masks, mask_strengths, mask_stds, original_masks = self.process_mask(masks, mask_strengths, mask_stds)
763
+
764
+ self.counts = masks.sum(dim=0) # (T, 1, h, w)
765
+ self.bg_mask = (1 - self.counts).clip_(0, 1) # (T, 1, h, w)
766
+ self.masks = masks # (p, T, 1, h, w)
767
+ self.mask_strengths = mask_strengths # (p,)
768
+ self.mask_stds = mask_stds # (p,)
769
+ self.original_masks = original_masks # (p, 1, h, w)
770
+
771
+ if p > n:
772
+ # Add more masks: counts and bg_masks are not changed, but only masks are changed.
773
+ self.masks = torch.cat((
774
+ self.masks,
775
+ torch.zeros(
776
+ (p - n, self.batch_size, 1, self.latent_height, self.latent_width),
777
+ dtype=self.dtype,
778
+ device=self.device,
779
+ ),
780
+ ), dim=0)
781
+ print(f'[WARNING] Detected more prompts ({p}) than masks ({n}). '
782
+ 'Automatically adds blank masks for the additional prompts.')
783
+ elif p < n:
784
+ # Warns user to add more prompts.
785
+ print(f'[WARNING] Detected more masks ({n}) than prompts ({p}). '
786
+ 'Additional masks are ignored until more prompts are provided.')
787
+
788
  @torch.no_grad()
789
  def update_single_layer(
790
  self,