jimmycarter commited on
Commit
7d024b3
·
verified ·
1 Parent(s): 39f4661

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +8 -2
  2. pipeline.py +83 -56
README.md CHANGED
@@ -49,8 +49,9 @@ negative_prompt = "blurry"
49
  images = pipe(
50
  prompt=prompt,
51
  negative_prompt=negative_prompt,
 
52
  )
53
- images[0].save('chalkboard.png')
54
 
55
  # If you have <=24 GB VRAM, try:
56
  # ! pip install optimum-quanto
@@ -67,14 +68,19 @@ quantize(
67
  )
68
  freeze(pipe.transformer)
69
  pipe.enable_model_cpu_offload()
 
 
70
  images = pipe(
71
  prompt=prompt,
72
  negative_prompt=negative_prompt,
73
  device=None,
 
74
  )
75
- images[0].save('chalkboard.png')
76
  ```
77
 
 
 
78
  # Non-technical Report on Schnell De-distillation
79
 
80
  Welcome to my non-technical report on de-distilling FLUX.1-schnell in the most un-scientific way possible with extremely limited resources. I'm not going to claim I made a good model, but I did make a model. It was trained on about 1,500 H100 hour equivalents.
 
49
  images = pipe(
50
  prompt=prompt,
51
  negative_prompt=negative_prompt,
52
+ return_dict=False,
53
  )
54
+ images[0][0].save('chalkboard.png')
55
 
56
  # If you have <=24 GB VRAM, try:
57
  # ! pip install optimum-quanto
 
68
  )
69
  freeze(pipe.transformer)
70
  pipe.enable_model_cpu_offload()
71
+
72
+ # If you are still running out of memory, add do_batch_cfg=False below.
73
  images = pipe(
74
  prompt=prompt,
75
  negative_prompt=negative_prompt,
76
  device=None,
77
+ return_dict=False,
78
  )
79
+ images[0][0].save('chalkboard.png')
80
  ```
81
 
82
+ For usage in ComfyUI, [a single transformer file is provided](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/transformer_legacy.safetensors) but note that ComfyUI does not presently support attention masks so your images may be degraded.
83
+
84
  # Non-technical Report on Schnell De-distillation
85
 
86
  Welcome to my non-technical report on de-distilling FLUX.1-schnell in the most un-scientific way possible with extremely limited resources. I'm not going to claim I made a good model, but I did make a model. It was trained on about 1,500 H100 hour equivalents.
pipeline.py CHANGED
@@ -1376,8 +1376,7 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
1376
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1377
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1378
  no_cfg_until_timestep: int = 0,
1379
- use_prompt_mask: bool = True,
1380
- zero_using_prompt_mask: bool = False,
1381
  device=torch.device('cuda'), # TODO let this work with non-cuda stuff? Might if you set this to None
1382
  ):
1383
  r"""
@@ -1510,6 +1509,7 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
1510
  )
1511
  if _prompt_mask is not None:
1512
  prompt_mask = _prompt_mask
 
1513
 
1514
  if negative_prompt_2 == "" and negative_prompt != "":
1515
  negative_prompt_2 = negative_prompt
@@ -1537,6 +1537,8 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
1537
  if _neg_prompt_mask is not None:
1538
  negative_mask = _neg_prompt_mask
1539
 
 
 
1540
  # 4. Prepare latent variables
1541
  num_channels_latents = self.transformer.config.in_channels // 4
1542
  latents, latent_image_ids = self.prepare_latents(
@@ -1601,56 +1603,63 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
1601
  if self.interrupt:
1602
  continue
1603
 
1604
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1605
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
1606
-
1607
- assert prompt_mask is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1608
 
 
1609
  extra_transformer_args = {}
1610
- if use_prompt_mask and prompt_mask is not None and not zero_using_prompt_mask:
1611
- extra_transformer_args["attention_mask"] = prompt_mask
1612
- elif use_prompt_mask and prompt_mask is not None and zero_using_prompt_mask:
1613
- mask_tens = prompt_mask.unsqueeze(-1).to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
1614
- prompt_embeds = prompt_embeds * mask_tens
1615
 
 
1616
  noise_pred = self.transformer(
1617
- hidden_states=latents,
1618
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
1619
  timestep=timestep / 1000,
1620
  guidance=guidance,
1621
- pooled_projections=pooled_prompt_embeds,
1622
- encoder_hidden_states=prompt_embeds,
1623
- txt_ids=text_ids,
1624
- img_ids=latent_image_ids.to(device=device),
1625
  joint_attention_kwargs=self.joint_attention_kwargs,
1626
  return_dict=False,
1627
  **extra_transformer_args,
1628
  )[0]
1629
 
1630
- # TODO optionally use batch prediction to speed this up.
1631
- if self._guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1632
- extra_transformer_args_neg = {}
1633
- if negative_mask is not None:
1634
- extra_transformer_args_neg["attention_mask"] = negative_mask
1635
- extra_transformer_args_neg["attention_mask"] is not None
1636
-
1637
- noise_pred_uncond = self.transformer(
1638
- hidden_states=latents,
1639
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
1640
- timestep=timestep / 1000,
1641
- guidance=guidance,
1642
- pooled_projections=negative_pooled_prompt_embeds,
1643
- encoder_hidden_states=negative_prompt_embeds,
1644
- txt_ids=negative_text_ids,
1645
- img_ids=latent_image_ids.to(device=device),
1646
- joint_attention_kwargs=self.joint_attention_kwargs,
1647
- return_dict=False,
1648
- **extra_transformer_args_neg,
1649
- )[0]
1650
-
1651
- noise_pred = noise_pred_uncond + self._guidance_scale_real * (
1652
- noise_pred - noise_pred_uncond
1653
- )
1654
  progress_bar.set_postfix(
1655
  {
1656
  'ts': timestep.detach().item() / 1000,
@@ -1665,32 +1674,50 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
1665
  },
1666
  )
1667
 
1668
- # compute the previous noisy sample x_t -> x_t-1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1669
  latents_dtype = latents.dtype
1670
- latents = self.scheduler.step(
1671
- noise_pred, t, latents, return_dict=False
1672
- )[0]
1673
 
 
1674
  if latents.dtype != latents_dtype:
1675
  if torch.backends.mps.is_available():
1676
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1677
  latents = latents.to(latents_dtype)
1678
 
 
1679
  if callback_on_step_end is not None:
1680
- callback_kwargs = {}
1681
- for k in callback_on_step_end_tensor_inputs:
1682
- callback_kwargs[k] = locals()[k]
1683
  callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
 
 
1684
 
1685
- latents = callback_outputs.pop("latents", latents)
1686
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1687
-
1688
- # call the callback, if provided
1689
- if i == len(timesteps) - 1 or (
1690
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1691
- ):
1692
  progress_bar.update()
1693
 
 
1694
  if XLA_AVAILABLE:
1695
  xm.mark_step()
1696
 
 
1376
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1377
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1378
  no_cfg_until_timestep: int = 0,
1379
+ do_batch_cfg: bool=True,
 
1380
  device=torch.device('cuda'), # TODO let this work with non-cuda stuff? Might if you set this to None
1381
  ):
1382
  r"""
 
1509
  )
1510
  if _prompt_mask is not None:
1511
  prompt_mask = _prompt_mask
1512
+ assert prompt_mask is not None
1513
 
1514
  if negative_prompt_2 == "" and negative_prompt != "":
1515
  negative_prompt_2 = negative_prompt
 
1537
  if _neg_prompt_mask is not None:
1538
  negative_mask = _neg_prompt_mask
1539
 
1540
+ assert negative_mask is not None
1541
+
1542
  # 4. Prepare latent variables
1543
  num_channels_latents = self.transformer.config.in_channels // 4
1544
  latents, latent_image_ids = self.prepare_latents(
 
1603
  if self.interrupt:
1604
  continue
1605
 
1606
+ # Prepare the latent model input
1607
+ prompt_embeds_input = prompt_embeds
1608
+ pooled_prompt_embeds_input = pooled_prompt_embeds
1609
+ text_ids_input = text_ids
1610
+ latent_image_ids_input = latent_image_ids
1611
+ prompt_mask_input = prompt_mask
1612
+ latent_model_input = latents
1613
+
1614
+ if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1615
+ # Concatenate prompt embeddings
1616
+ prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1617
+ pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1618
+
1619
+ # # Concatenate text IDs if they are used
1620
+ # if text_ids is not None and negative_text_ids is not None:
1621
+ # text_ids_input = torch.cat([negative_text_ids, text_ids], dim=0)
1622
+
1623
+ # Concatenate latent image IDs if they are used
1624
+ # if latent_image_ids is not None:
1625
+ # latent_image_ids_input = torch.cat([latent_image_ids, latent_image_ids], dim=0)
1626
+
1627
+ # Concatenate prompt masks if they are used
1628
+ if prompt_mask is not None and negative_mask is not None:
1629
+ prompt_mask_input = torch.cat([negative_mask, prompt_mask], dim=0)
1630
+ # Duplicate latents for unconditional and conditional inputs
1631
+ latent_model_input = torch.cat([latents] * 2)
1632
+
1633
+ # Expand timestep to match batch size
1634
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
1635
+
1636
+ # Handle guidance
1637
+ if self.transformer.config.guidance_embeds:
1638
+ guidance = torch.tensor([guidance_scale], device=self.transformer.device)
1639
+ guidance = guidance.expand(latent_model_input.shape[0])
1640
+ else:
1641
+ guidance = None
1642
 
1643
+ # Prepare extra transformer arguments
1644
  extra_transformer_args = {}
1645
+ if prompt_mask is not None:
1646
+ extra_transformer_args["attention_mask"] = prompt_mask_input.to(device=self.transformer.device)
 
 
 
1647
 
1648
+ # Forward pass through the transformer
1649
  noise_pred = self.transformer(
1650
+ hidden_states=latent_model_input.to(device=self.transformer.device),
 
1651
  timestep=timestep / 1000,
1652
  guidance=guidance,
1653
+ pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device),
1654
+ encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device),
1655
+ txt_ids=text_ids_input.to(device=self.transformer.device) if text_ids is not None else None,
1656
+ img_ids=latent_image_ids_input.to(device=self.transformer.device) if latent_image_ids is not None else None,
1657
  joint_attention_kwargs=self.joint_attention_kwargs,
1658
  return_dict=False,
1659
  **extra_transformer_args,
1660
  )[0]
1661
 
1662
+ if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1663
  progress_bar.set_postfix(
1664
  {
1665
  'ts': timestep.detach().item() / 1000,
 
1674
  },
1675
  )
1676
 
1677
+ # Apply real CFG
1678
+ if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
1679
+ if do_batch_cfg:
1680
+ # Batched CFG: Split the noise prediction into unconditional and conditional parts
1681
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
1682
+ noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred_cond - noise_pred_uncond)
1683
+ else:
1684
+ # Sequential CFG: Compute unconditional noise prediction separately
1685
+ noise_pred_uncond = self.transformer(
1686
+ hidden_states=latents.to(device=self.transformer.device),
1687
+ timestep=timestep / 1000,
1688
+ guidance=guidance,
1689
+ pooled_projections=negative_pooled_prompt_embeds.to(device=self.transformer.device),
1690
+ encoder_hidden_states=negative_prompt_embeds.to(device=self.transformer.device),
1691
+ txt_ids=negative_text_ids.to(device=self.transformer.device) if negative_text_ids is not None else None,
1692
+ img_ids=latent_image_ids.to(device=self.transformer.device) if latent_image_ids is not None else None,
1693
+ joint_attention_kwargs=self.joint_attention_kwargs,
1694
+ return_dict=False,
1695
+ )[0]
1696
+
1697
+ # Combine conditional and unconditional predictions
1698
+ noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred - noise_pred_uncond)
1699
+
1700
+ # Compute the previous noisy sample x_t -> x_t-1
1701
  latents_dtype = latents.dtype
1702
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
 
1703
 
1704
+ # Ensure latents have the correct dtype
1705
  if latents.dtype != latents_dtype:
1706
  if torch.backends.mps.is_available():
 
1707
  latents = latents.to(latents_dtype)
1708
 
1709
+ # Callback at the end of the step, if provided
1710
  if callback_on_step_end is not None:
1711
+ callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
 
 
1712
  callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1713
+ latents = callback_outputs.get("latents", latents)
1714
+ prompt_embeds = callback_outputs.get("prompt_embeds", prompt_embeds)
1715
 
1716
+ # Update the progress bar
1717
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
 
 
 
 
1718
  progress_bar.update()
1719
 
1720
+ # Mark step for XLA devices
1721
  if XLA_AVAILABLE:
1722
  xm.mark_step()
1723