Diffusers
English
File size: 9,164 Bytes
4233ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4a6582
 
 
 
 
f5864a2
 
 
4233ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4a6582
 
 
4233ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4a6582
 
1db6ee6
f4a6582
 
1db6ee6
f4a6582
 
 
1db6ee6
f4a6582
 
 
1db6ee6
f4a6582
 
 
1db6ee6
 
f4a6582
 
 
 
 
 
 
 
 
 
4233ccc
 
 
 
 
 
 
 
e785113
8bbc175
4233ccc
f4a6582
4233ccc
f4a6582
4233ccc
f4a6582
4233ccc
f4a6582
4233ccc
e785113
 
4233ccc
 
 
 
e785113
8bbc175
e785113
f4a6582
4233ccc
f4a6582
4233ccc
f4a6582
4233ccc
f4a6582
e785113
 
4233ccc
f4a6582
4233ccc
 
 
 
e785113
 
 
 
 
 
f5864a2
 
 
 
 
 
 
 
 
75a5c8a
f5864a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182d752
f5864a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182d752
f5864a2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
---
language:
- en
library_name: diffusers
license: other
license_name: flux-1-dev-non-commercial-license
license_link: LICENSE.md
---

LoRA is the de-facto technique for quickly adapting a pre-trained large model on custom use cases. Typically, LoRA matrices are low-rank in nature. Now, the word “low” can vary depending on the context, but usually, for a large diffusion model like [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev), a rank of 128 can be considered high. This is because users may often need to keep multiple LoRAs unfused in memory to be able to quickly switch between them. So, the higher the rank, the higher the memory on top of the volume of the base model. 

So, what if we could take an existing LoRA checkpoint with a high rank and reduce its rank even further to:

- Reduce the memory requirements
- Enable use cases like `torch.compile()` (which require all the LoRAs to be of the same rank to avoid re-compilation)

This project explores two options to reduce the original LoRA checkpoint into an even smaller one:

* Random projections
* SVD

> [!TIP]
> We have also explored the opposite direction of the above i.e., take a low-rank LoRA and increase its rank with orthoginal completion. Check out [this section](#lora-rank-upsampling) for more details (code, results, etc.).

## Random projections

Basic idea:

1. Generate a random projection matrix: `R = torch.randn(new_rank, original_rank, dtype=torch.float32) / torch.sqrt(torch.tensor(new_rank, dtype=torch.float32))`.
2. Then compute the new LoRA up and down matrices:
    
    ```python
    # We keep R in torch.float32 for numerical stability.
    lora_A_new = (R @ lora_A.to(R.dtype)).to(lora_A.dtype)
    lora_B_new = (lora_B.to(R.dtype) @ R.T).to(lora_B.dtype)
    ```
    
    If `lora_A` and `lora_B` had shapes of (42, 3072) and (3072, 42) respectively, `lora_A_new` and `lora_B_new` will have (4, 3072) and (3072, 4), respectively. 
    

### Results

Tried on this LoRA: [https://huggingface.co/glif/how2draw](https://huggingface.co/glif/how2draw). Unless explicitly specified, a rank of 4 was used for all experiments. Here’s a side-by-side comparison of the original and the reduced LoRAs (on the same seed).

<details>
<summary>Inference code</summary>
  
```python
from diffusers import DiffusionPipeline 
import torch 

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
# Change accordingly.
lora_id = "How2Draw-V2_000002800_svd.safetensors"
pipe.load_lora_weights(lora_id)

prompts = [
    "Yorkshire Terrier with smile, How2Draw",
    "a dolphin, How2Draw",
    "an owl, How3Draw",
    "A silhouette of a girl performing a ballet pose, with elegant lines to suggest grace and movement. The background can include simple outlines of ballet shoes and a music note. The image should convey elegance and poise in a minimalistic style, How2Draw"
]
images = pipe(
    prompts, num_inference_steps=50, max_sequence_length=512, guidance_scale=3.5, generator=torch.manual_seed(0)
).images
```

</details>

<table style="border-collapse: collapse;">
  <tbody>
    <tr>
      <td align="center"><img src="https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/collage_0.png" alt="Image 1"></td>
      <td align="center">Yorkshire Terrier with smile, How2Draw</td>
    </tr>
    <tr>
      <td align="center"><img src="https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/collage_1.png" alt="Image 2"></td>
      <td align="center">a dolphin, How2Draw</td>
    </tr>
    <tr>
      <td align="center"><img src="https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/collage_2.png" alt="Image 3"></td>
      <td align="center">an owl, How3Draw</td>
    </tr>
    <tr>
      <td align="center"><img src="https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/collage_3.png" alt="Image 4"></td>
      <td align="center" style="padding: 0; margin: 0;">
        A silhouette of a girl performing a ballet pose, with elegant lines to suggest grace and movement.
        The background can include simple outlines of ballet shoes and a music note.
        The image should convey elegance and poise in a minimalistic style, How2Draw
      </td>
    </tr>
  </tbody>
</table>                                               


Code:  [`low_rank_lora.py`](https://huggingface.co/sayakpaul/lower-rank-flux-lora/blob/main/low_rank_lora.py)

### Notes

* One should experiment with the `new_rank` parameter to obtain the desired trade-off between performance and memory. With a `new_rank` of 4, we reduce the size of the LoRA from 451MB to 42MB. 
* There is a `use_sparse` option in the script above for using sparse random projection matrices.

## SVD

<details>
<summary>Results</summary>

![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_svd_collage_0.png)

![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_svd_collage_1.png)

![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_svd_collage_2.png)

![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_svd_collage_3.png)

</details>

### Randomized SVD

Full SVD can be time-consuming. Truncated SVD is useful very large sparse matrices. We can use randomized SVD for none-to-negligible loss in quality but significantly faster speed. 

<details>
<summary>Results</summary>

![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_rand_svd_collage_0.png)

![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_rand_svd_collage_1.png)

![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_rand_svd_collage_2.png)

![image.png](https://huggingface.co/sayakpaul/lower-rank-flux-lora/resolve/main/images/How2Draw-V2_000002800_rand_svd_collage_3.png)
  
</details>

Code: [`svd_low_rank_lora.py`](https://huggingface.co/sayakpaul/lower-rank-flux-lora/blob/main/svd_low_rank_lora.py) 

### Tune the knobs in SVD

- `new_rank` as always
- `niter` when using randomized SVD

## Reduced checkpoints

* Randomized SVD: [How2Draw-V2_000002800_rand_svd.safetensors](./How2Draw-V2_000002800_rand_svd.safetensors)
* Full SVD: [How2Draw-V2_000002800_svd.safetensors](./How2Draw-V2_000002800_svd.safetensors)
* Random projections: [How2Draw-V2_000002800_reduced.safetensors](./How2Draw-V2_000002800_reduced.safetensors)

## LoRA rank upsampling

We also explored the opposite direction of what we presented above. We do this by using "orthogonal extension" across
the rank dimensions. Since we are increasing the ranks, we thought "rank upsampling" was a cool name! Check out [upsample_lora_rank.py](./upsample_lora_rank.py) script for
the implementation.

We applied this technique to [`cocktailpeanut/optimus`](https://huggingface.co/cocktailpeanut/optimus) to increase the rank from 4 to 16. You can find the
checkpoint [here](https://huggingface.co/sayakpaul/flux-lora-resizing/blob/main/optimus_16.safetensors). 

### Results

Right: original Left: upsampled

<table style="border-collapse: collapse;">
  <tbody>
    <tr>
      <td align="center"><img src="https://huggingface.co/sayakpaul/flux-lora-resizing/resolve/main/upsampled_lora/0_collage.png" alt="Image 1"></td>
      <td align="center">optimus is cleaning the house with broomstick</td>
    </tr>
    <tr>
      <td align="center"><img src="https://huggingface.co/sayakpaul/flux-lora-resizing/resolve/main/upsampled_lora/1_collage.png" alt="Image 2"></td>
      <td align="center">optimus is a DJ performing at a hip nightclub</td>
    </tr>
    <tr>
      <td align="center"><img src="https://huggingface.co/sayakpaul/flux-lora-resizing/resolve/main/upsampled_lora/2_collage.png" alt="Image 3"></td>
      <td align="center">optimus is competing in a bboy break dancing competition</td>
    </tr>
    <tr>
      <td align="center"><img src="https://huggingface.co/sayakpaul/flux-lora-resizing/resolve/main/upsampled_lora/3_collage.png" alt="Image 4"></td>
      <td align="center">optimus is playing tennis in a tennis court</td>
    </tr>
  </tbody>
</table> 

<details>
  <summary>Code</summary>

```python
from diffusers import FluxPipeline
import torch 

pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")
# Change this accordingly.
pipeline.load_lora_weights("optimus_16.safetensors")

prompts = [
    "optimus is cleaning the house with broomstick",
    "optimus is a DJ performing at a hip nightclub",
    "optimus is competing in a bboy break dancing competition",
    "optimus is playing tennis in a tennis court"
]
images = pipeline(
    prompts, 
    num_inference_steps=50,
    guidance_scale=3.5,
    max_sequence_length=512,
    generator=torch.manual_seed(0)
).images
for i, image in enumerate(images):
    image.save(f"{i}.png")
```

</details>