Spaces:
Running
on
A10G
Running
on
A10G
Commit
•
f33c43f
1
Parent(s):
178e606
Performance PR
Browse files- Swap `LCM LoRA` to `SDXL Lightening 2 steps` (faster, more quality)
- Switch regular VAE to tiny TAESD VAE
- Add header and mention to blog
- Add keyboard navigation (`A` to Dislike, `Space` for Neither and `L` to like)
- Disable Safety Filter (redundant in SDXL for this use-case and lots of false positives)
Performance result on A10G:
- < 1s per image
app.py
CHANGED
@@ -6,7 +6,7 @@ from sklearn.svm import LinearSVC
|
|
6 |
from sklearn import preprocessing
|
7 |
import pandas as pd
|
8 |
|
9 |
-
from diffusers import LCMScheduler
|
10 |
from diffusers.models import ImageProjection
|
11 |
from patch_sdxl import SDEmb
|
12 |
import torch
|
@@ -22,6 +22,9 @@ from PIL import Image
|
|
22 |
import requests
|
23 |
from io import BytesIO, StringIO
|
24 |
|
|
|
|
|
|
|
25 |
prompt_list = [p for p in list(set(
|
26 |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
27 |
|
@@ -29,12 +32,16 @@ start_time = time.time()
|
|
29 |
|
30 |
####################### Setup Model
|
31 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
pipe.
|
|
|
|
|
|
|
37 |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
|
|
38 |
output_hidden_state = False
|
39 |
#######################
|
40 |
|
@@ -53,7 +60,7 @@ def predict(
|
|
53 |
ip_adapter_emb=im_emb.to('cuda'),
|
54 |
height=1024,
|
55 |
width=1024,
|
56 |
-
num_inference_steps=
|
57 |
guidance_scale=0,
|
58 |
).images[0]
|
59 |
im_emb, _ = pipe.encode_image(
|
@@ -61,12 +68,6 @@ def predict(
|
|
61 |
)
|
62 |
return image, im_emb.to(DEVICE)
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
# TODO add to state instead of shared across all
|
71 |
glob_idx = 0
|
72 |
|
@@ -133,9 +134,9 @@ def next_image(embs, ys, calibrate_prompts):
|
|
133 |
def start(_, embs, ys, calibrate_prompts):
|
134 |
image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
135 |
return [
|
136 |
-
gr.Button(value='Like', interactive=True),
|
137 |
-
gr.Button(value='Neither', interactive=True),
|
138 |
-
gr.Button(value='Dislike', interactive=True),
|
139 |
gr.Button(value='Start', interactive=False),
|
140 |
image,
|
141 |
embs,
|
@@ -157,9 +158,32 @@ def choose(choice, embs, ys, calibrate_prompts):
|
|
157 |
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
158 |
return img, embs, ys, calibrate_prompts
|
159 |
|
160 |
-
css =
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
embs = gr.State([])
|
164 |
ys = gr.State([])
|
165 |
calibrate_prompts = gr.State([
|
@@ -177,9 +201,9 @@ with gr.Blocks(css=css) as demo:
|
|
177 |
with gr.Row(elem_id='output-image'):
|
178 |
img = gr.Image(interactive=False, elem_id='output-image',width=700)
|
179 |
with gr.Row(equal_height=True):
|
180 |
-
b3 = gr.Button(value='Dislike', interactive=False,)
|
181 |
-
b2 = gr.Button(value='Neither', interactive=False,)
|
182 |
-
b1 = gr.Button(value='Like', interactive=False,)
|
183 |
b1.click(
|
184 |
choose,
|
185 |
[b1, embs, ys, calibrate_prompts],
|
|
|
6 |
from sklearn import preprocessing
|
7 |
import pandas as pd
|
8 |
|
9 |
+
from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel
|
10 |
from diffusers.models import ImageProjection
|
11 |
from patch_sdxl import SDEmb
|
12 |
import torch
|
|
|
22 |
import requests
|
23 |
from io import BytesIO, StringIO
|
24 |
|
25 |
+
from huggingface_hub import hf_hub_download
|
26 |
+
from safetensors.torch import load_file
|
27 |
+
|
28 |
prompt_list = [p for p in list(set(
|
29 |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
30 |
|
|
|
32 |
|
33 |
####################### Setup Model
|
34 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
35 |
+
sdxl_lightening = "ByteDance/SDXL-Lightning"
|
36 |
+
ckpt = "sdxl_lightning_2step_unet.safetensors"
|
37 |
+
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to("cuda", torch.float16)
|
38 |
+
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="cuda"))
|
39 |
+
pipe = SDEmb.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
|
40 |
+
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
|
41 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
42 |
+
pipe.to(device='cuda')
|
43 |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
44 |
+
|
45 |
output_hidden_state = False
|
46 |
#######################
|
47 |
|
|
|
60 |
ip_adapter_emb=im_emb.to('cuda'),
|
61 |
height=1024,
|
62 |
width=1024,
|
63 |
+
num_inference_steps=2,
|
64 |
guidance_scale=0,
|
65 |
).images[0]
|
66 |
im_emb, _ = pipe.encode_image(
|
|
|
68 |
)
|
69 |
return image, im_emb.to(DEVICE)
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
# TODO add to state instead of shared across all
|
72 |
glob_idx = 0
|
73 |
|
|
|
134 |
def start(_, embs, ys, calibrate_prompts):
|
135 |
image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
136 |
return [
|
137 |
+
gr.Button(value='Like (L)', interactive=True),
|
138 |
+
gr.Button(value='Neither (Space)', interactive=True),
|
139 |
+
gr.Button(value='Dislike (A)', interactive=True),
|
140 |
gr.Button(value='Start', interactive=False),
|
141 |
image,
|
142 |
embs,
|
|
|
158 |
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
159 |
return img, embs, ys, calibrate_prompts
|
160 |
|
161 |
+
css = '''.gradio-container{max-width: 700px !important}
|
162 |
+
#description{text-align: center}
|
163 |
+
#description h1{display: block}
|
164 |
+
#description p{margin-top: 0}
|
165 |
+
'''
|
166 |
+
js = '''
|
167 |
+
<script>
|
168 |
+
document.addEventListener('keydown', function(event) {
|
169 |
+
if (event.key === 'a' || event.key === 'A') {
|
170 |
+
// Trigger click on 'dislike' if 'A' is pressed
|
171 |
+
document.getElementById('dislike').click();
|
172 |
+
} else if (event.key === ' ' || event.keyCode === 32) {
|
173 |
+
// Trigger click on 'neither' if Spacebar is pressed
|
174 |
+
document.getElementById('neither').click();
|
175 |
+
} else if (event.key === 'l' || event.key === 'L') {
|
176 |
+
// Trigger click on 'like' if 'L' is pressed
|
177 |
+
document.getElementById('like').click();
|
178 |
+
}
|
179 |
+
});
|
180 |
+
</script>
|
181 |
+
'''
|
182 |
+
|
183 |
+
with gr.Blocks(css=css, head=js) as demo:
|
184 |
+
gr.Markdown('''# Generative Recommenders
|
185 |
+
Explore the latent space without text prompts, based on your preferences. [Learn more on the blog](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/)
|
186 |
+
''', elem_id="description")
|
187 |
embs = gr.State([])
|
188 |
ys = gr.State([])
|
189 |
calibrate_prompts = gr.State([
|
|
|
201 |
with gr.Row(elem_id='output-image'):
|
202 |
img = gr.Image(interactive=False, elem_id='output-image',width=700)
|
203 |
with gr.Row(equal_height=True):
|
204 |
+
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
|
205 |
+
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
|
206 |
+
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
|
207 |
b1.click(
|
208 |
choose,
|
209 |
[b1, embs, ys, calibrate_prompts],
|