jimmycarter commited on
Commit
aa06741
1 Parent(s): a23a8bc

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +26 -26
pipeline.py CHANGED
@@ -1611,18 +1611,33 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
1611
  prompt_mask_input = prompt_mask
1612
  latent_model_input = latents
1613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1614
  if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1615
  # Concatenate prompt embeddings
1616
  prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1617
  pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1618
 
1619
- # # Concatenate text IDs if they are used
1620
- if text_ids is not None and negative_text_ids is not None:
1621
- text_ids_input = torch.cat([negative_text_ids, text_ids], dim=0)
1622
 
1623
  # Concatenate latent image IDs if they are used
1624
- if latent_image_ids is not None:
1625
- latent_image_ids_input = torch.cat([latent_image_ids, latent_image_ids], dim=0)
1626
 
1627
  # Concatenate prompt masks if they are used
1628
  if prompt_mask is not None and negative_mask is not None:
@@ -1643,37 +1658,22 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
1643
  # Prepare extra transformer arguments
1644
  extra_transformer_args = {}
1645
  if prompt_mask is not None:
1646
- extra_transformer_args["attention_mask"] = prompt_mask_input.to(device=self.transformer.device)
1647
 
1648
  # Forward pass through the transformer
1649
  noise_pred = self.transformer(
1650
- hidden_states=latent_model_input.to(device=self.transformer.device),
1651
  timestep=timestep / 1000,
1652
  guidance=guidance,
1653
- pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device),
1654
- encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device),
1655
- txt_ids=text_ids_input.to(device=self.transformer.device) if text_ids is not None else None,
1656
- img_ids=latent_image_ids_input.to(device=self.transformer.device) if latent_image_ids is not None else None,
1657
  joint_attention_kwargs=self.joint_attention_kwargs,
1658
  return_dict=False,
1659
  **extra_transformer_args,
1660
  )[0]
1661
 
1662
- if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1663
- progress_bar.set_postfix(
1664
- {
1665
- 'ts': timestep.detach().item() / 1000,
1666
- 'cfg': self._guidance_scale_real,
1667
- },
1668
- )
1669
- else:
1670
- progress_bar.set_postfix(
1671
- {
1672
- 'ts': timestep.detach().item() / 1000,
1673
- 'cfg': 'N/A',
1674
- },
1675
- )
1676
-
1677
  # Apply real CFG
1678
  if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1679
  if do_batch_cfg:
 
1611
  prompt_mask_input = prompt_mask
1612
  latent_model_input = latents
1613
 
1614
+ if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1615
+ progress_bar.set_postfix(
1616
+ {
1617
+ 'ts': timestep.detach().item() / 1000,
1618
+ 'cfg': self._guidance_scale_real,
1619
+ },
1620
+ )
1621
+ else:
1622
+ progress_bar.set_postfix(
1623
+ {
1624
+ 'ts': timestep.detach().item() / 1000,
1625
+ 'cfg': 'N/A',
1626
+ },
1627
+ )
1628
+
1629
  if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1630
  # Concatenate prompt embeddings
1631
  prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1632
  pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1633
 
1634
+ # Concatenate text IDs if they are used
1635
+ # if text_ids is not None and negative_text_ids is not None:
1636
+ # text_ids_input = torch.cat([negative_text_ids, text_ids], dim=0)
1637
 
1638
  # Concatenate latent image IDs if they are used
1639
+ # if latent_image_ids is not None:
1640
+ # latent_image_ids_input = torch.cat([latent_image_ids, latent_image_ids], dim=0)
1641
 
1642
  # Concatenate prompt masks if they are used
1643
  if prompt_mask is not None and negative_mask is not None:
 
1658
  # Prepare extra transformer arguments
1659
  extra_transformer_args = {}
1660
  if prompt_mask is not None:
1661
+ extra_transformer_args["attention_mask"] = prompt_mask_input.to(device=self.transformer.device).contiguous()
1662
 
1663
  # Forward pass through the transformer
1664
  noise_pred = self.transformer(
1665
+ hidden_states=latent_model_input.to(device=self.transformer.device).contiguous() ,
1666
  timestep=timestep / 1000,
1667
  guidance=guidance,
1668
+ pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device).contiguous() ,
1669
+ encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device).contiguous() ,
1670
+ txt_ids=text_ids_input.to(device=self.transformer.device).contiguous() if text_ids is not None else None,
1671
+ img_ids=latent_image_ids_input.to(device=self.transformer.device).contiguous() if latent_image_ids is not None else None,
1672
  joint_attention_kwargs=self.joint_attention_kwargs,
1673
  return_dict=False,
1674
  **extra_transformer_args,
1675
  )[0]
1676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1677
  # Apply real CFG
1678
  if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1679
  if do_batch_cfg: