Linoy Tsaban commited on
Commit
0052810
1 Parent(s): b7b2a49

Update pipeline_semantic_stable_diffusion_img2img_solver.py

Browse files

merging Manuel's updates - removing edit_momentum and adjustments to attention store

pipeline_semantic_stable_diffusion_img2img_solver.py CHANGED
@@ -36,21 +36,19 @@ class AttentionStore():
36
 
37
  def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False):
38
  # attn.shape = batch_size * head_size, seq_len query, seq_len_key
39
- bs = 2 + int(PnP) + editing_prompts
40
- skip = 2 if PnP else 1 # skip PnP & unconditional
41
-
42
- head_size = int(attn.shape[0] / self.batch_size)
43
- attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3)
44
- source_batch_size = int(attn.shape[1] // bs)
45
- self.forward(
46
- attn[:, skip * source_batch_size:],
47
- is_cross,
48
- place_in_unet)
49
 
50
  def forward(self, attn, is_cross: bool, place_in_unet: str):
51
  key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
52
- if attn.shape[1] <= 32 ** 2: # avoid memory overhead
53
- self.step_store[key].append(attn)
54
 
55
  def between_steps(self, store_step=True):
56
  if store_step:
@@ -96,12 +94,13 @@ class AttentionStore():
96
  out = out.sum(1) / out.shape[1]
97
  return out
98
 
99
- def __init__(self, average: bool, batch_size=1):
100
  self.step_store = self.get_empty_store()
101
  self.attention_store = []
102
  self.cur_step = 0
103
  self.average = average
104
  self.batch_size = batch_size
 
105
 
106
 
107
  class CrossAttnProcessor:
@@ -433,10 +432,10 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
433
 
434
  # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
435
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents):
436
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
437
 
438
- if latents.shape != shape:
439
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
440
 
441
  latents = latents.to(device)
442
 
@@ -456,7 +455,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
456
  else:
457
  continue
458
 
459
- if "attn2" in name:
460
  attn_procs[name] = CrossAttnProcessor(
461
  attention_store=attention_store,
462
  place_in_unet=place_in_unet,
@@ -488,12 +487,11 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
488
  editing_prompt_embeddings: Optional[torch.Tensor] = None,
489
  reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
490
  edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
491
- edit_warmup_steps: Optional[Union[int, List[int]]] = 10,
492
  edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
493
  edit_threshold: Optional[Union[float, List[float]]] = 0.9,
494
  user_mask: Optional[torch.FloatTensor] = None,
495
- edit_momentum_scale: Optional[float] = 0.1,
496
- edit_mom_beta: Optional[float] = 0.4,
497
  edit_weights: Optional[List[float]] = None,
498
  sem_guidance: Optional[List[torch.Tensor]] = None,
499
  verbose=True,
@@ -788,8 +786,6 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
788
  # 6. Prepare extra step kwargs.
789
  extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
790
 
791
- # Initialize edit_momentum to None
792
- edit_momentum = None
793
 
794
  self.uncond_estimates = None
795
  self.text_estimates = None
@@ -833,12 +829,10 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
833
  self.text_estimates = torch.zeros((len(timesteps), *noise_pred_text.shape))
834
  self.text_estimates[i] = noise_pred_text.detach().cpu()
835
 
836
- if edit_momentum is None:
837
- edit_momentum = torch.zeros_like(noise_guidance)
838
 
839
  if sem_guidance is not None and len(sem_guidance) > i:
840
  edit_guidance = sem_guidance[i].to(self.device)
841
- edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * edit_guidance
842
  noise_guidance = noise_guidance + edit_guidance
843
 
844
  elif enable_edit_guidance:
 
36
 
37
  def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False):
38
  # attn.shape = batch_size * head_size, seq_len query, seq_len_key
39
+ if attn.shape[1] <= self.max_size:
40
+ bs = 1 + int(PnP) + editing_prompts
41
+ skip = 2 if PnP else 1 # skip PnP & unconditional
42
+ attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3)
43
+ source_batch_size = int(attn.shape[1] // bs)
44
+ self.forward(
45
+ attn[:, skip * source_batch_size:],
46
+ is_cross,
47
+ place_in_unet)
 
48
 
49
  def forward(self, attn, is_cross: bool, place_in_unet: str):
50
  key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
51
+ self.step_store[key].append(attn)
 
52
 
53
  def between_steps(self, store_step=True):
54
  if store_step:
 
94
  out = out.sum(1) / out.shape[1]
95
  return out
96
 
97
+ def __init__(self, average: bool, batch_size=1, max_resolution=16):
98
  self.step_store = self.get_empty_store()
99
  self.attention_store = []
100
  self.cur_step = 0
101
  self.average = average
102
  self.batch_size = batch_size
103
+ self.max_size = max_resolution ** 2
104
 
105
 
106
  class CrossAttnProcessor:
 
432
 
433
  # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
434
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents):
435
+ # shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
436
 
437
+ # if latents.shape != shape:
438
+ # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
439
 
440
  latents = latents.to(device)
441
 
 
455
  else:
456
  continue
457
 
458
+ if "attn2" in name and place_in_unet != 'mid':
459
  attn_procs[name] = CrossAttnProcessor(
460
  attention_store=attention_store,
461
  place_in_unet=place_in_unet,
 
487
  editing_prompt_embeddings: Optional[torch.Tensor] = None,
488
  reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
489
  edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
490
+ edit_warmup_steps: Optional[Union[int, List[int]]] = 0,
491
  edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
492
  edit_threshold: Optional[Union[float, List[float]]] = 0.9,
493
  user_mask: Optional[torch.FloatTensor] = None,
494
+
 
495
  edit_weights: Optional[List[float]] = None,
496
  sem_guidance: Optional[List[torch.Tensor]] = None,
497
  verbose=True,
 
786
  # 6. Prepare extra step kwargs.
787
  extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
788
 
 
 
789
 
790
  self.uncond_estimates = None
791
  self.text_estimates = None
 
829
  self.text_estimates = torch.zeros((len(timesteps), *noise_pred_text.shape))
830
  self.text_estimates[i] = noise_pred_text.detach().cpu()
831
 
832
+
 
833
 
834
  if sem_guidance is not None and len(sem_guidance) > i:
835
  edit_guidance = sem_guidance[i].to(self.device)
 
836
  noise_guidance = noise_guidance + edit_guidance
837
 
838
  elif enable_edit_guidance: