Shitao commited on
Commit
6209fe3
1 Parent(s): d765142

bugfix model.py

Browse files
Files changed (1) hide show
  1. OmniGen/model.py +3 -3
OmniGen/model.py CHANGED
@@ -347,7 +347,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
347
  x = self.final_layer(image_embedding, time_emb)
348
  latents = self.unpatchify(x, shapes[0], shapes[1])
349
 
350
- if past_key_values:
351
  return latents, past_key_values
352
  return latents
353
 
@@ -357,7 +357,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
357
  Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
358
  """
359
  self.llm.config.use_cache = use_kv_cache
360
- model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values)
361
  if use_img_cfg:
362
  cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
363
  cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
@@ -371,7 +371,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
371
 
372
 
373
  @torch.no_grad()
374
- def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
375
  """
376
  Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
377
  """
 
347
  x = self.final_layer(image_embedding, time_emb)
348
  latents = self.unpatchify(x, shapes[0], shapes[1])
349
 
350
+ if return_past_key_values:
351
  return latents, past_key_values
352
  return latents
353
 
 
357
  Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
358
  """
359
  self.llm.config.use_cache = use_kv_cache
360
+ model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True)
361
  if use_img_cfg:
362
  cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
363
  cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
 
371
 
372
 
373
  @torch.no_grad()
374
+ def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, return_past_key_values=True):
375
  """
376
  Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
377
  """