jimmycarter
commited on
Commit
•
aa06741
1
Parent(s):
a23a8bc
Upload pipeline.py
Browse files- pipeline.py +26 -26
pipeline.py
CHANGED
@@ -1611,18 +1611,33 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
|
1611 |
prompt_mask_input = prompt_mask
|
1612 |
latent_model_input = latents
|
1613 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1614 |
if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
1615 |
# Concatenate prompt embeddings
|
1616 |
prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1617 |
pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
1618 |
|
1619 |
-
#
|
1620 |
-
if text_ids is not None and negative_text_ids is not None:
|
1621 |
-
|
1622 |
|
1623 |
# Concatenate latent image IDs if they are used
|
1624 |
-
if latent_image_ids is not None:
|
1625 |
-
|
1626 |
|
1627 |
# Concatenate prompt masks if they are used
|
1628 |
if prompt_mask is not None and negative_mask is not None:
|
@@ -1643,37 +1658,22 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
|
1643 |
# Prepare extra transformer arguments
|
1644 |
extra_transformer_args = {}
|
1645 |
if prompt_mask is not None:
|
1646 |
-
extra_transformer_args["attention_mask"] = prompt_mask_input.to(device=self.transformer.device)
|
1647 |
|
1648 |
# Forward pass through the transformer
|
1649 |
noise_pred = self.transformer(
|
1650 |
-
hidden_states=latent_model_input.to(device=self.transformer.device),
|
1651 |
timestep=timestep / 1000,
|
1652 |
guidance=guidance,
|
1653 |
-
pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device),
|
1654 |
-
encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device),
|
1655 |
-
txt_ids=text_ids_input.to(device=self.transformer.device) if text_ids is not None else None,
|
1656 |
-
img_ids=latent_image_ids_input.to(device=self.transformer.device) if latent_image_ids is not None else None,
|
1657 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
1658 |
return_dict=False,
|
1659 |
**extra_transformer_args,
|
1660 |
)[0]
|
1661 |
|
1662 |
-
if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
1663 |
-
progress_bar.set_postfix(
|
1664 |
-
{
|
1665 |
-
'ts': timestep.detach().item() / 1000,
|
1666 |
-
'cfg': self._guidance_scale_real,
|
1667 |
-
},
|
1668 |
-
)
|
1669 |
-
else:
|
1670 |
-
progress_bar.set_postfix(
|
1671 |
-
{
|
1672 |
-
'ts': timestep.detach().item() / 1000,
|
1673 |
-
'cfg': 'N/A',
|
1674 |
-
},
|
1675 |
-
)
|
1676 |
-
|
1677 |
# Apply real CFG
|
1678 |
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
1679 |
if do_batch_cfg:
|
|
|
1611 |
prompt_mask_input = prompt_mask
|
1612 |
latent_model_input = latents
|
1613 |
|
1614 |
+
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
1615 |
+
progress_bar.set_postfix(
|
1616 |
+
{
|
1617 |
+
'ts': timestep.detach().item() / 1000,
|
1618 |
+
'cfg': self._guidance_scale_real,
|
1619 |
+
},
|
1620 |
+
)
|
1621 |
+
else:
|
1622 |
+
progress_bar.set_postfix(
|
1623 |
+
{
|
1624 |
+
'ts': timestep.detach().item() / 1000,
|
1625 |
+
'cfg': 'N/A',
|
1626 |
+
},
|
1627 |
+
)
|
1628 |
+
|
1629 |
if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
1630 |
# Concatenate prompt embeddings
|
1631 |
prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1632 |
pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
1633 |
|
1634 |
+
# Concatenate text IDs if they are used
|
1635 |
+
# if text_ids is not None and negative_text_ids is not None:
|
1636 |
+
# text_ids_input = torch.cat([negative_text_ids, text_ids], dim=0)
|
1637 |
|
1638 |
# Concatenate latent image IDs if they are used
|
1639 |
+
# if latent_image_ids is not None:
|
1640 |
+
# latent_image_ids_input = torch.cat([latent_image_ids, latent_image_ids], dim=0)
|
1641 |
|
1642 |
# Concatenate prompt masks if they are used
|
1643 |
if prompt_mask is not None and negative_mask is not None:
|
|
|
1658 |
# Prepare extra transformer arguments
|
1659 |
extra_transformer_args = {}
|
1660 |
if prompt_mask is not None:
|
1661 |
+
extra_transformer_args["attention_mask"] = prompt_mask_input.to(device=self.transformer.device).contiguous()
|
1662 |
|
1663 |
# Forward pass through the transformer
|
1664 |
noise_pred = self.transformer(
|
1665 |
+
hidden_states=latent_model_input.to(device=self.transformer.device).contiguous() ,
|
1666 |
timestep=timestep / 1000,
|
1667 |
guidance=guidance,
|
1668 |
+
pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device).contiguous() ,
|
1669 |
+
encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device).contiguous() ,
|
1670 |
+
txt_ids=text_ids_input.to(device=self.transformer.device).contiguous() if text_ids is not None else None,
|
1671 |
+
img_ids=latent_image_ids_input.to(device=self.transformer.device).contiguous() if latent_image_ids is not None else None,
|
1672 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
1673 |
return_dict=False,
|
1674 |
**extra_transformer_args,
|
1675 |
)[0]
|
1676 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1677 |
# Apply real CFG
|
1678 |
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
1679 |
if do_batch_cfg:
|