jimmycarter
commited on
Upload 2 files
Browse files- README.md +8 -2
- pipeline.py +83 -56
README.md
CHANGED
@@ -49,8 +49,9 @@ negative_prompt = "blurry"
|
|
49 |
images = pipe(
|
50 |
prompt=prompt,
|
51 |
negative_prompt=negative_prompt,
|
|
|
52 |
)
|
53 |
-
images[0].save('chalkboard.png')
|
54 |
|
55 |
# If you have <=24 GB VRAM, try:
|
56 |
# ! pip install optimum-quanto
|
@@ -67,14 +68,19 @@ quantize(
|
|
67 |
)
|
68 |
freeze(pipe.transformer)
|
69 |
pipe.enable_model_cpu_offload()
|
|
|
|
|
70 |
images = pipe(
|
71 |
prompt=prompt,
|
72 |
negative_prompt=negative_prompt,
|
73 |
device=None,
|
|
|
74 |
)
|
75 |
-
images[0].save('chalkboard.png')
|
76 |
```
|
77 |
|
|
|
|
|
78 |
# Non-technical Report on Schnell De-distillation
|
79 |
|
80 |
Welcome to my non-technical report on de-distilling FLUX.1-schnell in the most un-scientific way possible with extremely limited resources. I'm not going to claim I made a good model, but I did make a model. It was trained on about 1,500 H100 hour equivalents.
|
|
|
49 |
images = pipe(
|
50 |
prompt=prompt,
|
51 |
negative_prompt=negative_prompt,
|
52 |
+
return_dict=False,
|
53 |
)
|
54 |
+
images[0][0].save('chalkboard.png')
|
55 |
|
56 |
# If you have <=24 GB VRAM, try:
|
57 |
# ! pip install optimum-quanto
|
|
|
68 |
)
|
69 |
freeze(pipe.transformer)
|
70 |
pipe.enable_model_cpu_offload()
|
71 |
+
|
72 |
+
# If you are still running out of memory, add do_batch_cfg=False below.
|
73 |
images = pipe(
|
74 |
prompt=prompt,
|
75 |
negative_prompt=negative_prompt,
|
76 |
device=None,
|
77 |
+
return_dict=False,
|
78 |
)
|
79 |
+
images[0][0].save('chalkboard.png')
|
80 |
```
|
81 |
|
82 |
+
For usage in ComfyUI, [a single transformer file is provided](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/transformer_legacy.safetensors) but note that ComfyUI does not presently support attention masks so your images may be degraded.
|
83 |
+
|
84 |
# Non-technical Report on Schnell De-distillation
|
85 |
|
86 |
Welcome to my non-technical report on de-distilling FLUX.1-schnell in the most un-scientific way possible with extremely limited resources. I'm not going to claim I made a good model, but I did make a model. It was trained on about 1,500 H100 hour equivalents.
|
pipeline.py
CHANGED
@@ -1376,8 +1376,7 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
|
1376 |
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1377 |
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1378 |
no_cfg_until_timestep: int = 0,
|
1379 |
-
|
1380 |
-
zero_using_prompt_mask: bool = False,
|
1381 |
device=torch.device('cuda'), # TODO let this work with non-cuda stuff? Might if you set this to None
|
1382 |
):
|
1383 |
r"""
|
@@ -1510,6 +1509,7 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
|
1510 |
)
|
1511 |
if _prompt_mask is not None:
|
1512 |
prompt_mask = _prompt_mask
|
|
|
1513 |
|
1514 |
if negative_prompt_2 == "" and negative_prompt != "":
|
1515 |
negative_prompt_2 = negative_prompt
|
@@ -1537,6 +1537,8 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
|
1537 |
if _neg_prompt_mask is not None:
|
1538 |
negative_mask = _neg_prompt_mask
|
1539 |
|
|
|
|
|
1540 |
# 4. Prepare latent variables
|
1541 |
num_channels_latents = self.transformer.config.in_channels // 4
|
1542 |
latents, latent_image_ids = self.prepare_latents(
|
@@ -1601,56 +1603,63 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
|
1601 |
if self.interrupt:
|
1602 |
continue
|
1603 |
|
1604 |
-
#
|
1605 |
-
|
1606 |
-
|
1607 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1608 |
|
|
|
1609 |
extra_transformer_args = {}
|
1610 |
-
if
|
1611 |
-
extra_transformer_args["attention_mask"] =
|
1612 |
-
elif use_prompt_mask and prompt_mask is not None and zero_using_prompt_mask:
|
1613 |
-
mask_tens = prompt_mask.unsqueeze(-1).to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
1614 |
-
prompt_embeds = prompt_embeds * mask_tens
|
1615 |
|
|
|
1616 |
noise_pred = self.transformer(
|
1617 |
-
hidden_states=
|
1618 |
-
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
1619 |
timestep=timestep / 1000,
|
1620 |
guidance=guidance,
|
1621 |
-
pooled_projections=
|
1622 |
-
encoder_hidden_states=
|
1623 |
-
txt_ids=text_ids,
|
1624 |
-
img_ids=
|
1625 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
1626 |
return_dict=False,
|
1627 |
**extra_transformer_args,
|
1628 |
)[0]
|
1629 |
|
1630 |
-
|
1631 |
-
if self._guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
1632 |
-
extra_transformer_args_neg = {}
|
1633 |
-
if negative_mask is not None:
|
1634 |
-
extra_transformer_args_neg["attention_mask"] = negative_mask
|
1635 |
-
extra_transformer_args_neg["attention_mask"] is not None
|
1636 |
-
|
1637 |
-
noise_pred_uncond = self.transformer(
|
1638 |
-
hidden_states=latents,
|
1639 |
-
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
1640 |
-
timestep=timestep / 1000,
|
1641 |
-
guidance=guidance,
|
1642 |
-
pooled_projections=negative_pooled_prompt_embeds,
|
1643 |
-
encoder_hidden_states=negative_prompt_embeds,
|
1644 |
-
txt_ids=negative_text_ids,
|
1645 |
-
img_ids=latent_image_ids.to(device=device),
|
1646 |
-
joint_attention_kwargs=self.joint_attention_kwargs,
|
1647 |
-
return_dict=False,
|
1648 |
-
**extra_transformer_args_neg,
|
1649 |
-
)[0]
|
1650 |
-
|
1651 |
-
noise_pred = noise_pred_uncond + self._guidance_scale_real * (
|
1652 |
-
noise_pred - noise_pred_uncond
|
1653 |
-
)
|
1654 |
progress_bar.set_postfix(
|
1655 |
{
|
1656 |
'ts': timestep.detach().item() / 1000,
|
@@ -1665,32 +1674,50 @@ class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
|
1665 |
},
|
1666 |
)
|
1667 |
|
1668 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1669 |
latents_dtype = latents.dtype
|
1670 |
-
latents = self.scheduler.step(
|
1671 |
-
noise_pred, t, latents, return_dict=False
|
1672 |
-
)[0]
|
1673 |
|
|
|
1674 |
if latents.dtype != latents_dtype:
|
1675 |
if torch.backends.mps.is_available():
|
1676 |
-
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
1677 |
latents = latents.to(latents_dtype)
|
1678 |
|
|
|
1679 |
if callback_on_step_end is not None:
|
1680 |
-
callback_kwargs = {}
|
1681 |
-
for k in callback_on_step_end_tensor_inputs:
|
1682 |
-
callback_kwargs[k] = locals()[k]
|
1683 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
|
|
|
|
1684 |
|
1685 |
-
|
1686 |
-
|
1687 |
-
|
1688 |
-
# call the callback, if provided
|
1689 |
-
if i == len(timesteps) - 1 or (
|
1690 |
-
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
1691 |
-
):
|
1692 |
progress_bar.update()
|
1693 |
|
|
|
1694 |
if XLA_AVAILABLE:
|
1695 |
xm.mark_step()
|
1696 |
|
|
|
1376 |
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1377 |
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1378 |
no_cfg_until_timestep: int = 0,
|
1379 |
+
do_batch_cfg: bool=True,
|
|
|
1380 |
device=torch.device('cuda'), # TODO let this work with non-cuda stuff? Might if you set this to None
|
1381 |
):
|
1382 |
r"""
|
|
|
1509 |
)
|
1510 |
if _prompt_mask is not None:
|
1511 |
prompt_mask = _prompt_mask
|
1512 |
+
assert prompt_mask is not None
|
1513 |
|
1514 |
if negative_prompt_2 == "" and negative_prompt != "":
|
1515 |
negative_prompt_2 = negative_prompt
|
|
|
1537 |
if _neg_prompt_mask is not None:
|
1538 |
negative_mask = _neg_prompt_mask
|
1539 |
|
1540 |
+
assert negative_mask is not None
|
1541 |
+
|
1542 |
# 4. Prepare latent variables
|
1543 |
num_channels_latents = self.transformer.config.in_channels // 4
|
1544 |
latents, latent_image_ids = self.prepare_latents(
|
|
|
1603 |
if self.interrupt:
|
1604 |
continue
|
1605 |
|
1606 |
+
# Prepare the latent model input
|
1607 |
+
prompt_embeds_input = prompt_embeds
|
1608 |
+
pooled_prompt_embeds_input = pooled_prompt_embeds
|
1609 |
+
text_ids_input = text_ids
|
1610 |
+
latent_image_ids_input = latent_image_ids
|
1611 |
+
prompt_mask_input = prompt_mask
|
1612 |
+
latent_model_input = latents
|
1613 |
+
|
1614 |
+
if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
1615 |
+
# Concatenate prompt embeddings
|
1616 |
+
prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1617 |
+
pooled_prompt_embeds_input = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
1618 |
+
|
1619 |
+
# # Concatenate text IDs if they are used
|
1620 |
+
# if text_ids is not None and negative_text_ids is not None:
|
1621 |
+
# text_ids_input = torch.cat([negative_text_ids, text_ids], dim=0)
|
1622 |
+
|
1623 |
+
# Concatenate latent image IDs if they are used
|
1624 |
+
# if latent_image_ids is not None:
|
1625 |
+
# latent_image_ids_input = torch.cat([latent_image_ids, latent_image_ids], dim=0)
|
1626 |
+
|
1627 |
+
# Concatenate prompt masks if they are used
|
1628 |
+
if prompt_mask is not None and negative_mask is not None:
|
1629 |
+
prompt_mask_input = torch.cat([negative_mask, prompt_mask], dim=0)
|
1630 |
+
# Duplicate latents for unconditional and conditional inputs
|
1631 |
+
latent_model_input = torch.cat([latents] * 2)
|
1632 |
+
|
1633 |
+
# Expand timestep to match batch size
|
1634 |
+
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
1635 |
+
|
1636 |
+
# Handle guidance
|
1637 |
+
if self.transformer.config.guidance_embeds:
|
1638 |
+
guidance = torch.tensor([guidance_scale], device=self.transformer.device)
|
1639 |
+
guidance = guidance.expand(latent_model_input.shape[0])
|
1640 |
+
else:
|
1641 |
+
guidance = None
|
1642 |
|
1643 |
+
# Prepare extra transformer arguments
|
1644 |
extra_transformer_args = {}
|
1645 |
+
if prompt_mask is not None:
|
1646 |
+
extra_transformer_args["attention_mask"] = prompt_mask_input.to(device=self.transformer.device)
|
|
|
|
|
|
|
1647 |
|
1648 |
+
# Forward pass through the transformer
|
1649 |
noise_pred = self.transformer(
|
1650 |
+
hidden_states=latent_model_input.to(device=self.transformer.device),
|
|
|
1651 |
timestep=timestep / 1000,
|
1652 |
guidance=guidance,
|
1653 |
+
pooled_projections=pooled_prompt_embeds_input.to(device=self.transformer.device),
|
1654 |
+
encoder_hidden_states=prompt_embeds_input.to(device=self.transformer.device),
|
1655 |
+
txt_ids=text_ids_input.to(device=self.transformer.device) if text_ids is not None else None,
|
1656 |
+
img_ids=latent_image_ids_input.to(device=self.transformer.device) if latent_image_ids is not None else None,
|
1657 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
1658 |
return_dict=False,
|
1659 |
**extra_transformer_args,
|
1660 |
)[0]
|
1661 |
|
1662 |
+
if do_batch_cfg and guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1663 |
progress_bar.set_postfix(
|
1664 |
{
|
1665 |
'ts': timestep.detach().item() / 1000,
|
|
|
1674 |
},
|
1675 |
)
|
1676 |
|
1677 |
+
# Apply real CFG
|
1678 |
+
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
1679 |
+
if do_batch_cfg:
|
1680 |
+
# Batched CFG: Split the noise prediction into unconditional and conditional parts
|
1681 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
1682 |
+
noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred_cond - noise_pred_uncond)
|
1683 |
+
else:
|
1684 |
+
# Sequential CFG: Compute unconditional noise prediction separately
|
1685 |
+
noise_pred_uncond = self.transformer(
|
1686 |
+
hidden_states=latents.to(device=self.transformer.device),
|
1687 |
+
timestep=timestep / 1000,
|
1688 |
+
guidance=guidance,
|
1689 |
+
pooled_projections=negative_pooled_prompt_embeds.to(device=self.transformer.device),
|
1690 |
+
encoder_hidden_states=negative_prompt_embeds.to(device=self.transformer.device),
|
1691 |
+
txt_ids=negative_text_ids.to(device=self.transformer.device) if negative_text_ids is not None else None,
|
1692 |
+
img_ids=latent_image_ids.to(device=self.transformer.device) if latent_image_ids is not None else None,
|
1693 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
1694 |
+
return_dict=False,
|
1695 |
+
)[0]
|
1696 |
+
|
1697 |
+
# Combine conditional and unconditional predictions
|
1698 |
+
noise_pred = noise_pred_uncond + guidance_scale_real * (noise_pred - noise_pred_uncond)
|
1699 |
+
|
1700 |
+
# Compute the previous noisy sample x_t -> x_t-1
|
1701 |
latents_dtype = latents.dtype
|
1702 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
|
|
|
|
1703 |
|
1704 |
+
# Ensure latents have the correct dtype
|
1705 |
if latents.dtype != latents_dtype:
|
1706 |
if torch.backends.mps.is_available():
|
|
|
1707 |
latents = latents.to(latents_dtype)
|
1708 |
|
1709 |
+
# Callback at the end of the step, if provided
|
1710 |
if callback_on_step_end is not None:
|
1711 |
+
callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
|
|
|
|
|
1712 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1713 |
+
latents = callback_outputs.get("latents", latents)
|
1714 |
+
prompt_embeds = callback_outputs.get("prompt_embeds", prompt_embeds)
|
1715 |
|
1716 |
+
# Update the progress bar
|
1717 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
|
|
|
|
|
|
|
|
|
1718 |
progress_bar.update()
|
1719 |
|
1720 |
+
# Mark step for XLA devices
|
1721 |
if XLA_AVAILABLE:
|
1722 |
xm.mark_step()
|
1723 |
|