File size: 6,455 Bytes
4233ccc f4a6582 4233ccc f4a6582 4233ccc f4a6582 1db6ee6 f4a6582 1db6ee6 f4a6582 1db6ee6 f4a6582 1db6ee6 f4a6582 1db6ee6 f4a6582 4233ccc e785113 8bbc175 4233ccc f4a6582 4233ccc f4a6582 4233ccc f4a6582 4233ccc f4a6582 4233ccc e785113 4233ccc e785113 8bbc175 e785113 f4a6582 4233ccc f4a6582 4233ccc f4a6582 4233ccc f4a6582 e785113 4233ccc f4a6582 4233ccc e785113 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
---
language:
- en
library_name: diffusers
license: other
license_name: flux-1-dev-non-commercial-license
license_link: LICENSE.md
---
LoRA is the de-facto technique for quickly adapting a pre-trained large model on custom use cases. Typically, LoRA matrices are low-rank in nature. Now, the word “low” can vary depending on the context, but usually, for a large diffusion model like [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev), a rank of 128 can be considered high. This is because users may often need to keep multiple LoRAs unfused in memory to be able to quickly switch between them. So, the higher the rank, the higher the memory on top of the volume of the base model.
So, what if we could take an existing LoRA checkpoint with a high rank and reduce its rank even further to:
- Reduce the memory requirements
- Enable use cases like `torch.compile()` (which require all the LoRAs to be of the same rank to avoid re-compilation)
This project explores two options to reduce the original LoRA checkpoint into an even smaller one:
* Random projections
* SVD
## Random projections
Basic idea:
1. Generate a random projection matrix: `R = torch.randn(new_rank, original_rank, dtype=torch.float32) / torch.sqrt(torch.tensor(new_rank, dtype=torch.float32))`.
2. Then compute the new LoRA up and down matrices:
```python
# We keep R in torch.float32 for numerical stability.
lora_A_new = (R @ lora_A.to(R.dtype)).to(lora_A.dtype)
lora_B_new = (lora_B.to(R.dtype) @ R.T).to(lora_B.dtype)
```
If `lora_A` and `lora_B` had shapes of (42, 3072) and (3072, 42) respectively, `lora_A_new` and `lora_B_new` will have (4, 3072) and (3072, 4), respectively.
### Results
Tried on this LoRA: [https://huggingface.co/glif/how2draw](https://huggingface.co/glif/how2draw). Unless explicitly specified, a rank of 4 was used for all experiments. Here’s a side-by-side comparison of the original and the reduced LoRAs (on the same seed).
<details>
<summary>Inference code</summary>
```python
from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
# Change accordingly.
lora_id = "How2Draw-V2_000002800_svd.safetensors"
pipe.load_lora_weights(lora_id)
prompts = [
"Yorkshire Terrier with smile, How2Draw",
"a dolphin, How2Draw",
"an owl, How3Draw",
"A silhouette of a girl performing a ballet pose, with elegant lines to suggest grace and movement. The background can include simple outlines of ballet shoes and a music note. The image should convey elegance and poise in a minimalistic style, How2Draw"
]
images = pipe(
prompts, num_inference_steps=50, max_sequence_length=512, guidance_scale=3.5, generator=torch.manual_seed(0)
).images
```
</details>
<table style="border-collapse: collapse;">
<tbody>
<tr>
<td align="center"><img src="https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/collage_0.png" alt="Image 1"></td>
<td align="center">Yorkshire Terrier with smile, How2Draw</td>
</tr>
<tr>
<td align="center"><img src="https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/collage_1.png" alt="Image 2"></td>
<td align="center">a dolphin, How2Draw</td>
</tr>
<tr>
<td align="center"><img src="https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/collage_2.png" alt="Image 3"></td>
<td align="center">an owl, How3Draw</td>
</tr>
<tr>
<td align="center"><img src="https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/collage_3.png" alt="Image 4"></td>
<td align="center" style="padding: 0; margin: 0;">
A silhouette of a girl performing a ballet pose, with elegant lines to suggest grace and movement.
The background can include simple outlines of ballet shoes and a music note.
The image should convey elegance and poise in a minimalistic style, How2Draw
</td>
</tr>
</tbody>
</table>
Code: [`low_rank_lora.py`](https://huggingface.co/sayakpaul/lower-rank-flux-lora/blob/main/low_rank_lora.py)
### Notes
* One should experiment with the `new_rank` parameter to obtain the desired trade-off between performance and memory. With a `new_rank` of 4, we reduce the size of the LoRA from 451MB to 42MB.
* There is a `use_sparse` option in the script above for using sparse random projection matrices.
## SVD
<details>
<summary>Results</summary>
![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_svd_collage_0.png)
![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_svd_collage_1.png)
![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_svd_collage_2.png)
![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_svd_collage_3.png)
</details>
### Randomized SVD
Full SVD can be time-consuming. Truncated SVD is useful very large sparse matrices. We can use randomized SVD for none-to-negligible loss in quality but significantly faster speed.
<details>
<summary>Results</summary>
![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_rand_svd_collage_0.png)
![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_rand_svd_collage_1.png)
![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_rand_svd_collage_2.png)
![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_rand_svd_collage_3.png)
</details>
Code: [`svd_low_rank_lora.py`](https://huggingface.co/sayakpaul/lower-rank-flux-lora/blob/main/svd_low_rank_lora.py)
### Tune the knobs in SVD
- `new_rank` as always
- `niter` when using randomized SVD
## Reduced checkpoints
* Randomized SVD: [How2Draw-V2_000002800_rand_svd.safetensors](./How2Draw-V2_000002800_rand_svd.safetensors)
* Full SVD: [How2Draw-V2_000002800_svd.safetensors](./How2Draw-V2_000002800_svd.safetensors)
* Random projections: [How2Draw-V2_000002800_reduced.safetensors](./How2Draw-V2_000002800_reduced.safetensors) |