adamelliotfields commited on
Commit
232c234
1 Parent(s): 88a4072

Async generate wrapper

Browse files
Files changed (7) hide show
  1. app.py +9 -9
  2. cli.py +7 -5
  3. lib/__init__.py +2 -2
  4. lib/config.py +1 -1
  5. lib/inference.py +35 -7
  6. lib/loader.py +10 -7
  7. requirements.txt +1 -0
app.py CHANGED
@@ -4,7 +4,7 @@ import random
4
 
5
  import gradio as gr
6
 
7
- from lib import Config, generate
8
 
9
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
10
  refresh_seed_js = """
@@ -79,7 +79,7 @@ def image_select_fn(images, image, i):
79
  return gr.Image(images[i][0]) if i > -1 else None
80
 
81
 
82
- def generate_fn(*args):
83
  if len(args) > 0:
84
  prompt = args[0]
85
  else:
@@ -87,7 +87,7 @@ def generate_fn(*args):
87
  if prompt is None or prompt.strip() == "":
88
  raise gr.Error("You must enter a prompt")
89
  try:
90
- images = generate(*args, Info=gr.Info, Error=gr.Error)
91
  except RuntimeError:
92
  raise gr.Error("RuntimeError: Please try again")
93
  return images
@@ -194,25 +194,25 @@ with gr.Blocks(
194
  width = gr.Slider(
195
  value=Config.WIDTH,
196
  label="Width",
197
- minimum=320,
198
  maximum=768,
199
- step=16,
200
  )
201
  height = gr.Slider(
202
  value=Config.HEIGHT,
203
  label="Height",
204
- minimum=320,
205
  maximum=768,
206
- step=16,
207
  )
208
  aspect_ratio = gr.Dropdown(
209
  choices=[
210
  ("Custom", None),
 
211
  ("7:9 (448x576)", "448,576"),
212
- ("3:4 (432x576)", "432,576"),
213
  ("1:1 (512x512)", "512,512"),
214
- ("4:3 (576x432)", "576,432"),
215
  ("9:7 (576x448)", "576,448"),
 
216
  ],
217
  value="448,576",
218
  filterable=False,
 
4
 
5
  import gradio as gr
6
 
7
+ from lib import Config, async_call, generate
8
 
9
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
10
  refresh_seed_js = """
 
79
  return gr.Image(images[i][0]) if i > -1 else None
80
 
81
 
82
+ async def generate_fn(*args):
83
  if len(args) > 0:
84
  prompt = args[0]
85
  else:
 
87
  if prompt is None or prompt.strip() == "":
88
  raise gr.Error("You must enter a prompt")
89
  try:
90
+ images = await async_call(generate, *args, Info=gr.Info, Error=gr.Error)
91
  except RuntimeError:
92
  raise gr.Error("RuntimeError: Please try again")
93
  return images
 
194
  width = gr.Slider(
195
  value=Config.WIDTH,
196
  label="Width",
197
+ minimum=256,
198
  maximum=768,
199
+ step=32,
200
  )
201
  height = gr.Slider(
202
  value=Config.HEIGHT,
203
  label="Height",
204
+ minimum=256,
205
  maximum=768,
206
+ step=32,
207
  )
208
  aspect_ratio = gr.Dropdown(
209
  choices=[
210
  ("Custom", None),
211
+ ("4:7 (384x672)", "384,672"),
212
  ("7:9 (448x576)", "448,576"),
 
213
  ("1:1 (512x512)", "512,512"),
 
214
  ("9:7 (576x448)", "576,448"),
215
+ ("7:4 (672x384)", "672,384"),
216
  ],
217
  value="448,576",
218
  filterable=False,
cli.py CHANGED
@@ -1,8 +1,9 @@
1
  # CLI
2
  # usage: python cli.py 'colorful calico cat artstation'
3
  import argparse
 
4
 
5
- from lib import Config, generate
6
 
7
 
8
  def save_images(images, filename="image.png"):
@@ -11,7 +12,7 @@ def save_images(images, filename="image.png"):
11
  img.save(f"{name}.{ext}" if len(images) == 1 else f"{name}_{i}.{ext}")
12
 
13
 
14
- def main():
15
  # fmt: off
16
  parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
17
  parser.add_argument("prompt", type=str, metavar="PROMPT")
@@ -42,7 +43,8 @@ def main():
42
  # fmt: on
43
 
44
  args = parser.parse_args()
45
- images = generate(
 
46
  args.prompt,
47
  args.negative,
48
  args.image,
@@ -68,8 +70,8 @@ def main():
68
  args.deepcache,
69
  args.scale,
70
  )
71
- save_images(images, args.filename)
72
 
73
 
74
  if __name__ == "__main__":
75
- main()
 
1
  # CLI
2
  # usage: python cli.py 'colorful calico cat artstation'
3
  import argparse
4
+ import asyncio
5
 
6
+ from lib import Config, async_call, generate
7
 
8
 
9
  def save_images(images, filename="image.png"):
 
12
  img.save(f"{name}.{ext}" if len(images) == 1 else f"{name}_{i}.{ext}")
13
 
14
 
15
+ async def main():
16
  # fmt: off
17
  parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
18
  parser.add_argument("prompt", type=str, metavar="PROMPT")
 
43
  # fmt: on
44
 
45
  args = parser.parse_args()
46
+ images = await async_call(
47
+ generate,
48
  args.prompt,
49
  args.negative,
50
  args.image,
 
70
  args.deepcache,
71
  args.scale,
72
  )
73
+ await async_call(save_images, images, args.filename)
74
 
75
 
76
  if __name__ == "__main__":
77
+ asyncio.run(main())
lib/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  from .config import Config
2
- from .inference import generate
3
  from .loader import Loader
4
  from .upscaler import RealESRGAN
5
 
6
- __all__ = ["Config", "Loader", "RealESRGAN", "generate"]
 
1
  from .config import Config
2
+ from .inference import async_call, generate
3
  from .loader import Loader
4
  from .upscaler import RealESRGAN
5
 
6
+ __all__ = ["Config", "Loader", "RealESRGAN", "async_call", "generate"]
lib/config.py CHANGED
@@ -41,7 +41,7 @@ Config = SimpleNamespace(
41
  GUIDANCE_SCALE=6,
42
  INFERENCE_STEPS=35,
43
  DENOISING_STRENGTH=0.6,
44
- DEEPCACHE_INTERVAL=2,
45
  SCALE=1,
46
  SCALES=[1, 2, 4],
47
  )
 
41
  GUIDANCE_SCALE=6,
42
  INFERENCE_STEPS=35,
43
  DENOISING_STRENGTH=0.6,
44
+ DEEPCACHE_INTERVAL=1,
45
  SCALE=1,
46
  SCALES=[1, 2, 4],
47
  )
lib/inference.py CHANGED
@@ -1,26 +1,48 @@
 
 
1
  import json
2
  import os
3
  import re
4
  import time
5
  from datetime import datetime
6
  from itertools import product
7
- from typing import Callable
8
 
 
9
  import numpy as np
10
  import spaces
11
  import torch
 
12
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
13
  from compel.prompt_parser import PromptParser
14
  from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
15
  from PIL import Image
 
16
 
17
  from .loader import Loader
18
 
19
  __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
20
  __import__("transformers").logging.set_verbosity_error()
21
 
 
 
 
 
 
 
22
  with open("./data/styles.json") as f:
23
- styles = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  # parse prompts with arrays
@@ -43,10 +65,10 @@ def parse_prompt(prompt: str) -> list[str]:
43
 
44
 
45
  def apply_style(prompt, style_id, negative=False):
46
- global styles
47
  if not style_id or style_id == "None":
48
  return prompt
49
- for style in styles:
50
  if style["id"] == style_id:
51
  if negative:
52
  return prompt + " . " + style["negative_prompt"]
@@ -55,7 +77,7 @@ def apply_style(prompt, style_id, negative=False):
55
  return prompt
56
 
57
 
58
- def prepare_image(input, size=(512, 512)):
59
  image = None
60
  if isinstance(input, Image.Image):
61
  image = input
@@ -65,7 +87,11 @@ def prepare_image(input, size=(512, 512)):
65
  if os.path.isfile(input):
66
  image = Image.open(input)
67
  if image is not None:
68
- return image.convert("RGB").resize(size, Image.Resampling.LANCZOS)
 
 
 
 
69
  else:
70
  raise ValueError("Invalid image prompt")
71
 
@@ -213,7 +239,9 @@ def generate(
213
  kwargs["image"] = prepare_image(image_prompt, (width, height))
214
 
215
  if IP_ADAPTER:
216
- kwargs["ip_adapter_image"] = prepare_image(ip_image, (width, height))
 
 
217
 
218
  try:
219
  image = pipe(**kwargs).images[0]
 
1
+ import functools
2
+ import inspect
3
  import json
4
  import os
5
  import re
6
  import time
7
  from datetime import datetime
8
  from itertools import product
9
+ from typing import Callable, TypeVar
10
 
11
+ import anyio
12
  import numpy as np
13
  import spaces
14
  import torch
15
+ from anyio import Semaphore
16
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
17
  from compel.prompt_parser import PromptParser
18
  from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
19
  from PIL import Image
20
+ from typing_extensions import ParamSpec
21
 
22
  from .loader import Loader
23
 
24
  __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
25
  __import__("transformers").logging.set_verbosity_error()
26
 
27
+ T = TypeVar("T")
28
+ P = ParamSpec("P")
29
+
30
+ MAX_CONCURRENT_THREADS = 1
31
+ MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
32
+
33
  with open("./data/styles.json") as f:
34
+ STYLES = json.load(f)
35
+
36
+
37
+ # like the original but supports args and kwargs instead of a dict
38
+ # https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
39
+ async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
40
+ async with MAX_THREADS_GUARD:
41
+ sig = inspect.signature(fn)
42
+ bound_args = sig.bind(*args, **kwargs)
43
+ bound_args.apply_defaults()
44
+ partial_fn = functools.partial(fn, **bound_args.arguments)
45
+ return await anyio.to_thread.run_sync(partial_fn)
46
 
47
 
48
  # parse prompts with arrays
 
65
 
66
 
67
  def apply_style(prompt, style_id, negative=False):
68
+ global STYLES
69
  if not style_id or style_id == "None":
70
  return prompt
71
+ for style in STYLES:
72
  if style["id"] == style_id:
73
  if negative:
74
  return prompt + " . " + style["negative_prompt"]
 
77
  return prompt
78
 
79
 
80
+ def prepare_image(input, size=None):
81
  image = None
82
  if isinstance(input, Image.Image):
83
  image = input
 
87
  if os.path.isfile(input):
88
  image = Image.open(input)
89
  if image is not None:
90
+ image = image.convert("RGB")
91
+ if size is not None:
92
+ image = image.resize(size, Image.Resampling.LANCZOS)
93
+ if image is not None:
94
+ return image
95
  else:
96
  raise ValueError("Invalid image prompt")
97
 
 
239
  kwargs["image"] = prepare_image(image_prompt, (width, height))
240
 
241
  if IP_ADAPTER:
242
+ # don't resize full-face images
243
+ size = None if ip_face else (width, height)
244
+ kwargs["ip_adapter_image"] = prepare_image(ip_image, size)
245
 
246
  try:
247
  image = pipe(**kwargs).images[0]
lib/loader.py CHANGED
@@ -104,31 +104,33 @@ class Loader:
104
  print("Switching to Tiny VAE...")
105
  self.pipe.vae = AutoencoderTiny.from_pretrained(
106
  pretrained_model_name_or_path="madebyollin/taesd",
107
- ).to(self.pipe.device, self.pipe.dtype)
 
108
  return
109
 
110
  if is_tiny and not taesd:
111
  print("Switching to KL VAE...")
112
  model = AutoencoderKL.from_pretrained(
113
  pretrained_model_name_or_path=model_name,
 
114
  subfolder="vae",
115
  variant=variant,
116
- ).to(self.pipe.device, self.pipe.dtype)
117
  self.pipe.vae = torch.compile(
118
  mode="reduce-overhead",
119
  fullgraph=True,
120
  model=model,
121
  )
122
 
123
- def _load_pipeline(self, kind, model, device, dtype, **kwargs):
124
  pipelines = {
125
  "txt2img": StableDiffusionPipeline,
126
  "img2img": StableDiffusionImg2ImgPipeline,
127
  }
128
  if self.pipe is None:
129
- self.pipe = pipelines[kind].from_pretrained(model, **kwargs).to(device, dtype)
130
  if not isinstance(self.pipe, pipelines[kind]):
131
- self.pipe = pipelines[kind].from_pipe(self.pipe).to(device, dtype)
132
  self.ip_adapter = None
133
 
134
  def load(
@@ -186,13 +188,14 @@ class Loader:
186
  "scheduler": schedulers[scheduler](**scheduler_kwargs),
187
  "requires_safety_checker": False,
188
  "safety_checker": None,
 
189
  "variant": variant,
190
  }
191
 
192
  if self.pipe is None:
193
  print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
194
 
195
- self._load_pipeline(kind, model_lower, device, dtype, **pipe_kwargs)
196
  model_name = self.pipe.config._name_or_path
197
  same_model = model_name.lower() == model_lower
198
  same_scheduler = isinstance(self.pipe.scheduler, schedulers[scheduler])
@@ -210,7 +213,7 @@ class Loader:
210
  self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
211
  else:
212
  self.pipe = None
213
- self._load_pipeline(kind, model_lower, device, dtype, **pipe_kwargs)
214
 
215
  self._load_ip_adapter(ip_adapter)
216
  self._load_vae(taesd, model_lower, variant)
 
104
  print("Switching to Tiny VAE...")
105
  self.pipe.vae = AutoencoderTiny.from_pretrained(
106
  pretrained_model_name_or_path="madebyollin/taesd",
107
+ torch_dtype=self.pipe.dtype,
108
+ ).to(self.pipe.device)
109
  return
110
 
111
  if is_tiny and not taesd:
112
  print("Switching to KL VAE...")
113
  model = AutoencoderKL.from_pretrained(
114
  pretrained_model_name_or_path=model_name,
115
+ torch_dtype=self.pipe.dtype,
116
  subfolder="vae",
117
  variant=variant,
118
+ ).to(self.pipe.device)
119
  self.pipe.vae = torch.compile(
120
  mode="reduce-overhead",
121
  fullgraph=True,
122
  model=model,
123
  )
124
 
125
+ def _load_pipeline(self, kind, model, device, **kwargs):
126
  pipelines = {
127
  "txt2img": StableDiffusionPipeline,
128
  "img2img": StableDiffusionImg2ImgPipeline,
129
  }
130
  if self.pipe is None:
131
+ self.pipe = pipelines[kind].from_pretrained(model, **kwargs).to(device)
132
  if not isinstance(self.pipe, pipelines[kind]):
133
+ self.pipe = pipelines[kind].from_pipe(self.pipe).to(device)
134
  self.ip_adapter = None
135
 
136
  def load(
 
188
  "scheduler": schedulers[scheduler](**scheduler_kwargs),
189
  "requires_safety_checker": False,
190
  "safety_checker": None,
191
+ "torch_dtype": dtype,
192
  "variant": variant,
193
  }
194
 
195
  if self.pipe is None:
196
  print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
197
 
198
+ self._load_pipeline(kind, model_lower, device, **pipe_kwargs)
199
  model_name = self.pipe.config._name_or_path
200
  same_model = model_name.lower() == model_lower
201
  same_scheduler = isinstance(self.pipe.scheduler, schedulers[scheduler])
 
213
  self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
214
  else:
215
  self.pipe = None
216
+ self._load_pipeline(kind, model_lower, device, **pipe_kwargs)
217
 
218
  self._load_ip_adapter(ip_adapter)
219
  self._load_vae(taesd, model_lower, variant)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  accelerate
2
  einops==0.8.0
3
  compel==2.0.3
 
1
+ anyio==4.4.0
2
  accelerate
3
  einops==0.8.0
4
  compel==2.0.3