Spaces:
Running
on
A10G
Running
on
A10G
Commit
•
db551d5
1
Parent(s):
178e606
Performance PR (#2)
Browse files- Performance PR (f33c43f609f59f8722b5928f0535007a9157da38)
- Disable SC (e6d1b5454f215a7280081510188907d11646de37)
Co-authored-by: Apolinário from multimodal AI art <multimodalart@users.noreply.huggingface.co>
- app.py +46 -22
- patch_sdxl.py +4 -30
app.py
CHANGED
@@ -6,7 +6,7 @@ from sklearn.svm import LinearSVC
|
|
6 |
from sklearn import preprocessing
|
7 |
import pandas as pd
|
8 |
|
9 |
-
from diffusers import LCMScheduler
|
10 |
from diffusers.models import ImageProjection
|
11 |
from patch_sdxl import SDEmb
|
12 |
import torch
|
@@ -22,6 +22,9 @@ from PIL import Image
|
|
22 |
import requests
|
23 |
from io import BytesIO, StringIO
|
24 |
|
|
|
|
|
|
|
25 |
prompt_list = [p for p in list(set(
|
26 |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
27 |
|
@@ -29,12 +32,16 @@ start_time = time.time()
|
|
29 |
|
30 |
####################### Setup Model
|
31 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
pipe.
|
|
|
|
|
|
|
37 |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
|
|
38 |
output_hidden_state = False
|
39 |
#######################
|
40 |
|
@@ -53,7 +60,7 @@ def predict(
|
|
53 |
ip_adapter_emb=im_emb.to('cuda'),
|
54 |
height=1024,
|
55 |
width=1024,
|
56 |
-
num_inference_steps=
|
57 |
guidance_scale=0,
|
58 |
).images[0]
|
59 |
im_emb, _ = pipe.encode_image(
|
@@ -61,12 +68,6 @@ def predict(
|
|
61 |
)
|
62 |
return image, im_emb.to(DEVICE)
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
# TODO add to state instead of shared across all
|
71 |
glob_idx = 0
|
72 |
|
@@ -133,9 +134,9 @@ def next_image(embs, ys, calibrate_prompts):
|
|
133 |
def start(_, embs, ys, calibrate_prompts):
|
134 |
image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
135 |
return [
|
136 |
-
gr.Button(value='Like', interactive=True),
|
137 |
-
gr.Button(value='Neither', interactive=True),
|
138 |
-
gr.Button(value='Dislike', interactive=True),
|
139 |
gr.Button(value='Start', interactive=False),
|
140 |
image,
|
141 |
embs,
|
@@ -157,9 +158,32 @@ def choose(choice, embs, ys, calibrate_prompts):
|
|
157 |
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
158 |
return img, embs, ys, calibrate_prompts
|
159 |
|
160 |
-
css =
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
embs = gr.State([])
|
164 |
ys = gr.State([])
|
165 |
calibrate_prompts = gr.State([
|
@@ -177,9 +201,9 @@ with gr.Blocks(css=css) as demo:
|
|
177 |
with gr.Row(elem_id='output-image'):
|
178 |
img = gr.Image(interactive=False, elem_id='output-image',width=700)
|
179 |
with gr.Row(equal_height=True):
|
180 |
-
b3 = gr.Button(value='Dislike', interactive=False,)
|
181 |
-
b2 = gr.Button(value='Neither', interactive=False,)
|
182 |
-
b1 = gr.Button(value='Like', interactive=False,)
|
183 |
b1.click(
|
184 |
choose,
|
185 |
[b1, embs, ys, calibrate_prompts],
|
|
|
6 |
from sklearn import preprocessing
|
7 |
import pandas as pd
|
8 |
|
9 |
+
from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel
|
10 |
from diffusers.models import ImageProjection
|
11 |
from patch_sdxl import SDEmb
|
12 |
import torch
|
|
|
22 |
import requests
|
23 |
from io import BytesIO, StringIO
|
24 |
|
25 |
+
from huggingface_hub import hf_hub_download
|
26 |
+
from safetensors.torch import load_file
|
27 |
+
|
28 |
prompt_list = [p for p in list(set(
|
29 |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
30 |
|
|
|
32 |
|
33 |
####################### Setup Model
|
34 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
35 |
+
sdxl_lightening = "ByteDance/SDXL-Lightning"
|
36 |
+
ckpt = "sdxl_lightning_2step_unet.safetensors"
|
37 |
+
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to("cuda", torch.float16)
|
38 |
+
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="cuda"))
|
39 |
+
pipe = SDEmb.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
|
40 |
+
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
|
41 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
42 |
+
pipe.to(device='cuda')
|
43 |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
44 |
+
|
45 |
output_hidden_state = False
|
46 |
#######################
|
47 |
|
|
|
60 |
ip_adapter_emb=im_emb.to('cuda'),
|
61 |
height=1024,
|
62 |
width=1024,
|
63 |
+
num_inference_steps=2,
|
64 |
guidance_scale=0,
|
65 |
).images[0]
|
66 |
im_emb, _ = pipe.encode_image(
|
|
|
68 |
)
|
69 |
return image, im_emb.to(DEVICE)
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
# TODO add to state instead of shared across all
|
72 |
glob_idx = 0
|
73 |
|
|
|
134 |
def start(_, embs, ys, calibrate_prompts):
|
135 |
image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
136 |
return [
|
137 |
+
gr.Button(value='Like (L)', interactive=True),
|
138 |
+
gr.Button(value='Neither (Space)', interactive=True),
|
139 |
+
gr.Button(value='Dislike (A)', interactive=True),
|
140 |
gr.Button(value='Start', interactive=False),
|
141 |
image,
|
142 |
embs,
|
|
|
158 |
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
159 |
return img, embs, ys, calibrate_prompts
|
160 |
|
161 |
+
css = '''.gradio-container{max-width: 700px !important}
|
162 |
+
#description{text-align: center}
|
163 |
+
#description h1{display: block}
|
164 |
+
#description p{margin-top: 0}
|
165 |
+
'''
|
166 |
+
js = '''
|
167 |
+
<script>
|
168 |
+
document.addEventListener('keydown', function(event) {
|
169 |
+
if (event.key === 'a' || event.key === 'A') {
|
170 |
+
// Trigger click on 'dislike' if 'A' is pressed
|
171 |
+
document.getElementById('dislike').click();
|
172 |
+
} else if (event.key === ' ' || event.keyCode === 32) {
|
173 |
+
// Trigger click on 'neither' if Spacebar is pressed
|
174 |
+
document.getElementById('neither').click();
|
175 |
+
} else if (event.key === 'l' || event.key === 'L') {
|
176 |
+
// Trigger click on 'like' if 'L' is pressed
|
177 |
+
document.getElementById('like').click();
|
178 |
+
}
|
179 |
+
});
|
180 |
+
</script>
|
181 |
+
'''
|
182 |
+
|
183 |
+
with gr.Blocks(css=css, head=js) as demo:
|
184 |
+
gr.Markdown('''# Generative Recommenders
|
185 |
+
Explore the latent space without text prompts, based on your preferences. [Learn more on the blog](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/)
|
186 |
+
''', elem_id="description")
|
187 |
embs = gr.State([])
|
188 |
ys = gr.State([])
|
189 |
calibrate_prompts = gr.State([
|
|
|
201 |
with gr.Row(elem_id='output-image'):
|
202 |
img = gr.Image(interactive=False, elem_id='output-image',width=700)
|
203 |
with gr.Row(equal_height=True):
|
204 |
+
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
|
205 |
+
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
|
206 |
+
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
|
207 |
b1.click(
|
208 |
choose,
|
209 |
[b1, embs, ys, calibrate_prompts],
|
patch_sdxl.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
import inspect
|
5 |
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
6 |
|
@@ -29,7 +26,6 @@ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOut
|
|
29 |
|
30 |
|
31 |
|
32 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
33 |
from transformers import CLIPFeatureExtractor
|
34 |
import numpy as np
|
35 |
import torch
|
@@ -40,27 +36,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
40 |
torch_device = device
|
41 |
torch_dtype = torch.float16
|
42 |
|
43 |
-
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
44 |
-
"CompVis/stable-diffusion-safety-checker"
|
45 |
-
).to(device)
|
46 |
-
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
47 |
-
"openai/clip-vit-base-patch32"
|
48 |
-
)
|
49 |
-
|
50 |
-
def check_nsfw_images(
|
51 |
-
images: list[Image.Image],
|
52 |
-
) -> list[bool]:
|
53 |
-
safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
|
54 |
-
images_np = [np.array(img) for img in images]
|
55 |
-
|
56 |
-
_, has_nsfw_concepts = safety_checker(
|
57 |
-
images=images_np,
|
58 |
-
clip_input=safety_checker_input.pixel_values.to(torch_device),
|
59 |
-
)
|
60 |
-
return has_nsfw_concepts
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
|
65 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
66 |
|
@@ -569,12 +544,11 @@ class SDEmb(StableDiffusionXLPipeline):
|
|
569 |
# apply watermark if available
|
570 |
if self.watermark is not None:
|
571 |
image = self.watermark.apply_watermark(image)
|
572 |
-
|
573 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
574 |
-
maybe_nsfw = any(check_nsfw_images(image))
|
575 |
-
if maybe_nsfw:
|
576 |
-
|
577 |
-
|
578 |
|
579 |
# Offload all models
|
580 |
self.maybe_free_model_hooks()
|
|
|
|
|
|
|
|
|
1 |
import inspect
|
2 |
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
3 |
|
|
|
26 |
|
27 |
|
28 |
|
|
|
29 |
from transformers import CLIPFeatureExtractor
|
30 |
import numpy as np
|
31 |
import torch
|
|
|
36 |
torch_device = device
|
37 |
torch_dtype = torch.float16
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
|
|
|
544 |
# apply watermark if available
|
545 |
if self.watermark is not None:
|
546 |
image = self.watermark.apply_watermark(image)
|
|
|
547 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
548 |
+
#maybe_nsfw = any(check_nsfw_images(image))
|
549 |
+
#if maybe_nsfw:
|
550 |
+
# print('This image could be NSFW so we return a blank image.')
|
551 |
+
# return StableDiffusionXLPipelineOutput(images=[Image.new('RGB', (1024, 1024))])
|
552 |
|
553 |
# Offload all models
|
554 |
self.maybe_free_model_hooks()
|