adamelliotfields commited on
Commit
7a7cda5
1 Parent(s): 98afd85

Move ControlNet to Image tab

Browse files
Files changed (8) hide show
  1. README.md +3 -12
  2. app.py +95 -94
  3. lib/__init__.py +2 -2
  4. lib/annotators.py +3 -1
  5. lib/config.py +10 -4
  6. lib/inference.py +14 -33
  7. lib/loader.py +1 -0
  8. lib/utils.py +65 -51
README.md CHANGED
@@ -60,25 +60,16 @@ preload_from_hub: # up to 10
60
  # diffusion
61
 
62
  Gradio app for Stable Diffusion 1.5 featuring:
63
- * txt2img and img2img pipelines with IP-Adapter
 
64
  * Curated models, LoRAs, and TI embeddings
65
- * ControlNet with annotators
66
  * Compel prompt weighting
67
- * dozens of styles and starter prompts
68
  * Multiple samplers with Karras scheduling
69
  * DeepCache, FreeU, and Clip Skip available
70
  * Real-ESRGAN upscaling
71
  * Optional tiny autoencoder
72
 
73
- ## Motivation
74
-
75
- * host a free and easy-to-use Stable Diffusion UI on ZeroGPU
76
- * provide the necessary tools for common workflows
77
- * curate useful models, adapters, and embeddings
78
- * prefer Diffusers over custom PyTorch
79
- * be fast on 8GB with no offloading
80
- * only support CUDA on Linux/WSL
81
-
82
  ## Usage
83
 
84
  See [`DOCS.md`](https://huggingface.co/spaces/adamelliotfields/diffusion/blob/main/DOCS.md).
 
60
  # diffusion
61
 
62
  Gradio app for Stable Diffusion 1.5 featuring:
63
+ * txt2img and img2img pipelines with ControlNet and IP-Adapter
64
+ * Canny edge detection (more preprocessors coming soon)
65
  * Curated models, LoRAs, and TI embeddings
 
66
  * Compel prompt weighting
67
+ * Hand-written style templates
68
  * Multiple samplers with Karras scheduling
69
  * DeepCache, FreeU, and Clip Skip available
70
  * Real-ESRGAN upscaling
71
  * Optional tiny autoencoder
72
 
 
 
 
 
 
 
 
 
 
73
  ## Usage
74
 
75
  See [`DOCS.md`](https://huggingface.co/spaces/adamelliotfields/diffusion/blob/main/DOCS.md).
app.py CHANGED
@@ -6,16 +6,13 @@ import random
6
  import gradio as gr
7
 
8
  from lib import (
9
- CannyAnnotator,
10
  Config,
11
  async_call,
12
  disable_progress_bars,
13
  download_civit_file,
14
  download_repo_files,
15
  generate,
16
- get_valid_size,
17
  read_file,
18
- resize_image,
19
  )
20
 
21
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
@@ -45,7 +42,7 @@ aspect_ratio_js = """
45
  """
46
 
47
 
48
- def create_image_dropdown(images, locked=False):
49
  if locked:
50
  return gr.Dropdown(
51
  choices=[("🔒", -2)],
@@ -60,19 +57,17 @@ def create_image_dropdown(images, locked=False):
60
  )
61
 
62
 
63
- async def gallery_fn(images, image, ip_image):
64
  return (
65
- create_image_dropdown(images, locked=image is not None),
66
- create_image_dropdown(images, locked=ip_image is not None),
 
67
  )
68
 
69
 
70
- async def image_prompt_fn(images):
71
- return create_image_dropdown(images)
72
-
73
-
74
- # handle selecting an image from the gallery
75
- # -2 is the lock icon, -1 is None
76
  async def image_select_fn(images, image, i):
77
  if i == -2:
78
  return gr.Image(image)
@@ -87,15 +82,6 @@ async def random_fn():
87
  return gr.Textbox(value=random.choice(prompts))
88
 
89
 
90
- # TODO: move this to another file once more annotators are added; will need @GPU decorator
91
- async def annotate_fn(image, annotator):
92
- size = get_valid_size(image)
93
- image = resize_image(image, size)
94
- if annotator == "canny":
95
- canny = CannyAnnotator()
96
- return canny(image, size)
97
-
98
-
99
  async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
100
  if len(args) > 0:
101
  prompt = args[0]
@@ -105,17 +91,22 @@ async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
105
  raise gr.Error("You must enter a prompt")
106
 
107
  # always the last arguments
108
- DISABLE_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT = args[-2:]
109
- gen_args = list(args[:-2])
 
 
110
  if DISABLE_IMAGE_PROMPT:
111
  gen_args[2] = None
112
- if DISABLE_IP_IMAGE_PROMPT:
113
  gen_args[3] = None
 
 
114
 
115
  try:
116
  if Config.ZERO_GPU:
117
  progress((0, 100), desc="ZeroGPU init")
118
 
 
119
  images = await async_call(
120
  generate,
121
  *gen_args,
@@ -125,6 +116,7 @@ async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
125
  )
126
  except RuntimeError:
127
  raise gr.Error("Error: Please try again")
 
128
  return images
129
 
130
 
@@ -155,6 +147,7 @@ with gr.Blocks(
155
  # override image inputs without clearing them
156
  DISABLE_IMAGE_PROMPT = gr.State(False)
157
  DISABLE_IP_IMAGE_PROMPT = gr.State(False)
 
158
 
159
  gr.HTML(read_file("./partials/intro.html"))
160
 
@@ -212,6 +205,14 @@ with gr.Blocks(
212
  image_prompt = gr.Image(
213
  show_share_button=False,
214
  label="Initial Image",
 
 
 
 
 
 
 
 
215
  min_width=320,
216
  format="png",
217
  type="pil",
@@ -226,100 +227,84 @@ with gr.Blocks(
226
 
227
  with gr.Row():
228
  image_select = gr.Dropdown(
229
- info="Use an initial image from the gallery",
230
  choices=[("None", -1)],
231
- label="Gallery Image",
232
  interactive=True,
233
  filterable=False,
 
 
 
 
 
 
 
 
 
 
234
  value=-1,
235
  )
236
  ip_image_select = gr.Dropdown(
237
- info="Use an IP-Adapter image from the gallery",
238
- label="Gallery Image",
239
  choices=[("None", -1)],
240
  interactive=True,
241
  filterable=False,
 
242
  value=-1,
243
  )
244
 
245
  with gr.Row():
246
  denoising_strength = gr.Slider(
 
247
  value=Config.DENOISING_STRENGTH,
248
- label="Denoising Strength",
249
  minimum=0.0,
250
  maximum=1.0,
251
  step=0.1,
252
  )
 
 
 
 
 
 
 
253
 
254
  with gr.Row():
255
  disable_image = gr.Checkbox(
256
- elem_classes=["checkbox"],
257
  label="Disable Initial Image",
 
258
  value=False,
259
  )
260
- disable_ip_image = gr.Checkbox(
 
261
  elem_classes=["checkbox"],
 
 
 
262
  label="Disable IP-Adapter Image",
 
263
  value=False,
264
  )
265
  use_ip_face = gr.Checkbox(
266
- elem_classes=["checkbox"],
267
  label="Use IP-Adapter Face",
 
268
  value=False,
269
  )
270
 
271
- # controlnet tab
272
- with gr.TabItem("🎮 Control"):
273
- with gr.Row():
274
- control_image_input = gr.Image(
275
- show_share_button=False,
276
- label="Control Image",
277
- min_width=320,
278
- format="png",
279
- type="pil",
280
- )
281
- control_image_prompt = gr.Image(
282
- interactive=False,
283
- show_share_button=False,
284
- label="Control Image Output",
285
- show_label=False,
286
- min_width=320,
287
- format="png",
288
- type="pil",
289
- )
290
-
291
- with gr.Row():
292
- control_annotator = gr.Dropdown(
293
- choices=[("Canny", "canny")],
294
- label="Annotator",
295
- filterable=False,
296
- value="canny",
297
- )
298
-
299
- with gr.Row():
300
- annotate_btn = gr.Button("Annotate", variant="primary")
301
- clear_control_btn = gr.ClearButton(
302
- elem_classes=["icon-button", "popover"],
303
- components=[control_image_prompt],
304
- variant="secondary",
305
- elem_id="clear-control",
306
- min_width=0,
307
- value="🗑️",
308
- )
309
-
310
  with gr.TabItem("⚙️ Menu"):
311
  with gr.Group():
312
  negative_prompt = gr.Textbox(
313
- value="nsfw+",
314
  label="Negative Prompt",
 
315
  lines=2,
316
  )
317
 
318
  with gr.Row():
319
  model = gr.Dropdown(
320
  choices=Config.MODELS,
321
- filterable=False,
322
  value=Config.MODEL,
 
323
  label="Model",
324
  min_width=240,
325
  )
@@ -489,25 +474,12 @@ with gr.Blocks(
489
  value=False,
490
  )
491
 
492
- annotate_btn.click(
493
- annotate_fn,
494
- inputs=[control_image_input, control_annotator],
495
- outputs=[control_image_prompt],
496
- )
497
-
498
  random_btn.click(random_fn, inputs=[], outputs=[prompt], show_api=False)
499
 
500
  refresh_btn.click(None, inputs=[], outputs=[seed], js=refresh_seed_js)
501
 
502
  seed.change(None, inputs=[seed], outputs=[], js=seed_js)
503
 
504
- file_format.change(
505
- lambda f: (gr.Gallery(format=f), gr.Image(format=f), gr.Image(format=f)),
506
- inputs=[file_format],
507
- outputs=[output_images, image_prompt, ip_image_prompt],
508
- show_api=False,
509
- )
510
-
511
  # input events are only user input; change events are both user and programmatic
512
  aspect_ratio.input(
513
  None,
@@ -516,11 +488,23 @@ with gr.Blocks(
516
  js=aspect_ratio_js,
517
  )
518
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  # lock the input images so you don't lose them when the gallery updates
520
  output_images.change(
521
  gallery_fn,
522
- inputs=[output_images, image_prompt, ip_image_prompt],
523
- outputs=[image_select, ip_image_select],
524
  show_api=False,
525
  )
526
 
@@ -531,6 +515,12 @@ with gr.Blocks(
531
  outputs=[image_prompt],
532
  show_api=False,
533
  )
 
 
 
 
 
 
534
  ip_image_select.change(
535
  image_select_fn,
536
  inputs=[output_images, ip_image_prompt, ip_image_select],
@@ -545,6 +535,12 @@ with gr.Blocks(
545
  outputs=[image_select],
546
  show_api=False,
547
  )
 
 
 
 
 
 
548
  ip_image_prompt.clear(
549
  image_prompt_fn,
550
  inputs=[output_images],
@@ -563,10 +559,14 @@ with gr.Blocks(
563
 
564
  # toggle image prompts by updating session state
565
  gr.on(
566
- triggers=[disable_image.input, disable_ip_image.input],
567
- fn=lambda disable_image, disable_ip_image: (disable_image, disable_ip_image),
568
- inputs=[disable_image, disable_ip_image],
569
- outputs=[DISABLE_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT],
 
 
 
 
570
  )
571
 
572
  # generate images
@@ -579,8 +579,8 @@ with gr.Blocks(
579
  prompt,
580
  negative_prompt,
581
  image_prompt,
582
- ip_image_prompt,
583
  control_image_prompt,
 
584
  lora_1,
585
  lora_1_weight,
586
  lora_2,
@@ -605,6 +605,7 @@ with gr.Blocks(
605
  use_clip_skip,
606
  use_ip_face,
607
  DISABLE_IMAGE_PROMPT,
 
608
  DISABLE_IP_IMAGE_PROMPT,
609
  ],
610
  )
 
6
  import gradio as gr
7
 
8
  from lib import (
 
9
  Config,
10
  async_call,
11
  disable_progress_bars,
12
  download_civit_file,
13
  download_repo_files,
14
  generate,
 
15
  read_file,
 
16
  )
17
 
18
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
 
42
  """
43
 
44
 
45
+ def image_prompt_fn(images, locked=False):
46
  if locked:
47
  return gr.Dropdown(
48
  choices=[("🔒", -2)],
 
57
  )
58
 
59
 
60
+ async def gallery_fn(images, image, control_image, ip_image):
61
  return (
62
+ image_prompt_fn(images, locked=image is not None),
63
+ image_prompt_fn(images, locked=control_image is not None),
64
+ image_prompt_fn(images, locked=ip_image is not None),
65
  )
66
 
67
 
68
+ # Handle selecting an image from the gallery:
69
+ # * -2 is the lock icon
70
+ # * -1 is None
 
 
 
71
  async def image_select_fn(images, image, i):
72
  if i == -2:
73
  return gr.Image(image)
 
82
  return gr.Textbox(value=random.choice(prompts))
83
 
84
 
 
 
 
 
 
 
 
 
 
85
  async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
86
  if len(args) > 0:
87
  prompt = args[0]
 
91
  raise gr.Error("You must enter a prompt")
92
 
93
  # always the last arguments
94
+ DISABLE_IMAGE_PROMPT, DISABLE_CONTROL_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT = args[-3:]
95
+ gen_args = list(args[:-3])
96
+
97
+ # the first two arguments are the prompt and negative prompt
98
  if DISABLE_IMAGE_PROMPT:
99
  gen_args[2] = None
100
+ if DISABLE_CONTROL_IMAGE_PROMPT:
101
  gen_args[3] = None
102
+ if DISABLE_IP_IMAGE_PROMPT:
103
+ gen_args[4] = None
104
 
105
  try:
106
  if Config.ZERO_GPU:
107
  progress((0, 100), desc="ZeroGPU init")
108
 
109
+ # the remaining arguments are the alert handlers and progress bar
110
  images = await async_call(
111
  generate,
112
  *gen_args,
 
116
  )
117
  except RuntimeError:
118
  raise gr.Error("Error: Please try again")
119
+
120
  return images
121
 
122
 
 
147
  # override image inputs without clearing them
148
  DISABLE_IMAGE_PROMPT = gr.State(False)
149
  DISABLE_IP_IMAGE_PROMPT = gr.State(False)
150
+ DISABLE_CONTROL_IMAGE_PROMPT = gr.State(False)
151
 
152
  gr.HTML(read_file("./partials/intro.html"))
153
 
 
205
  image_prompt = gr.Image(
206
  show_share_button=False,
207
  label="Initial Image",
208
+ min_width=640,
209
+ format="png",
210
+ type="pil",
211
+ )
212
+ with gr.Row():
213
+ control_image_prompt = gr.Image(
214
+ show_share_button=False,
215
+ label="Control Image",
216
  min_width=320,
217
  format="png",
218
  type="pil",
 
227
 
228
  with gr.Row():
229
  image_select = gr.Dropdown(
230
+ info="Use a gallery image for initial latents",
231
  choices=[("None", -1)],
232
+ label="Initial Image",
233
  interactive=True,
234
  filterable=False,
235
+ min_width=100,
236
+ value=-1,
237
+ )
238
+ control_image_select = gr.Dropdown(
239
+ info="Use a gallery image for ControlNet",
240
+ label="ControlNet Image",
241
+ choices=[("None", -1)],
242
+ interactive=True,
243
+ filterable=False,
244
+ min_width=100,
245
  value=-1,
246
  )
247
  ip_image_select = gr.Dropdown(
248
+ info="Use a gallery image for IP-Adapter",
249
+ label="IP-Adapter Image",
250
  choices=[("None", -1)],
251
  interactive=True,
252
  filterable=False,
253
+ min_width=100,
254
  value=-1,
255
  )
256
 
257
  with gr.Row():
258
  denoising_strength = gr.Slider(
259
+ label="Initial Image Strength",
260
  value=Config.DENOISING_STRENGTH,
 
261
  minimum=0.0,
262
  maximum=1.0,
263
  step=0.1,
264
  )
265
+ control_annotator = gr.Dropdown(
266
+ label="ControlNet Annotator",
267
+ # TODO: annotators should be in config with names
268
+ choices=[("Canny", "canny")],
269
+ value=Config.ANNOTATOR,
270
+ filterable=False,
271
+ )
272
 
273
  with gr.Row():
274
  disable_image = gr.Checkbox(
 
275
  label="Disable Initial Image",
276
+ elem_classes=["checkbox"],
277
  value=False,
278
  )
279
+ disable_control_image = gr.Checkbox(
280
+ label="Disable ControlNet Image",
281
  elem_classes=["checkbox"],
282
+ value=False,
283
+ )
284
+ disable_ip_image = gr.Checkbox(
285
  label="Disable IP-Adapter Image",
286
+ elem_classes=["checkbox"],
287
  value=False,
288
  )
289
  use_ip_face = gr.Checkbox(
 
290
  label="Use IP-Adapter Face",
291
+ elem_classes=["checkbox"],
292
  value=False,
293
  )
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  with gr.TabItem("⚙️ Menu"):
296
  with gr.Group():
297
  negative_prompt = gr.Textbox(
 
298
  label="Negative Prompt",
299
+ value="nsfw+",
300
  lines=2,
301
  )
302
 
303
  with gr.Row():
304
  model = gr.Dropdown(
305
  choices=Config.MODELS,
 
306
  value=Config.MODEL,
307
+ filterable=False,
308
  label="Model",
309
  min_width=240,
310
  )
 
474
  value=False,
475
  )
476
 
 
 
 
 
 
 
477
  random_btn.click(random_fn, inputs=[], outputs=[prompt], show_api=False)
478
 
479
  refresh_btn.click(None, inputs=[], outputs=[seed], js=refresh_seed_js)
480
 
481
  seed.change(None, inputs=[seed], outputs=[], js=seed_js)
482
 
 
 
 
 
 
 
 
483
  # input events are only user input; change events are both user and programmatic
484
  aspect_ratio.input(
485
  None,
 
488
  js=aspect_ratio_js,
489
  )
490
 
491
+ file_format.change(
492
+ lambda f: (
493
+ gr.Gallery(format=f),
494
+ gr.Image(format=f),
495
+ gr.Image(format=f),
496
+ gr.Image(format=f),
497
+ ),
498
+ inputs=[file_format],
499
+ outputs=[output_images, image_prompt, control_image_prompt, ip_image_prompt],
500
+ show_api=False,
501
+ )
502
+
503
  # lock the input images so you don't lose them when the gallery updates
504
  output_images.change(
505
  gallery_fn,
506
+ inputs=[output_images, image_prompt, control_image_prompt, ip_image_prompt],
507
+ outputs=[image_select, control_image_select, ip_image_select],
508
  show_api=False,
509
  )
510
 
 
515
  outputs=[image_prompt],
516
  show_api=False,
517
  )
518
+ control_image_select.change(
519
+ image_select_fn,
520
+ inputs=[output_images, control_image_prompt, control_image_select],
521
+ outputs=[control_image_prompt],
522
+ show_api=False,
523
+ )
524
  ip_image_select.change(
525
  image_select_fn,
526
  inputs=[output_images, ip_image_prompt, ip_image_select],
 
535
  outputs=[image_select],
536
  show_api=False,
537
  )
538
+ control_image_prompt.clear(
539
+ image_prompt_fn,
540
+ inputs=[output_images],
541
+ outputs=[control_image_select],
542
+ show_api=False,
543
+ )
544
  ip_image_prompt.clear(
545
  image_prompt_fn,
546
  inputs=[output_images],
 
559
 
560
  # toggle image prompts by updating session state
561
  gr.on(
562
+ triggers=[disable_image.input, disable_control_image.input, disable_ip_image.input],
563
+ fn=lambda disable_image, disable_control_image, disable_ip_image: (
564
+ disable_image,
565
+ disable_control_image,
566
+ disable_ip_image,
567
+ ),
568
+ inputs=[disable_image, disable_control_image, disable_ip_image],
569
+ outputs=[DISABLE_IMAGE_PROMPT, DISABLE_CONTROL_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT],
570
  )
571
 
572
  # generate images
 
579
  prompt,
580
  negative_prompt,
581
  image_prompt,
 
582
  control_image_prompt,
583
+ ip_image_prompt,
584
  lora_1,
585
  lora_1_weight,
586
  lora_2,
 
605
  use_clip_skip,
606
  use_ip_face,
607
  DISABLE_IMAGE_PROMPT,
608
+ DISABLE_CONTROL_IMAGE_PROMPT,
609
  DISABLE_IP_IMAGE_PROMPT,
610
  ],
611
  )
lib/__init__.py CHANGED
@@ -5,12 +5,12 @@ from .loader import Loader
5
  from .logger import Logger
6
  from .upscaler import RealESRGAN
7
  from .utils import (
 
8
  async_call,
9
  disable_progress_bars,
10
  download_civit_file,
11
  download_repo_files,
12
  enable_progress_bars,
13
- get_valid_size,
14
  load_json,
15
  read_file,
16
  resize_image,
@@ -24,13 +24,13 @@ __all__ = [
24
  "Loader",
25
  "Logger",
26
  "RealESRGAN",
 
27
  "async_call",
28
  "disable_progress_bars",
29
  "download_civit_file",
30
  "download_repo_files",
31
  "enable_progress_bars",
32
  "generate",
33
- "get_valid_size",
34
  "load_json",
35
  "read_file",
36
  "resize_image",
 
5
  from .logger import Logger
6
  from .upscaler import RealESRGAN
7
  from .utils import (
8
+ annotate_image,
9
  async_call,
10
  disable_progress_bars,
11
  download_civit_file,
12
  download_repo_files,
13
  enable_progress_bars,
 
14
  load_json,
15
  read_file,
16
  resize_image,
 
24
  "Loader",
25
  "Logger",
26
  "RealESRGAN",
27
+ "annotate_image",
28
  "async_call",
29
  "disable_progress_bars",
30
  "download_civit_file",
31
  "download_repo_files",
32
  "enable_progress_bars",
33
  "generate",
 
34
  "load_json",
35
  "read_file",
36
  "resize_image",
lib/annotators.py CHANGED
@@ -1,6 +1,8 @@
1
  from threading import Lock
 
2
 
3
  from controlnet_aux import CannyDetector
 
4
 
5
 
6
  class CannyAnnotator:
@@ -14,7 +16,7 @@ class CannyAnnotator:
14
  cls._instance.model = CannyDetector()
15
  return cls._instance
16
 
17
- def __call__(self, img, size):
18
  resolution = min(*size)
19
  return self.model(
20
  img,
 
1
  from threading import Lock
2
+ from typing import Tuple
3
 
4
  from controlnet_aux import CannyDetector
5
+ from PIL import Image
6
 
7
 
8
  class CannyAnnotator:
 
16
  cls._instance.model = CannyDetector()
17
  return cls._instance
18
 
19
+ def __call__(self, img: Image.Image, size: Tuple[int, int]) -> Image.Image:
20
  resolution = min(*size)
21
  return self.model(
22
  img,
lib/config.py CHANGED
@@ -23,9 +23,10 @@ from .pipelines import (
23
  CustomStableDiffusionPipeline,
24
  )
25
 
26
- # improved GPU handling and progress bars; set before importing spaces
27
  os.environ["ZEROGPU_V2"] = "1"
28
 
 
29
  if find_spec("hf_transfer"):
30
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
31
 
@@ -35,7 +36,8 @@ filterwarnings("ignore", category=FutureWarning, module="transformers")
35
  diffusers_logging.set_verbosity_error()
36
  transformers_logging.set_verbosity_error()
37
 
38
- _sd_files = [
 
39
  "feature_extractor/preprocessor_config.json",
40
  "safety_checker/config.json",
41
  "scheduler/scheduler_config.json",
@@ -52,10 +54,12 @@ _sd_files = [
52
  "model_index.json",
53
  ]
54
 
 
55
  Config = SimpleNamespace(
56
  HF_TOKEN=os.environ.get("HF_TOKEN", None),
57
  CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
58
  ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
 
59
  HF_MODELS={
60
  # downloaded on startup
61
  "ai-forever/Real-ESRGAN": ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"],
@@ -64,7 +68,7 @@ Config = SimpleNamespace(
64
  "fluently/Fluently-v4": ["Fluently-v4.safetensors"],
65
  "Linaqruf/anything-v3-1": ["anything-v3-2.safetensors"],
66
  "lllyasviel/control_v11p_sd15_canny": ["diffusion_pytorch_model.fp16.safetensors"],
67
- "Lykon/dreamshaper-8": [*_sd_files],
68
  "madebyollin/taesd": ["diffusion_pytorch_model.safetensors"],
69
  "prompthero/openjourney-v4": ["openjourney-v4.ckpt"],
70
  "SG161222/Realistic_Vision_V5.1_noVAE": ["Realistic_Vision_V5.1_fp16-no-ema.safetensors"],
@@ -111,8 +115,9 @@ Config = SimpleNamespace(
111
  "SG161222/Realistic_Vision_V5.1_noVAE",
112
  "XpucT/Deliberate",
113
  ],
 
114
  MODEL_CHECKPOINTS={
115
- # keep keys lowercase
116
  "comfy-org/stable-diffusion-v1-5-archive": "v1-5-pruned-emaonly-fp16.safetensors",
117
  "cyberdelia/cyberrealistic": "CyberRealistic_V5_FP16.safetensors",
118
  "fluently/fluently-v4": "Fluently-v4.safetensors",
@@ -131,6 +136,7 @@ Config = SimpleNamespace(
131
  "PNDM": PNDMScheduler,
132
  "UniPC 2M": UniPCMultistepScheduler,
133
  },
 
134
  ANNOTATORS={
135
  "canny": "lllyasviel/control_v11p_sd15_canny",
136
  },
 
23
  CustomStableDiffusionPipeline,
24
  )
25
 
26
+ # Improved GPU handling and progress bars; set before importing spaces
27
  os.environ["ZEROGPU_V2"] = "1"
28
 
29
+ # Errors if enabled and not installed
30
  if find_spec("hf_transfer"):
31
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
32
 
 
36
  diffusers_logging.set_verbosity_error()
37
  transformers_logging.set_verbosity_error()
38
 
39
+ # Standard Stable Diffusion 1.5 file structure
40
+ sd_files = [
41
  "feature_extractor/preprocessor_config.json",
42
  "safety_checker/config.json",
43
  "scheduler/scheduler_config.json",
 
54
  "model_index.json",
55
  ]
56
 
57
+ # Using namespace instead of dataclass for simplicity
58
  Config = SimpleNamespace(
59
  HF_TOKEN=os.environ.get("HF_TOKEN", None),
60
  CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
61
  ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
62
+ # TODO: fix model config redundancy
63
  HF_MODELS={
64
  # downloaded on startup
65
  "ai-forever/Real-ESRGAN": ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"],
 
68
  "fluently/Fluently-v4": ["Fluently-v4.safetensors"],
69
  "Linaqruf/anything-v3-1": ["anything-v3-2.safetensors"],
70
  "lllyasviel/control_v11p_sd15_canny": ["diffusion_pytorch_model.fp16.safetensors"],
71
+ "Lykon/dreamshaper-8": [*sd_files],
72
  "madebyollin/taesd": ["diffusion_pytorch_model.safetensors"],
73
  "prompthero/openjourney-v4": ["openjourney-v4.ckpt"],
74
  "SG161222/Realistic_Vision_V5.1_noVAE": ["Realistic_Vision_V5.1_fp16-no-ema.safetensors"],
 
115
  "SG161222/Realistic_Vision_V5.1_noVAE",
116
  "XpucT/Deliberate",
117
  ],
118
+ # Single-file model weights
119
  MODEL_CHECKPOINTS={
120
+ # keep keys lowercase for case-insensitive matching in the loader
121
  "comfy-org/stable-diffusion-v1-5-archive": "v1-5-pruned-emaonly-fp16.safetensors",
122
  "cyberdelia/cyberrealistic": "CyberRealistic_V5_FP16.safetensors",
123
  "fluently/fluently-v4": "Fluently-v4.safetensors",
 
136
  "PNDM": PNDMScheduler,
137
  "UniPC 2M": UniPCMultistepScheduler,
138
  },
139
+ ANNOTATOR="canny",
140
  ANNOTATORS={
141
  "canny": "lllyasviel/control_v11p_sd15_canny",
142
  },
lib/inference.py CHANGED
@@ -5,18 +5,22 @@ import time
5
  from datetime import datetime
6
  from itertools import product
7
 
8
- import numpy as np
9
  import torch
10
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
11
  from compel.prompt_parser import PromptParser
12
  from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
13
- from PIL import Image
14
  from spaces import GPU
15
 
16
  from .config import Config
17
  from .loader import Loader
18
  from .logger import Logger
19
- from .utils import load_json, safe_progress, timer
 
 
 
 
 
 
20
 
21
 
22
  def parse_prompt_with_arrays(prompt: str) -> list[str]:
@@ -58,25 +62,7 @@ def apply_style(positive_prompt, negative_prompt, style_id="none"):
58
  )
59
 
60
 
61
- def prepare_image(input, size=None):
62
- image = None
63
- if isinstance(input, Image.Image):
64
- image = input
65
- if isinstance(input, np.ndarray):
66
- image = Image.fromarray(input)
67
- if isinstance(input, str):
68
- if os.path.isfile(input):
69
- image = Image.open(input)
70
- if image is not None:
71
- image = image.convert("RGB")
72
- if size is not None:
73
- image = image.resize(size, Image.Resampling.LANCZOS)
74
- if image is not None:
75
- return image
76
- else:
77
- raise ValueError("Invalid image prompt")
78
-
79
-
80
  def gpu_duration(**kwargs):
81
  loading = 20
82
  duration = 10
@@ -97,8 +83,8 @@ def generate(
97
  positive_prompt,
98
  negative_prompt="",
99
  image_prompt=None,
100
- ip_image_prompt=None,
101
  control_image_prompt=None,
 
102
  lora_1=None,
103
  lora_1_weight=0.0,
104
  lora_2=None,
@@ -146,9 +132,6 @@ def generate(
146
  KIND = "img2img" if image_prompt is not None else "txt2img"
147
  KIND = f"controlnet_{KIND}" if control_image_prompt is not None else KIND
148
 
149
- if KIND.startswith("controlnet_") and annotator.lower() not in Config.ANNOTATORS.keys():
150
- raise Error(f"Invalid annotator: {annotator}")
151
-
152
  EMBEDDINGS_TYPE = (
153
  ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
154
  if clip_skip
@@ -296,21 +279,19 @@ def generate(
296
  if progress is not None:
297
  kwargs["callback_on_step_end"] = callback_on_step_end
298
 
 
299
  if KIND == "img2img":
300
  kwargs["strength"] = denoising_strength
301
- kwargs["image"] = prepare_image(image_prompt, (width, height))
302
 
303
  if KIND == "controlnet_txt2img":
304
- # don't resize controlnet images
305
- kwargs["image"] = prepare_image(control_image_prompt, None)
306
 
307
  if KIND == "controlnet_img2img":
308
- kwargs["control_image"] = prepare_image(control_image_prompt, None)
309
 
310
  if IP_ADAPTER:
311
- # don't resize full-face images since they are usually square crops
312
- size = None if ip_face else (width, height)
313
- kwargs["ip_adapter_image"] = prepare_image(ip_image_prompt, size)
314
 
315
  try:
316
  image = pipe(**kwargs).images[0]
 
5
  from datetime import datetime
6
  from itertools import product
7
 
 
8
  import torch
9
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
10
  from compel.prompt_parser import PromptParser
11
  from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
 
12
  from spaces import GPU
13
 
14
  from .config import Config
15
  from .loader import Loader
16
  from .logger import Logger
17
+ from .utils import (
18
+ annotate_image,
19
+ load_json,
20
+ resize_image,
21
+ safe_progress,
22
+ timer,
23
+ )
24
 
25
 
26
  def parse_prompt_with_arrays(prompt: str) -> list[str]:
 
62
  )
63
 
64
 
65
+ # Dynamic signature for the GPU duration function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def gpu_duration(**kwargs):
67
  loading = 20
68
  duration = 10
 
83
  positive_prompt,
84
  negative_prompt="",
85
  image_prompt=None,
 
86
  control_image_prompt=None,
87
+ ip_image_prompt=None,
88
  lora_1=None,
89
  lora_1_weight=0.0,
90
  lora_2=None,
 
132
  KIND = "img2img" if image_prompt is not None else "txt2img"
133
  KIND = f"controlnet_{KIND}" if control_image_prompt is not None else KIND
134
 
 
 
 
135
  EMBEDDINGS_TYPE = (
136
  ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
137
  if clip_skip
 
279
  if progress is not None:
280
  kwargs["callback_on_step_end"] = callback_on_step_end
281
 
282
+ # Resizing so the initial latents are the same size as the generated image
283
  if KIND == "img2img":
284
  kwargs["strength"] = denoising_strength
285
+ kwargs["image"] = resize_image(image_prompt, (width, height))
286
 
287
  if KIND == "controlnet_txt2img":
288
+ kwargs["image"] = annotate_image(control_image_prompt, annotator)
 
289
 
290
  if KIND == "controlnet_img2img":
291
+ kwargs["control_image"] = annotate_image(control_image_prompt, annotator)
292
 
293
  if IP_ADAPTER:
294
+ kwargs["ip_adapter_image"] = resize_image(ip_image_prompt)
 
 
295
 
296
  try:
297
  image = pipe(**kwargs).images[0]
lib/loader.py CHANGED
@@ -372,6 +372,7 @@ class Loader:
372
  # defaults to float32
373
  pipe_kwargs["torch_dtype"] = torch.float16
374
 
 
375
  if kind.startswith("controlnet_"):
376
  pipe_kwargs["controlnet"] = ControlNetModel.from_pretrained(
377
  Config.ANNOTATORS[annotator],
 
372
  # defaults to float32
373
  pipe_kwargs["torch_dtype"] = torch.float16
374
 
375
+ # config maps the repo to the ID: canny -> lllyasviel/control_sd15_canny
376
  if kind.startswith("controlnet_"):
377
  pipe_kwargs["controlnet"] = ControlNetModel.from_pretrained(
378
  Config.ANNOTATORS[annotator],
lib/utils.py CHANGED
@@ -4,10 +4,9 @@ import json
4
  import os
5
  import time
6
  from contextlib import contextmanager
7
- from typing import Callable, TypeVar
8
 
9
  import anyio
10
- import cv2
11
  import httpx
12
  import numpy as np
13
  from anyio import Semaphore
@@ -18,6 +17,7 @@ from PIL import Image
18
  from transformers import logging as transformers_logging
19
  from typing_extensions import ParamSpec
20
 
 
21
  from .logger import Logger
22
 
23
  T = TypeVar("T")
@@ -110,64 +110,78 @@ def download_civit_file(lora_id, version_id, file_path=".", token=None):
110
  log.error(f"RequestError: {e}")
111
 
112
 
113
- # resize an image while preserving the aspect ratio (size is width-first)
114
- def resize_image(image, size):
 
 
 
 
115
  if isinstance(image, Image.Image):
116
- image = np.array(image)
117
-
118
- H, W, _ = image.shape
119
- W = float(W)
120
- H = float(H)
121
- target_W, target_H = size
122
-
123
- # Use the smaller scaling factor to maintain the aspect ratio.
124
- k_w = float(target_W) / W
125
- k_h = float(target_H) / H
126
- k = min(k_w, k_h)
127
-
128
- new_W = int(np.round(W * k / 64.0)) * 64
129
- new_H = int(np.round(H * k / 64.0)) * 64
130
- img = cv2.resize(
131
- image,
132
- (new_W, new_H),
133
- interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
134
- )
135
- return img
136
 
137
 
138
- # ensure image is within bounds
139
- def get_valid_size(image, step=64, low=512, high=4096):
140
- def round_down(x, step=step):
141
- return int((x // step) * step)
 
 
 
 
142
 
143
- def clamp_range(x, low=low, high=high):
144
- return max(low, min(x, high))
145
 
146
- if isinstance(image, Image.Image):
147
- image = np.array(image)
148
 
149
- H, W = image.shape[:2]
150
- ar = W / H
151
 
152
  # try width first
153
- if W > H:
154
- new_W = round_down(clamp_range(W))
155
- new_H = round_down(new_W / ar)
156
  else:
157
- new_H = round_down(clamp_range(H))
158
- new_W = round_down(new_H * ar)
159
-
160
- # if the new size is out of bounds, try the other dimension
161
- if new_W < low or new_W > high:
162
- new_W = round_down(clamp_range(W))
163
- new_H = round_down(new_W / ar)
164
- if new_H < low or new_H > high:
165
- new_H = round_down(clamp_range(H))
166
- new_W = round_down(new_H * ar)
167
- return (new_W, new_H)
168
-
169
-
170
- # like the original but supports args and kwargs instead of a dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
172
  async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
173
  async with MAX_THREADS_GUARD:
 
4
  import os
5
  import time
6
  from contextlib import contextmanager
7
+ from typing import Callable, Tuple, TypeVar
8
 
9
  import anyio
 
10
  import httpx
11
  import numpy as np
12
  from anyio import Semaphore
 
17
  from transformers import logging as transformers_logging
18
  from typing_extensions import ParamSpec
19
 
20
+ from .annotators import CannyAnnotator
21
  from .logger import Logger
22
 
23
  T = TypeVar("T")
 
110
  log.error(f"RequestError: {e}")
111
 
112
 
113
+ def image_to_pil(image: Image.Image):
114
+ """Converts various image inputs to RGB PIL Image."""
115
+ if isinstance(image, str) and os.path.isfile(image):
116
+ image = Image.open(image)
117
+ if isinstance(image, np.ndarray):
118
+ image = Image.fromarray(image)
119
  if isinstance(image, Image.Image):
120
+ return image.convert("RGB")
121
+ raise ValueError("Invalid image input")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
+ def get_valid_image_size(
125
+ width: int,
126
+ height: int,
127
+ step=64,
128
+ min_size=512,
129
+ max_size=4096,
130
+ ):
131
+ """Get new image dimensions while preserving aspect ratio."""
132
 
133
+ def round_down(x):
134
+ return int((x // step) * step)
135
 
136
+ def clamp(x):
137
+ return max(min_size, min(x, max_size))
138
 
139
+ aspect_ratio = width / height
 
140
 
141
  # try width first
142
+ if width > height:
143
+ new_width = round_down(clamp(width))
144
+ new_height = round_down(new_width / aspect_ratio)
145
  else:
146
+ new_height = round_down(clamp(height))
147
+ new_width = round_down(new_height * aspect_ratio)
148
+
149
+ # if new dimensions are out of bounds, try height
150
+ if not min_size <= new_width <= max_size:
151
+ new_width = round_down(clamp(width))
152
+ new_height = round_down(new_width / aspect_ratio)
153
+ if not min_size <= new_height <= max_size:
154
+ new_height = round_down(clamp(height))
155
+ new_width = round_down(new_height * aspect_ratio)
156
+
157
+ return (new_width, new_height)
158
+
159
+
160
+ def resize_image(
161
+ image: Image.Image,
162
+ size: Tuple[int, int] = None,
163
+ resampling: Image.Resampling = None,
164
+ ):
165
+ """Resize image with proper interpolation and dimension constraints."""
166
+ image = image_to_pil(image)
167
+ if size is None:
168
+ size = get_valid_image_size(*image.size)
169
+ if resampling is None:
170
+ resampling = Image.Resampling.LANCZOS
171
+ return image.resize(size, resampling)
172
+
173
+
174
+ def annotate_image(image: Image.Image, annotator="canny"):
175
+ """Get the feature map of an image using the specified annotator."""
176
+ size = get_valid_image_size(*image.size)
177
+ image = resize_image(image, size)
178
+ if annotator.lower() == "canny":
179
+ canny = CannyAnnotator()
180
+ return canny(image, size)
181
+ raise ValueError(f"Invalid annotator: {annotator}")
182
+
183
+
184
+ # Like the original but supports args and kwargs instead of a dict
185
  # https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
186
  async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
187
  async with MAX_THREADS_GUARD: