Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Commit
•
7a7cda5
1
Parent(s):
98afd85
Move ControlNet to Image tab
Browse files- README.md +3 -12
- app.py +95 -94
- lib/__init__.py +2 -2
- lib/annotators.py +3 -1
- lib/config.py +10 -4
- lib/inference.py +14 -33
- lib/loader.py +1 -0
- lib/utils.py +65 -51
README.md
CHANGED
@@ -60,25 +60,16 @@ preload_from_hub: # up to 10
|
|
60 |
# diffusion
|
61 |
|
62 |
Gradio app for Stable Diffusion 1.5 featuring:
|
63 |
-
* txt2img and img2img pipelines with IP-Adapter
|
|
|
64 |
* Curated models, LoRAs, and TI embeddings
|
65 |
-
* ControlNet with annotators
|
66 |
* Compel prompt weighting
|
67 |
-
*
|
68 |
* Multiple samplers with Karras scheduling
|
69 |
* DeepCache, FreeU, and Clip Skip available
|
70 |
* Real-ESRGAN upscaling
|
71 |
* Optional tiny autoencoder
|
72 |
|
73 |
-
## Motivation
|
74 |
-
|
75 |
-
* host a free and easy-to-use Stable Diffusion UI on ZeroGPU
|
76 |
-
* provide the necessary tools for common workflows
|
77 |
-
* curate useful models, adapters, and embeddings
|
78 |
-
* prefer Diffusers over custom PyTorch
|
79 |
-
* be fast on 8GB with no offloading
|
80 |
-
* only support CUDA on Linux/WSL
|
81 |
-
|
82 |
## Usage
|
83 |
|
84 |
See [`DOCS.md`](https://huggingface.co/spaces/adamelliotfields/diffusion/blob/main/DOCS.md).
|
|
|
60 |
# diffusion
|
61 |
|
62 |
Gradio app for Stable Diffusion 1.5 featuring:
|
63 |
+
* txt2img and img2img pipelines with ControlNet and IP-Adapter
|
64 |
+
* Canny edge detection (more preprocessors coming soon)
|
65 |
* Curated models, LoRAs, and TI embeddings
|
|
|
66 |
* Compel prompt weighting
|
67 |
+
* Hand-written style templates
|
68 |
* Multiple samplers with Karras scheduling
|
69 |
* DeepCache, FreeU, and Clip Skip available
|
70 |
* Real-ESRGAN upscaling
|
71 |
* Optional tiny autoencoder
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
## Usage
|
74 |
|
75 |
See [`DOCS.md`](https://huggingface.co/spaces/adamelliotfields/diffusion/blob/main/DOCS.md).
|
app.py
CHANGED
@@ -6,16 +6,13 @@ import random
|
|
6 |
import gradio as gr
|
7 |
|
8 |
from lib import (
|
9 |
-
CannyAnnotator,
|
10 |
Config,
|
11 |
async_call,
|
12 |
disable_progress_bars,
|
13 |
download_civit_file,
|
14 |
download_repo_files,
|
15 |
generate,
|
16 |
-
get_valid_size,
|
17 |
read_file,
|
18 |
-
resize_image,
|
19 |
)
|
20 |
|
21 |
# the CSS `content` attribute expects a string so we need to wrap the number in quotes
|
@@ -45,7 +42,7 @@ aspect_ratio_js = """
|
|
45 |
"""
|
46 |
|
47 |
|
48 |
-
def
|
49 |
if locked:
|
50 |
return gr.Dropdown(
|
51 |
choices=[("🔒", -2)],
|
@@ -60,19 +57,17 @@ def create_image_dropdown(images, locked=False):
|
|
60 |
)
|
61 |
|
62 |
|
63 |
-
async def gallery_fn(images, image, ip_image):
|
64 |
return (
|
65 |
-
|
66 |
-
|
|
|
67 |
)
|
68 |
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
# handle selecting an image from the gallery
|
75 |
-
# -2 is the lock icon, -1 is None
|
76 |
async def image_select_fn(images, image, i):
|
77 |
if i == -2:
|
78 |
return gr.Image(image)
|
@@ -87,15 +82,6 @@ async def random_fn():
|
|
87 |
return gr.Textbox(value=random.choice(prompts))
|
88 |
|
89 |
|
90 |
-
# TODO: move this to another file once more annotators are added; will need @GPU decorator
|
91 |
-
async def annotate_fn(image, annotator):
|
92 |
-
size = get_valid_size(image)
|
93 |
-
image = resize_image(image, size)
|
94 |
-
if annotator == "canny":
|
95 |
-
canny = CannyAnnotator()
|
96 |
-
return canny(image, size)
|
97 |
-
|
98 |
-
|
99 |
async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
|
100 |
if len(args) > 0:
|
101 |
prompt = args[0]
|
@@ -105,17 +91,22 @@ async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
|
|
105 |
raise gr.Error("You must enter a prompt")
|
106 |
|
107 |
# always the last arguments
|
108 |
-
DISABLE_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT = args[-
|
109 |
-
gen_args = list(args[:-
|
|
|
|
|
110 |
if DISABLE_IMAGE_PROMPT:
|
111 |
gen_args[2] = None
|
112 |
-
if
|
113 |
gen_args[3] = None
|
|
|
|
|
114 |
|
115 |
try:
|
116 |
if Config.ZERO_GPU:
|
117 |
progress((0, 100), desc="ZeroGPU init")
|
118 |
|
|
|
119 |
images = await async_call(
|
120 |
generate,
|
121 |
*gen_args,
|
@@ -125,6 +116,7 @@ async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
|
|
125 |
)
|
126 |
except RuntimeError:
|
127 |
raise gr.Error("Error: Please try again")
|
|
|
128 |
return images
|
129 |
|
130 |
|
@@ -155,6 +147,7 @@ with gr.Blocks(
|
|
155 |
# override image inputs without clearing them
|
156 |
DISABLE_IMAGE_PROMPT = gr.State(False)
|
157 |
DISABLE_IP_IMAGE_PROMPT = gr.State(False)
|
|
|
158 |
|
159 |
gr.HTML(read_file("./partials/intro.html"))
|
160 |
|
@@ -212,6 +205,14 @@ with gr.Blocks(
|
|
212 |
image_prompt = gr.Image(
|
213 |
show_share_button=False,
|
214 |
label="Initial Image",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
min_width=320,
|
216 |
format="png",
|
217 |
type="pil",
|
@@ -226,100 +227,84 @@ with gr.Blocks(
|
|
226 |
|
227 |
with gr.Row():
|
228 |
image_select = gr.Dropdown(
|
229 |
-
info="Use
|
230 |
choices=[("None", -1)],
|
231 |
-
label="
|
232 |
interactive=True,
|
233 |
filterable=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
value=-1,
|
235 |
)
|
236 |
ip_image_select = gr.Dropdown(
|
237 |
-
info="Use
|
238 |
-
label="
|
239 |
choices=[("None", -1)],
|
240 |
interactive=True,
|
241 |
filterable=False,
|
|
|
242 |
value=-1,
|
243 |
)
|
244 |
|
245 |
with gr.Row():
|
246 |
denoising_strength = gr.Slider(
|
|
|
247 |
value=Config.DENOISING_STRENGTH,
|
248 |
-
label="Denoising Strength",
|
249 |
minimum=0.0,
|
250 |
maximum=1.0,
|
251 |
step=0.1,
|
252 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
with gr.Row():
|
255 |
disable_image = gr.Checkbox(
|
256 |
-
elem_classes=["checkbox"],
|
257 |
label="Disable Initial Image",
|
|
|
258 |
value=False,
|
259 |
)
|
260 |
-
|
|
|
261 |
elem_classes=["checkbox"],
|
|
|
|
|
|
|
262 |
label="Disable IP-Adapter Image",
|
|
|
263 |
value=False,
|
264 |
)
|
265 |
use_ip_face = gr.Checkbox(
|
266 |
-
elem_classes=["checkbox"],
|
267 |
label="Use IP-Adapter Face",
|
|
|
268 |
value=False,
|
269 |
)
|
270 |
|
271 |
-
# controlnet tab
|
272 |
-
with gr.TabItem("🎮 Control"):
|
273 |
-
with gr.Row():
|
274 |
-
control_image_input = gr.Image(
|
275 |
-
show_share_button=False,
|
276 |
-
label="Control Image",
|
277 |
-
min_width=320,
|
278 |
-
format="png",
|
279 |
-
type="pil",
|
280 |
-
)
|
281 |
-
control_image_prompt = gr.Image(
|
282 |
-
interactive=False,
|
283 |
-
show_share_button=False,
|
284 |
-
label="Control Image Output",
|
285 |
-
show_label=False,
|
286 |
-
min_width=320,
|
287 |
-
format="png",
|
288 |
-
type="pil",
|
289 |
-
)
|
290 |
-
|
291 |
-
with gr.Row():
|
292 |
-
control_annotator = gr.Dropdown(
|
293 |
-
choices=[("Canny", "canny")],
|
294 |
-
label="Annotator",
|
295 |
-
filterable=False,
|
296 |
-
value="canny",
|
297 |
-
)
|
298 |
-
|
299 |
-
with gr.Row():
|
300 |
-
annotate_btn = gr.Button("Annotate", variant="primary")
|
301 |
-
clear_control_btn = gr.ClearButton(
|
302 |
-
elem_classes=["icon-button", "popover"],
|
303 |
-
components=[control_image_prompt],
|
304 |
-
variant="secondary",
|
305 |
-
elem_id="clear-control",
|
306 |
-
min_width=0,
|
307 |
-
value="🗑️",
|
308 |
-
)
|
309 |
-
|
310 |
with gr.TabItem("⚙️ Menu"):
|
311 |
with gr.Group():
|
312 |
negative_prompt = gr.Textbox(
|
313 |
-
value="nsfw+",
|
314 |
label="Negative Prompt",
|
|
|
315 |
lines=2,
|
316 |
)
|
317 |
|
318 |
with gr.Row():
|
319 |
model = gr.Dropdown(
|
320 |
choices=Config.MODELS,
|
321 |
-
filterable=False,
|
322 |
value=Config.MODEL,
|
|
|
323 |
label="Model",
|
324 |
min_width=240,
|
325 |
)
|
@@ -489,25 +474,12 @@ with gr.Blocks(
|
|
489 |
value=False,
|
490 |
)
|
491 |
|
492 |
-
annotate_btn.click(
|
493 |
-
annotate_fn,
|
494 |
-
inputs=[control_image_input, control_annotator],
|
495 |
-
outputs=[control_image_prompt],
|
496 |
-
)
|
497 |
-
|
498 |
random_btn.click(random_fn, inputs=[], outputs=[prompt], show_api=False)
|
499 |
|
500 |
refresh_btn.click(None, inputs=[], outputs=[seed], js=refresh_seed_js)
|
501 |
|
502 |
seed.change(None, inputs=[seed], outputs=[], js=seed_js)
|
503 |
|
504 |
-
file_format.change(
|
505 |
-
lambda f: (gr.Gallery(format=f), gr.Image(format=f), gr.Image(format=f)),
|
506 |
-
inputs=[file_format],
|
507 |
-
outputs=[output_images, image_prompt, ip_image_prompt],
|
508 |
-
show_api=False,
|
509 |
-
)
|
510 |
-
|
511 |
# input events are only user input; change events are both user and programmatic
|
512 |
aspect_ratio.input(
|
513 |
None,
|
@@ -516,11 +488,23 @@ with gr.Blocks(
|
|
516 |
js=aspect_ratio_js,
|
517 |
)
|
518 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
# lock the input images so you don't lose them when the gallery updates
|
520 |
output_images.change(
|
521 |
gallery_fn,
|
522 |
-
inputs=[output_images, image_prompt, ip_image_prompt],
|
523 |
-
outputs=[image_select, ip_image_select],
|
524 |
show_api=False,
|
525 |
)
|
526 |
|
@@ -531,6 +515,12 @@ with gr.Blocks(
|
|
531 |
outputs=[image_prompt],
|
532 |
show_api=False,
|
533 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
534 |
ip_image_select.change(
|
535 |
image_select_fn,
|
536 |
inputs=[output_images, ip_image_prompt, ip_image_select],
|
@@ -545,6 +535,12 @@ with gr.Blocks(
|
|
545 |
outputs=[image_select],
|
546 |
show_api=False,
|
547 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
ip_image_prompt.clear(
|
549 |
image_prompt_fn,
|
550 |
inputs=[output_images],
|
@@ -563,10 +559,14 @@ with gr.Blocks(
|
|
563 |
|
564 |
# toggle image prompts by updating session state
|
565 |
gr.on(
|
566 |
-
triggers=[disable_image.input, disable_ip_image.input],
|
567 |
-
fn=lambda disable_image, disable_ip_image: (
|
568 |
-
|
569 |
-
|
|
|
|
|
|
|
|
|
570 |
)
|
571 |
|
572 |
# generate images
|
@@ -579,8 +579,8 @@ with gr.Blocks(
|
|
579 |
prompt,
|
580 |
negative_prompt,
|
581 |
image_prompt,
|
582 |
-
ip_image_prompt,
|
583 |
control_image_prompt,
|
|
|
584 |
lora_1,
|
585 |
lora_1_weight,
|
586 |
lora_2,
|
@@ -605,6 +605,7 @@ with gr.Blocks(
|
|
605 |
use_clip_skip,
|
606 |
use_ip_face,
|
607 |
DISABLE_IMAGE_PROMPT,
|
|
|
608 |
DISABLE_IP_IMAGE_PROMPT,
|
609 |
],
|
610 |
)
|
|
|
6 |
import gradio as gr
|
7 |
|
8 |
from lib import (
|
|
|
9 |
Config,
|
10 |
async_call,
|
11 |
disable_progress_bars,
|
12 |
download_civit_file,
|
13 |
download_repo_files,
|
14 |
generate,
|
|
|
15 |
read_file,
|
|
|
16 |
)
|
17 |
|
18 |
# the CSS `content` attribute expects a string so we need to wrap the number in quotes
|
|
|
42 |
"""
|
43 |
|
44 |
|
45 |
+
def image_prompt_fn(images, locked=False):
|
46 |
if locked:
|
47 |
return gr.Dropdown(
|
48 |
choices=[("🔒", -2)],
|
|
|
57 |
)
|
58 |
|
59 |
|
60 |
+
async def gallery_fn(images, image, control_image, ip_image):
|
61 |
return (
|
62 |
+
image_prompt_fn(images, locked=image is not None),
|
63 |
+
image_prompt_fn(images, locked=control_image is not None),
|
64 |
+
image_prompt_fn(images, locked=ip_image is not None),
|
65 |
)
|
66 |
|
67 |
|
68 |
+
# Handle selecting an image from the gallery:
|
69 |
+
# * -2 is the lock icon
|
70 |
+
# * -1 is None
|
|
|
|
|
|
|
71 |
async def image_select_fn(images, image, i):
|
72 |
if i == -2:
|
73 |
return gr.Image(image)
|
|
|
82 |
return gr.Textbox(value=random.choice(prompts))
|
83 |
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
|
86 |
if len(args) > 0:
|
87 |
prompt = args[0]
|
|
|
91 |
raise gr.Error("You must enter a prompt")
|
92 |
|
93 |
# always the last arguments
|
94 |
+
DISABLE_IMAGE_PROMPT, DISABLE_CONTROL_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT = args[-3:]
|
95 |
+
gen_args = list(args[:-3])
|
96 |
+
|
97 |
+
# the first two arguments are the prompt and negative prompt
|
98 |
if DISABLE_IMAGE_PROMPT:
|
99 |
gen_args[2] = None
|
100 |
+
if DISABLE_CONTROL_IMAGE_PROMPT:
|
101 |
gen_args[3] = None
|
102 |
+
if DISABLE_IP_IMAGE_PROMPT:
|
103 |
+
gen_args[4] = None
|
104 |
|
105 |
try:
|
106 |
if Config.ZERO_GPU:
|
107 |
progress((0, 100), desc="ZeroGPU init")
|
108 |
|
109 |
+
# the remaining arguments are the alert handlers and progress bar
|
110 |
images = await async_call(
|
111 |
generate,
|
112 |
*gen_args,
|
|
|
116 |
)
|
117 |
except RuntimeError:
|
118 |
raise gr.Error("Error: Please try again")
|
119 |
+
|
120 |
return images
|
121 |
|
122 |
|
|
|
147 |
# override image inputs without clearing them
|
148 |
DISABLE_IMAGE_PROMPT = gr.State(False)
|
149 |
DISABLE_IP_IMAGE_PROMPT = gr.State(False)
|
150 |
+
DISABLE_CONTROL_IMAGE_PROMPT = gr.State(False)
|
151 |
|
152 |
gr.HTML(read_file("./partials/intro.html"))
|
153 |
|
|
|
205 |
image_prompt = gr.Image(
|
206 |
show_share_button=False,
|
207 |
label="Initial Image",
|
208 |
+
min_width=640,
|
209 |
+
format="png",
|
210 |
+
type="pil",
|
211 |
+
)
|
212 |
+
with gr.Row():
|
213 |
+
control_image_prompt = gr.Image(
|
214 |
+
show_share_button=False,
|
215 |
+
label="Control Image",
|
216 |
min_width=320,
|
217 |
format="png",
|
218 |
type="pil",
|
|
|
227 |
|
228 |
with gr.Row():
|
229 |
image_select = gr.Dropdown(
|
230 |
+
info="Use a gallery image for initial latents",
|
231 |
choices=[("None", -1)],
|
232 |
+
label="Initial Image",
|
233 |
interactive=True,
|
234 |
filterable=False,
|
235 |
+
min_width=100,
|
236 |
+
value=-1,
|
237 |
+
)
|
238 |
+
control_image_select = gr.Dropdown(
|
239 |
+
info="Use a gallery image for ControlNet",
|
240 |
+
label="ControlNet Image",
|
241 |
+
choices=[("None", -1)],
|
242 |
+
interactive=True,
|
243 |
+
filterable=False,
|
244 |
+
min_width=100,
|
245 |
value=-1,
|
246 |
)
|
247 |
ip_image_select = gr.Dropdown(
|
248 |
+
info="Use a gallery image for IP-Adapter",
|
249 |
+
label="IP-Adapter Image",
|
250 |
choices=[("None", -1)],
|
251 |
interactive=True,
|
252 |
filterable=False,
|
253 |
+
min_width=100,
|
254 |
value=-1,
|
255 |
)
|
256 |
|
257 |
with gr.Row():
|
258 |
denoising_strength = gr.Slider(
|
259 |
+
label="Initial Image Strength",
|
260 |
value=Config.DENOISING_STRENGTH,
|
|
|
261 |
minimum=0.0,
|
262 |
maximum=1.0,
|
263 |
step=0.1,
|
264 |
)
|
265 |
+
control_annotator = gr.Dropdown(
|
266 |
+
label="ControlNet Annotator",
|
267 |
+
# TODO: annotators should be in config with names
|
268 |
+
choices=[("Canny", "canny")],
|
269 |
+
value=Config.ANNOTATOR,
|
270 |
+
filterable=False,
|
271 |
+
)
|
272 |
|
273 |
with gr.Row():
|
274 |
disable_image = gr.Checkbox(
|
|
|
275 |
label="Disable Initial Image",
|
276 |
+
elem_classes=["checkbox"],
|
277 |
value=False,
|
278 |
)
|
279 |
+
disable_control_image = gr.Checkbox(
|
280 |
+
label="Disable ControlNet Image",
|
281 |
elem_classes=["checkbox"],
|
282 |
+
value=False,
|
283 |
+
)
|
284 |
+
disable_ip_image = gr.Checkbox(
|
285 |
label="Disable IP-Adapter Image",
|
286 |
+
elem_classes=["checkbox"],
|
287 |
value=False,
|
288 |
)
|
289 |
use_ip_face = gr.Checkbox(
|
|
|
290 |
label="Use IP-Adapter Face",
|
291 |
+
elem_classes=["checkbox"],
|
292 |
value=False,
|
293 |
)
|
294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
with gr.TabItem("⚙️ Menu"):
|
296 |
with gr.Group():
|
297 |
negative_prompt = gr.Textbox(
|
|
|
298 |
label="Negative Prompt",
|
299 |
+
value="nsfw+",
|
300 |
lines=2,
|
301 |
)
|
302 |
|
303 |
with gr.Row():
|
304 |
model = gr.Dropdown(
|
305 |
choices=Config.MODELS,
|
|
|
306 |
value=Config.MODEL,
|
307 |
+
filterable=False,
|
308 |
label="Model",
|
309 |
min_width=240,
|
310 |
)
|
|
|
474 |
value=False,
|
475 |
)
|
476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
random_btn.click(random_fn, inputs=[], outputs=[prompt], show_api=False)
|
478 |
|
479 |
refresh_btn.click(None, inputs=[], outputs=[seed], js=refresh_seed_js)
|
480 |
|
481 |
seed.change(None, inputs=[seed], outputs=[], js=seed_js)
|
482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
# input events are only user input; change events are both user and programmatic
|
484 |
aspect_ratio.input(
|
485 |
None,
|
|
|
488 |
js=aspect_ratio_js,
|
489 |
)
|
490 |
|
491 |
+
file_format.change(
|
492 |
+
lambda f: (
|
493 |
+
gr.Gallery(format=f),
|
494 |
+
gr.Image(format=f),
|
495 |
+
gr.Image(format=f),
|
496 |
+
gr.Image(format=f),
|
497 |
+
),
|
498 |
+
inputs=[file_format],
|
499 |
+
outputs=[output_images, image_prompt, control_image_prompt, ip_image_prompt],
|
500 |
+
show_api=False,
|
501 |
+
)
|
502 |
+
|
503 |
# lock the input images so you don't lose them when the gallery updates
|
504 |
output_images.change(
|
505 |
gallery_fn,
|
506 |
+
inputs=[output_images, image_prompt, control_image_prompt, ip_image_prompt],
|
507 |
+
outputs=[image_select, control_image_select, ip_image_select],
|
508 |
show_api=False,
|
509 |
)
|
510 |
|
|
|
515 |
outputs=[image_prompt],
|
516 |
show_api=False,
|
517 |
)
|
518 |
+
control_image_select.change(
|
519 |
+
image_select_fn,
|
520 |
+
inputs=[output_images, control_image_prompt, control_image_select],
|
521 |
+
outputs=[control_image_prompt],
|
522 |
+
show_api=False,
|
523 |
+
)
|
524 |
ip_image_select.change(
|
525 |
image_select_fn,
|
526 |
inputs=[output_images, ip_image_prompt, ip_image_select],
|
|
|
535 |
outputs=[image_select],
|
536 |
show_api=False,
|
537 |
)
|
538 |
+
control_image_prompt.clear(
|
539 |
+
image_prompt_fn,
|
540 |
+
inputs=[output_images],
|
541 |
+
outputs=[control_image_select],
|
542 |
+
show_api=False,
|
543 |
+
)
|
544 |
ip_image_prompt.clear(
|
545 |
image_prompt_fn,
|
546 |
inputs=[output_images],
|
|
|
559 |
|
560 |
# toggle image prompts by updating session state
|
561 |
gr.on(
|
562 |
+
triggers=[disable_image.input, disable_control_image.input, disable_ip_image.input],
|
563 |
+
fn=lambda disable_image, disable_control_image, disable_ip_image: (
|
564 |
+
disable_image,
|
565 |
+
disable_control_image,
|
566 |
+
disable_ip_image,
|
567 |
+
),
|
568 |
+
inputs=[disable_image, disable_control_image, disable_ip_image],
|
569 |
+
outputs=[DISABLE_IMAGE_PROMPT, DISABLE_CONTROL_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT],
|
570 |
)
|
571 |
|
572 |
# generate images
|
|
|
579 |
prompt,
|
580 |
negative_prompt,
|
581 |
image_prompt,
|
|
|
582 |
control_image_prompt,
|
583 |
+
ip_image_prompt,
|
584 |
lora_1,
|
585 |
lora_1_weight,
|
586 |
lora_2,
|
|
|
605 |
use_clip_skip,
|
606 |
use_ip_face,
|
607 |
DISABLE_IMAGE_PROMPT,
|
608 |
+
DISABLE_CONTROL_IMAGE_PROMPT,
|
609 |
DISABLE_IP_IMAGE_PROMPT,
|
610 |
],
|
611 |
)
|
lib/__init__.py
CHANGED
@@ -5,12 +5,12 @@ from .loader import Loader
|
|
5 |
from .logger import Logger
|
6 |
from .upscaler import RealESRGAN
|
7 |
from .utils import (
|
|
|
8 |
async_call,
|
9 |
disable_progress_bars,
|
10 |
download_civit_file,
|
11 |
download_repo_files,
|
12 |
enable_progress_bars,
|
13 |
-
get_valid_size,
|
14 |
load_json,
|
15 |
read_file,
|
16 |
resize_image,
|
@@ -24,13 +24,13 @@ __all__ = [
|
|
24 |
"Loader",
|
25 |
"Logger",
|
26 |
"RealESRGAN",
|
|
|
27 |
"async_call",
|
28 |
"disable_progress_bars",
|
29 |
"download_civit_file",
|
30 |
"download_repo_files",
|
31 |
"enable_progress_bars",
|
32 |
"generate",
|
33 |
-
"get_valid_size",
|
34 |
"load_json",
|
35 |
"read_file",
|
36 |
"resize_image",
|
|
|
5 |
from .logger import Logger
|
6 |
from .upscaler import RealESRGAN
|
7 |
from .utils import (
|
8 |
+
annotate_image,
|
9 |
async_call,
|
10 |
disable_progress_bars,
|
11 |
download_civit_file,
|
12 |
download_repo_files,
|
13 |
enable_progress_bars,
|
|
|
14 |
load_json,
|
15 |
read_file,
|
16 |
resize_image,
|
|
|
24 |
"Loader",
|
25 |
"Logger",
|
26 |
"RealESRGAN",
|
27 |
+
"annotate_image",
|
28 |
"async_call",
|
29 |
"disable_progress_bars",
|
30 |
"download_civit_file",
|
31 |
"download_repo_files",
|
32 |
"enable_progress_bars",
|
33 |
"generate",
|
|
|
34 |
"load_json",
|
35 |
"read_file",
|
36 |
"resize_image",
|
lib/annotators.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from threading import Lock
|
|
|
2 |
|
3 |
from controlnet_aux import CannyDetector
|
|
|
4 |
|
5 |
|
6 |
class CannyAnnotator:
|
@@ -14,7 +16,7 @@ class CannyAnnotator:
|
|
14 |
cls._instance.model = CannyDetector()
|
15 |
return cls._instance
|
16 |
|
17 |
-
def __call__(self, img, size):
|
18 |
resolution = min(*size)
|
19 |
return self.model(
|
20 |
img,
|
|
|
1 |
from threading import Lock
|
2 |
+
from typing import Tuple
|
3 |
|
4 |
from controlnet_aux import CannyDetector
|
5 |
+
from PIL import Image
|
6 |
|
7 |
|
8 |
class CannyAnnotator:
|
|
|
16 |
cls._instance.model = CannyDetector()
|
17 |
return cls._instance
|
18 |
|
19 |
+
def __call__(self, img: Image.Image, size: Tuple[int, int]) -> Image.Image:
|
20 |
resolution = min(*size)
|
21 |
return self.model(
|
22 |
img,
|
lib/config.py
CHANGED
@@ -23,9 +23,10 @@ from .pipelines import (
|
|
23 |
CustomStableDiffusionPipeline,
|
24 |
)
|
25 |
|
26 |
-
#
|
27 |
os.environ["ZEROGPU_V2"] = "1"
|
28 |
|
|
|
29 |
if find_spec("hf_transfer"):
|
30 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
31 |
|
@@ -35,7 +36,8 @@ filterwarnings("ignore", category=FutureWarning, module="transformers")
|
|
35 |
diffusers_logging.set_verbosity_error()
|
36 |
transformers_logging.set_verbosity_error()
|
37 |
|
38 |
-
|
|
|
39 |
"feature_extractor/preprocessor_config.json",
|
40 |
"safety_checker/config.json",
|
41 |
"scheduler/scheduler_config.json",
|
@@ -52,10 +54,12 @@ _sd_files = [
|
|
52 |
"model_index.json",
|
53 |
]
|
54 |
|
|
|
55 |
Config = SimpleNamespace(
|
56 |
HF_TOKEN=os.environ.get("HF_TOKEN", None),
|
57 |
CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
|
58 |
ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
|
|
|
59 |
HF_MODELS={
|
60 |
# downloaded on startup
|
61 |
"ai-forever/Real-ESRGAN": ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"],
|
@@ -64,7 +68,7 @@ Config = SimpleNamespace(
|
|
64 |
"fluently/Fluently-v4": ["Fluently-v4.safetensors"],
|
65 |
"Linaqruf/anything-v3-1": ["anything-v3-2.safetensors"],
|
66 |
"lllyasviel/control_v11p_sd15_canny": ["diffusion_pytorch_model.fp16.safetensors"],
|
67 |
-
"Lykon/dreamshaper-8": [*
|
68 |
"madebyollin/taesd": ["diffusion_pytorch_model.safetensors"],
|
69 |
"prompthero/openjourney-v4": ["openjourney-v4.ckpt"],
|
70 |
"SG161222/Realistic_Vision_V5.1_noVAE": ["Realistic_Vision_V5.1_fp16-no-ema.safetensors"],
|
@@ -111,8 +115,9 @@ Config = SimpleNamespace(
|
|
111 |
"SG161222/Realistic_Vision_V5.1_noVAE",
|
112 |
"XpucT/Deliberate",
|
113 |
],
|
|
|
114 |
MODEL_CHECKPOINTS={
|
115 |
-
# keep keys lowercase
|
116 |
"comfy-org/stable-diffusion-v1-5-archive": "v1-5-pruned-emaonly-fp16.safetensors",
|
117 |
"cyberdelia/cyberrealistic": "CyberRealistic_V5_FP16.safetensors",
|
118 |
"fluently/fluently-v4": "Fluently-v4.safetensors",
|
@@ -131,6 +136,7 @@ Config = SimpleNamespace(
|
|
131 |
"PNDM": PNDMScheduler,
|
132 |
"UniPC 2M": UniPCMultistepScheduler,
|
133 |
},
|
|
|
134 |
ANNOTATORS={
|
135 |
"canny": "lllyasviel/control_v11p_sd15_canny",
|
136 |
},
|
|
|
23 |
CustomStableDiffusionPipeline,
|
24 |
)
|
25 |
|
26 |
+
# Improved GPU handling and progress bars; set before importing spaces
|
27 |
os.environ["ZEROGPU_V2"] = "1"
|
28 |
|
29 |
+
# Errors if enabled and not installed
|
30 |
if find_spec("hf_transfer"):
|
31 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
32 |
|
|
|
36 |
diffusers_logging.set_verbosity_error()
|
37 |
transformers_logging.set_verbosity_error()
|
38 |
|
39 |
+
# Standard Stable Diffusion 1.5 file structure
|
40 |
+
sd_files = [
|
41 |
"feature_extractor/preprocessor_config.json",
|
42 |
"safety_checker/config.json",
|
43 |
"scheduler/scheduler_config.json",
|
|
|
54 |
"model_index.json",
|
55 |
]
|
56 |
|
57 |
+
# Using namespace instead of dataclass for simplicity
|
58 |
Config = SimpleNamespace(
|
59 |
HF_TOKEN=os.environ.get("HF_TOKEN", None),
|
60 |
CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
|
61 |
ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
|
62 |
+
# TODO: fix model config redundancy
|
63 |
HF_MODELS={
|
64 |
# downloaded on startup
|
65 |
"ai-forever/Real-ESRGAN": ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"],
|
|
|
68 |
"fluently/Fluently-v4": ["Fluently-v4.safetensors"],
|
69 |
"Linaqruf/anything-v3-1": ["anything-v3-2.safetensors"],
|
70 |
"lllyasviel/control_v11p_sd15_canny": ["diffusion_pytorch_model.fp16.safetensors"],
|
71 |
+
"Lykon/dreamshaper-8": [*sd_files],
|
72 |
"madebyollin/taesd": ["diffusion_pytorch_model.safetensors"],
|
73 |
"prompthero/openjourney-v4": ["openjourney-v4.ckpt"],
|
74 |
"SG161222/Realistic_Vision_V5.1_noVAE": ["Realistic_Vision_V5.1_fp16-no-ema.safetensors"],
|
|
|
115 |
"SG161222/Realistic_Vision_V5.1_noVAE",
|
116 |
"XpucT/Deliberate",
|
117 |
],
|
118 |
+
# Single-file model weights
|
119 |
MODEL_CHECKPOINTS={
|
120 |
+
# keep keys lowercase for case-insensitive matching in the loader
|
121 |
"comfy-org/stable-diffusion-v1-5-archive": "v1-5-pruned-emaonly-fp16.safetensors",
|
122 |
"cyberdelia/cyberrealistic": "CyberRealistic_V5_FP16.safetensors",
|
123 |
"fluently/fluently-v4": "Fluently-v4.safetensors",
|
|
|
136 |
"PNDM": PNDMScheduler,
|
137 |
"UniPC 2M": UniPCMultistepScheduler,
|
138 |
},
|
139 |
+
ANNOTATOR="canny",
|
140 |
ANNOTATORS={
|
141 |
"canny": "lllyasviel/control_v11p_sd15_canny",
|
142 |
},
|
lib/inference.py
CHANGED
@@ -5,18 +5,22 @@ import time
|
|
5 |
from datetime import datetime
|
6 |
from itertools import product
|
7 |
|
8 |
-
import numpy as np
|
9 |
import torch
|
10 |
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
|
11 |
from compel.prompt_parser import PromptParser
|
12 |
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
|
13 |
-
from PIL import Image
|
14 |
from spaces import GPU
|
15 |
|
16 |
from .config import Config
|
17 |
from .loader import Loader
|
18 |
from .logger import Logger
|
19 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def parse_prompt_with_arrays(prompt: str) -> list[str]:
|
@@ -58,25 +62,7 @@ def apply_style(positive_prompt, negative_prompt, style_id="none"):
|
|
58 |
)
|
59 |
|
60 |
|
61 |
-
|
62 |
-
image = None
|
63 |
-
if isinstance(input, Image.Image):
|
64 |
-
image = input
|
65 |
-
if isinstance(input, np.ndarray):
|
66 |
-
image = Image.fromarray(input)
|
67 |
-
if isinstance(input, str):
|
68 |
-
if os.path.isfile(input):
|
69 |
-
image = Image.open(input)
|
70 |
-
if image is not None:
|
71 |
-
image = image.convert("RGB")
|
72 |
-
if size is not None:
|
73 |
-
image = image.resize(size, Image.Resampling.LANCZOS)
|
74 |
-
if image is not None:
|
75 |
-
return image
|
76 |
-
else:
|
77 |
-
raise ValueError("Invalid image prompt")
|
78 |
-
|
79 |
-
|
80 |
def gpu_duration(**kwargs):
|
81 |
loading = 20
|
82 |
duration = 10
|
@@ -97,8 +83,8 @@ def generate(
|
|
97 |
positive_prompt,
|
98 |
negative_prompt="",
|
99 |
image_prompt=None,
|
100 |
-
ip_image_prompt=None,
|
101 |
control_image_prompt=None,
|
|
|
102 |
lora_1=None,
|
103 |
lora_1_weight=0.0,
|
104 |
lora_2=None,
|
@@ -146,9 +132,6 @@ def generate(
|
|
146 |
KIND = "img2img" if image_prompt is not None else "txt2img"
|
147 |
KIND = f"controlnet_{KIND}" if control_image_prompt is not None else KIND
|
148 |
|
149 |
-
if KIND.startswith("controlnet_") and annotator.lower() not in Config.ANNOTATORS.keys():
|
150 |
-
raise Error(f"Invalid annotator: {annotator}")
|
151 |
-
|
152 |
EMBEDDINGS_TYPE = (
|
153 |
ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
|
154 |
if clip_skip
|
@@ -296,21 +279,19 @@ def generate(
|
|
296 |
if progress is not None:
|
297 |
kwargs["callback_on_step_end"] = callback_on_step_end
|
298 |
|
|
|
299 |
if KIND == "img2img":
|
300 |
kwargs["strength"] = denoising_strength
|
301 |
-
kwargs["image"] =
|
302 |
|
303 |
if KIND == "controlnet_txt2img":
|
304 |
-
|
305 |
-
kwargs["image"] = prepare_image(control_image_prompt, None)
|
306 |
|
307 |
if KIND == "controlnet_img2img":
|
308 |
-
kwargs["control_image"] =
|
309 |
|
310 |
if IP_ADAPTER:
|
311 |
-
|
312 |
-
size = None if ip_face else (width, height)
|
313 |
-
kwargs["ip_adapter_image"] = prepare_image(ip_image_prompt, size)
|
314 |
|
315 |
try:
|
316 |
image = pipe(**kwargs).images[0]
|
|
|
5 |
from datetime import datetime
|
6 |
from itertools import product
|
7 |
|
|
|
8 |
import torch
|
9 |
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
|
10 |
from compel.prompt_parser import PromptParser
|
11 |
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
|
|
|
12 |
from spaces import GPU
|
13 |
|
14 |
from .config import Config
|
15 |
from .loader import Loader
|
16 |
from .logger import Logger
|
17 |
+
from .utils import (
|
18 |
+
annotate_image,
|
19 |
+
load_json,
|
20 |
+
resize_image,
|
21 |
+
safe_progress,
|
22 |
+
timer,
|
23 |
+
)
|
24 |
|
25 |
|
26 |
def parse_prompt_with_arrays(prompt: str) -> list[str]:
|
|
|
62 |
)
|
63 |
|
64 |
|
65 |
+
# Dynamic signature for the GPU duration function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
def gpu_duration(**kwargs):
|
67 |
loading = 20
|
68 |
duration = 10
|
|
|
83 |
positive_prompt,
|
84 |
negative_prompt="",
|
85 |
image_prompt=None,
|
|
|
86 |
control_image_prompt=None,
|
87 |
+
ip_image_prompt=None,
|
88 |
lora_1=None,
|
89 |
lora_1_weight=0.0,
|
90 |
lora_2=None,
|
|
|
132 |
KIND = "img2img" if image_prompt is not None else "txt2img"
|
133 |
KIND = f"controlnet_{KIND}" if control_image_prompt is not None else KIND
|
134 |
|
|
|
|
|
|
|
135 |
EMBEDDINGS_TYPE = (
|
136 |
ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
|
137 |
if clip_skip
|
|
|
279 |
if progress is not None:
|
280 |
kwargs["callback_on_step_end"] = callback_on_step_end
|
281 |
|
282 |
+
# Resizing so the initial latents are the same size as the generated image
|
283 |
if KIND == "img2img":
|
284 |
kwargs["strength"] = denoising_strength
|
285 |
+
kwargs["image"] = resize_image(image_prompt, (width, height))
|
286 |
|
287 |
if KIND == "controlnet_txt2img":
|
288 |
+
kwargs["image"] = annotate_image(control_image_prompt, annotator)
|
|
|
289 |
|
290 |
if KIND == "controlnet_img2img":
|
291 |
+
kwargs["control_image"] = annotate_image(control_image_prompt, annotator)
|
292 |
|
293 |
if IP_ADAPTER:
|
294 |
+
kwargs["ip_adapter_image"] = resize_image(ip_image_prompt)
|
|
|
|
|
295 |
|
296 |
try:
|
297 |
image = pipe(**kwargs).images[0]
|
lib/loader.py
CHANGED
@@ -372,6 +372,7 @@ class Loader:
|
|
372 |
# defaults to float32
|
373 |
pipe_kwargs["torch_dtype"] = torch.float16
|
374 |
|
|
|
375 |
if kind.startswith("controlnet_"):
|
376 |
pipe_kwargs["controlnet"] = ControlNetModel.from_pretrained(
|
377 |
Config.ANNOTATORS[annotator],
|
|
|
372 |
# defaults to float32
|
373 |
pipe_kwargs["torch_dtype"] = torch.float16
|
374 |
|
375 |
+
# config maps the repo to the ID: canny -> lllyasviel/control_sd15_canny
|
376 |
if kind.startswith("controlnet_"):
|
377 |
pipe_kwargs["controlnet"] = ControlNetModel.from_pretrained(
|
378 |
Config.ANNOTATORS[annotator],
|
lib/utils.py
CHANGED
@@ -4,10 +4,9 @@ import json
|
|
4 |
import os
|
5 |
import time
|
6 |
from contextlib import contextmanager
|
7 |
-
from typing import Callable, TypeVar
|
8 |
|
9 |
import anyio
|
10 |
-
import cv2
|
11 |
import httpx
|
12 |
import numpy as np
|
13 |
from anyio import Semaphore
|
@@ -18,6 +17,7 @@ from PIL import Image
|
|
18 |
from transformers import logging as transformers_logging
|
19 |
from typing_extensions import ParamSpec
|
20 |
|
|
|
21 |
from .logger import Logger
|
22 |
|
23 |
T = TypeVar("T")
|
@@ -110,64 +110,78 @@ def download_civit_file(lora_id, version_id, file_path=".", token=None):
|
|
110 |
log.error(f"RequestError: {e}")
|
111 |
|
112 |
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
115 |
if isinstance(image, Image.Image):
|
116 |
-
image
|
117 |
-
|
118 |
-
H, W, _ = image.shape
|
119 |
-
W = float(W)
|
120 |
-
H = float(H)
|
121 |
-
target_W, target_H = size
|
122 |
-
|
123 |
-
# Use the smaller scaling factor to maintain the aspect ratio.
|
124 |
-
k_w = float(target_W) / W
|
125 |
-
k_h = float(target_H) / H
|
126 |
-
k = min(k_w, k_h)
|
127 |
-
|
128 |
-
new_W = int(np.round(W * k / 64.0)) * 64
|
129 |
-
new_H = int(np.round(H * k / 64.0)) * 64
|
130 |
-
img = cv2.resize(
|
131 |
-
image,
|
132 |
-
(new_W, new_H),
|
133 |
-
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
|
134 |
-
)
|
135 |
-
return img
|
136 |
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
142 |
|
143 |
-
def
|
144 |
-
return
|
145 |
|
146 |
-
|
147 |
-
|
148 |
|
149 |
-
|
150 |
-
ar = W / H
|
151 |
|
152 |
# try width first
|
153 |
-
if
|
154 |
-
|
155 |
-
|
156 |
else:
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
# if
|
161 |
-
if
|
162 |
-
|
163 |
-
|
164 |
-
if
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
# https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
|
172 |
async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
173 |
async with MAX_THREADS_GUARD:
|
|
|
4 |
import os
|
5 |
import time
|
6 |
from contextlib import contextmanager
|
7 |
+
from typing import Callable, Tuple, TypeVar
|
8 |
|
9 |
import anyio
|
|
|
10 |
import httpx
|
11 |
import numpy as np
|
12 |
from anyio import Semaphore
|
|
|
17 |
from transformers import logging as transformers_logging
|
18 |
from typing_extensions import ParamSpec
|
19 |
|
20 |
+
from .annotators import CannyAnnotator
|
21 |
from .logger import Logger
|
22 |
|
23 |
T = TypeVar("T")
|
|
|
110 |
log.error(f"RequestError: {e}")
|
111 |
|
112 |
|
113 |
+
def image_to_pil(image: Image.Image):
|
114 |
+
"""Converts various image inputs to RGB PIL Image."""
|
115 |
+
if isinstance(image, str) and os.path.isfile(image):
|
116 |
+
image = Image.open(image)
|
117 |
+
if isinstance(image, np.ndarray):
|
118 |
+
image = Image.fromarray(image)
|
119 |
if isinstance(image, Image.Image):
|
120 |
+
return image.convert("RGB")
|
121 |
+
raise ValueError("Invalid image input")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
|
124 |
+
def get_valid_image_size(
|
125 |
+
width: int,
|
126 |
+
height: int,
|
127 |
+
step=64,
|
128 |
+
min_size=512,
|
129 |
+
max_size=4096,
|
130 |
+
):
|
131 |
+
"""Get new image dimensions while preserving aspect ratio."""
|
132 |
|
133 |
+
def round_down(x):
|
134 |
+
return int((x // step) * step)
|
135 |
|
136 |
+
def clamp(x):
|
137 |
+
return max(min_size, min(x, max_size))
|
138 |
|
139 |
+
aspect_ratio = width / height
|
|
|
140 |
|
141 |
# try width first
|
142 |
+
if width > height:
|
143 |
+
new_width = round_down(clamp(width))
|
144 |
+
new_height = round_down(new_width / aspect_ratio)
|
145 |
else:
|
146 |
+
new_height = round_down(clamp(height))
|
147 |
+
new_width = round_down(new_height * aspect_ratio)
|
148 |
+
|
149 |
+
# if new dimensions are out of bounds, try height
|
150 |
+
if not min_size <= new_width <= max_size:
|
151 |
+
new_width = round_down(clamp(width))
|
152 |
+
new_height = round_down(new_width / aspect_ratio)
|
153 |
+
if not min_size <= new_height <= max_size:
|
154 |
+
new_height = round_down(clamp(height))
|
155 |
+
new_width = round_down(new_height * aspect_ratio)
|
156 |
+
|
157 |
+
return (new_width, new_height)
|
158 |
+
|
159 |
+
|
160 |
+
def resize_image(
|
161 |
+
image: Image.Image,
|
162 |
+
size: Tuple[int, int] = None,
|
163 |
+
resampling: Image.Resampling = None,
|
164 |
+
):
|
165 |
+
"""Resize image with proper interpolation and dimension constraints."""
|
166 |
+
image = image_to_pil(image)
|
167 |
+
if size is None:
|
168 |
+
size = get_valid_image_size(*image.size)
|
169 |
+
if resampling is None:
|
170 |
+
resampling = Image.Resampling.LANCZOS
|
171 |
+
return image.resize(size, resampling)
|
172 |
+
|
173 |
+
|
174 |
+
def annotate_image(image: Image.Image, annotator="canny"):
|
175 |
+
"""Get the feature map of an image using the specified annotator."""
|
176 |
+
size = get_valid_image_size(*image.size)
|
177 |
+
image = resize_image(image, size)
|
178 |
+
if annotator.lower() == "canny":
|
179 |
+
canny = CannyAnnotator()
|
180 |
+
return canny(image, size)
|
181 |
+
raise ValueError(f"Invalid annotator: {annotator}")
|
182 |
+
|
183 |
+
|
184 |
+
# Like the original but supports args and kwargs instead of a dict
|
185 |
# https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
|
186 |
async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
187 |
async with MAX_THREADS_GUARD:
|