Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -741,6 +741,50 @@ class StreamMultiDiffusion(nn.Module):
|
|
741 |
self.ready_checklist['layers_ready'] = True
|
742 |
self.ready_checklist['flushed'] = False
|
743 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
744 |
@torch.no_grad()
|
745 |
def update_single_layer(
|
746 |
self,
|
|
|
741 |
self.ready_checklist['layers_ready'] = True
|
742 |
self.ready_checklist['flushed'] = False
|
743 |
|
744 |
+
@torch.no_grad()
|
745 |
+
def update_masks(
|
746 |
+
self,
|
747 |
+
masks: Optional[Union[torch.Tensor, Image.Image, List[Image.Image]]] = None,
|
748 |
+
mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
749 |
+
mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
750 |
+
) -> None:
|
751 |
+
if not self.ready_checklist['background_registered']:
|
752 |
+
print('[WARNING] Register background image first! Request ignored.')
|
753 |
+
return
|
754 |
+
|
755 |
+
### Register new masks
|
756 |
+
|
757 |
+
if isinstance(masks, Image.Image):
|
758 |
+
masks = [masks]
|
759 |
+
n = len(masks) if masks is not None else 0
|
760 |
+
|
761 |
+
# Modificiation.
|
762 |
+
masks, mask_strengths, mask_stds, original_masks = self.process_mask(masks, mask_strengths, mask_stds)
|
763 |
+
|
764 |
+
self.counts = masks.sum(dim=0) # (T, 1, h, w)
|
765 |
+
self.bg_mask = (1 - self.counts).clip_(0, 1) # (T, 1, h, w)
|
766 |
+
self.masks = masks # (p, T, 1, h, w)
|
767 |
+
self.mask_strengths = mask_strengths # (p,)
|
768 |
+
self.mask_stds = mask_stds # (p,)
|
769 |
+
self.original_masks = original_masks # (p, 1, h, w)
|
770 |
+
|
771 |
+
if p > n:
|
772 |
+
# Add more masks: counts and bg_masks are not changed, but only masks are changed.
|
773 |
+
self.masks = torch.cat((
|
774 |
+
self.masks,
|
775 |
+
torch.zeros(
|
776 |
+
(p - n, self.batch_size, 1, self.latent_height, self.latent_width),
|
777 |
+
dtype=self.dtype,
|
778 |
+
device=self.device,
|
779 |
+
),
|
780 |
+
), dim=0)
|
781 |
+
print(f'[WARNING] Detected more prompts ({p}) than masks ({n}). '
|
782 |
+
'Automatically adds blank masks for the additional prompts.')
|
783 |
+
elif p < n:
|
784 |
+
# Warns user to add more prompts.
|
785 |
+
print(f'[WARNING] Detected more masks ({n}) than prompts ({p}). '
|
786 |
+
'Additional masks are ignored until more prompts are provided.')
|
787 |
+
|
788 |
@torch.no_grad()
|
789 |
def update_single_layer(
|
790 |
self,
|