bugfix model.py
Browse files- OmniGen/model.py +3 -3
OmniGen/model.py
CHANGED
@@ -347,7 +347,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
347 |
x = self.final_layer(image_embedding, time_emb)
|
348 |
latents = self.unpatchify(x, shapes[0], shapes[1])
|
349 |
|
350 |
-
if
|
351 |
return latents, past_key_values
|
352 |
return latents
|
353 |
|
@@ -357,7 +357,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
357 |
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
358 |
"""
|
359 |
self.llm.config.use_cache = use_kv_cache
|
360 |
-
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values)
|
361 |
if use_img_cfg:
|
362 |
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
363 |
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
@@ -371,7 +371,7 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
371 |
|
372 |
|
373 |
@torch.no_grad()
|
374 |
-
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
|
375 |
"""
|
376 |
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
377 |
"""
|
|
|
347 |
x = self.final_layer(image_embedding, time_emb)
|
348 |
latents = self.unpatchify(x, shapes[0], shapes[1])
|
349 |
|
350 |
+
if return_past_key_values:
|
351 |
return latents, past_key_values
|
352 |
return latents
|
353 |
|
|
|
357 |
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
358 |
"""
|
359 |
self.llm.config.use_cache = use_kv_cache
|
360 |
+
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True)
|
361 |
if use_img_cfg:
|
362 |
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
363 |
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
|
|
371 |
|
372 |
|
373 |
@torch.no_grad()
|
374 |
+
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, return_past_key_values=True):
|
375 |
"""
|
376 |
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
377 |
"""
|