Diffusers
English
sayakpaul HF staff commited on
Commit
f5864a2
1 Parent(s): 01adc64

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +70 -1
README.md CHANGED
@@ -19,6 +19,9 @@ This project explores two options to reduce the original LoRA checkpoint into an
19
  * Random projections
20
  * SVD
21
 
 
 
 
22
  ## Random projections
23
 
24
  Basic idea:
@@ -140,4 +143,70 @@ Code: [`svd_low_rank_lora.py`](https://huggingface.co/sayakpaul/lower-rank-flux-
140
 
141
  * Randomized SVD: [How2Draw-V2_000002800_rand_svd.safetensors](./How2Draw-V2_000002800_rand_svd.safetensors)
142
  * Full SVD: [How2Draw-V2_000002800_svd.safetensors](./How2Draw-V2_000002800_svd.safetensors)
143
- * Random projections: [How2Draw-V2_000002800_reduced.safetensors](./How2Draw-V2_000002800_reduced.safetensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  * Random projections
20
  * SVD
21
 
22
+ > [!TIP]
23
+ > We have also explored the opposite direction of the above i.e., take a low-rank LoRA and increase its rank with orthoginal completion. Check out [this section](#lora-rank-upsampling) for more details (code, results, etc.).
24
+
25
  ## Random projections
26
 
27
  Basic idea:
 
143
 
144
  * Randomized SVD: [How2Draw-V2_000002800_rand_svd.safetensors](./How2Draw-V2_000002800_rand_svd.safetensors)
145
  * Full SVD: [How2Draw-V2_000002800_svd.safetensors](./How2Draw-V2_000002800_svd.safetensors)
146
+ * Random projections: [How2Draw-V2_000002800_reduced.safetensors](./How2Draw-V2_000002800_reduced.safetensors)
147
+
148
+ ## LoRA rank upsampling
149
+
150
+ We also explored the opposite direction of what we presented above. We do this by using "orthogonal extension" across
151
+ the rank dimensions. Since we are increasing the ranks, we thought "rank upsampling" was a cool name! Check out [upsample_lora_rank.py](./upsample_lora_rank.py) script for
152
+ the implementation.
153
+
154
+ We applied this technique to [`cocktailpeanut/optimus`](https://huggingface.co/cocktailpeanut/optimus) to increase the rank from 4 to 16. You can find the
155
+ checkpoint [here](https://huggingface.co/sayakpaul/flux-lora-resizing/blob/main/optimus_16.safetensors.
156
+
157
+ ### Results
158
+
159
+ Right: original Left: upsampled
160
+
161
+ <table style="border-collapse: collapse;">
162
+ <tbody>
163
+ <tr>
164
+ <td align="center"><img src="https://huggingface.co/sayakpaul/flux-lora-resizing/resolve/main/upsampled_lora/0_collage.png" alt="Image 1"></td>
165
+ <td align="center">optimus is cleaning the house with broomstick</td>
166
+ </tr>
167
+ <tr>
168
+ <td align="center"><img src="https://huggingface.co/sayakpaul/flux-lora-resizing/resolve/main/upsampled_lora/1_collage.png" alt="Image 2"></td>
169
+ <td align="center">optimus is a DJ performing at a hip nightclub</td>
170
+ </tr>
171
+ <tr>
172
+ <td align="center"><img src="https://huggingface.co/sayakpaul/flux-lora-resizing/resolve/main/upsampled_lora/2_collage.png" alt="Image 3"></td>
173
+ <td align="center">optimus is competing in a bboy break dancing competition</td>
174
+ </tr>
175
+ <tr>
176
+ <td align="center"><img src="https://huggingface.co/sayakpaul/flux-lora-resizing/resolve/main/upsampled_lora/3_collage.png" alt="Image 4"></td>
177
+ <td align="center">optimus is playing tennis in a tennis court</td>
178
+ </tr>
179
+ </tbody>
180
+ </table>
181
+
182
+ <details>
183
+ <summary>Code</summary>
184
+
185
+ ```python
186
+ from diffusers import FluxPipeline
187
+ import torch
188
+
189
+ pipeline = FluxPipeline.from_pretrained(
190
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
191
+ ).to("cuda")
192
+ # Change this.
193
+ pipeline.load_lora_weights("optimus_16.safetensors")
194
+
195
+ prompts = [
196
+ "optimus is cleaning the house with broomstick",
197
+ "optimus is a DJ performing at a hip nightclub",
198
+ "optimus is competing in a bboy break dancing competition",
199
+ "optimus is playing tennis in a tennis court"
200
+ ]
201
+ images = pipeline(
202
+ prompts,
203
+ num_inference_steps=50,
204
+ guidance_scale=3.5,
205
+ max_sequence_length=512,
206
+ generator=torch.manual_seed(0)
207
+ ).images
208
+ for i, image in enumerate(images):
209
+ image.save(f"{i}_{'upsampled' if upsample else 'non_upsampled'}.png")
210
+ ```
211
+
212
+ </details>