adamelliotfields commited on
Commit
4470520
·
verified ·
1 Parent(s): 75805bd

Loader and inference improvements

Browse files
Files changed (2) hide show
  1. lib/inference.py +58 -40
  2. lib/loader.py +33 -32
lib/inference.py CHANGED
@@ -21,8 +21,8 @@ from typing_extensions import ParamSpec
21
 
22
  from .loader import Loader
23
 
24
- __import__("transformers").logging.set_verbosity_error()
25
  __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
 
26
 
27
  T = TypeVar("T")
28
  P = ParamSpec("P")
@@ -45,17 +45,17 @@ async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T
45
  return await anyio.to_thread.run_sync(partial_fn)
46
 
47
 
48
- # parse prompts with arrays
49
- def parse_prompt(prompt: str) -> list[str]:
50
  arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
51
 
52
  if not arrays:
53
  return [prompt]
54
 
55
- tokens = [item.split(",") for item in arrays]
56
- combinations = list(product(*tokens))
57
- prompts = []
58
 
 
 
59
  for combo in combinations:
60
  current_prompt = prompt
61
  for i, token in enumerate(combo):
@@ -71,8 +71,12 @@ def apply_style(prompt, style_id, negative=False):
71
  for style in STYLES:
72
  if style["id"] == style_id:
73
  if negative:
74
- return prompt + " . " + style["negative_prompt"]
 
 
 
75
  else:
 
76
  return style["prompt"].format(prompt=prompt)
77
  return prompt
78
 
@@ -97,12 +101,18 @@ def prepare_image(input, size=None):
97
 
98
 
99
  def gpu_duration(**kwargs):
100
- duration = 15
 
 
 
101
  scale = kwargs.get("scale", 1)
102
  num_images = kwargs.get("num_images", 1)
 
 
 
103
  if scale == 4:
104
  duration += 5
105
- return duration * num_images
106
 
107
 
108
  @spaces.GPU(duration=gpu_duration)
@@ -116,7 +126,7 @@ def generate(
116
  style=None,
117
  seed=None,
118
  model="Lykon/dreamshaper-8",
119
- scheduler="DEIS 2M",
120
  width=512,
121
  height=512,
122
  guidance_scale=7.5,
@@ -140,16 +150,17 @@ def generate(
140
  if seed is None or seed < 0:
141
  seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
142
 
 
 
 
 
 
143
  EMBEDDINGS_TYPE = (
144
  ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
145
  if clip_skip
146
  else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
147
  )
148
 
149
- KIND = "img2img" if image_prompt is not None else "txt2img"
150
-
151
- CURRENT_IMAGE = 1
152
-
153
  if ip_image:
154
  IP_ADAPTER = "full-face" if ip_face else "plus"
155
  else:
@@ -162,23 +173,22 @@ def generate(
162
  TQDM = True
163
 
164
  def callback_on_step_end(pipeline, step, timestep, latents):
165
- nonlocal CURRENT_IMAGE
166
  if progress is None:
167
  return latents
168
  strength = denoising_strength if KIND == "img2img" else 1
169
  total_steps = min(int(inference_steps * strength), inference_steps)
170
- current_step = step + 1
 
171
  progress(
172
- (current_step, total_steps),
173
  desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
174
  )
175
- if current_step == total_steps:
176
- CURRENT_IMAGE += 1
177
  return latents
178
 
179
  start = time.perf_counter()
180
  loader = Loader()
181
- pipe, upscaler = loader.load(
182
  KIND,
183
  IP_ADAPTER,
184
  model,
@@ -191,6 +201,17 @@ def generate(
191
  TQDM,
192
  )
193
 
 
 
 
 
 
 
 
 
 
 
 
194
  # load embeddings and append to negative prompt
195
  embeddings_dir = os.path.join(os.path.dirname(__file__), "..", "embeddings")
196
  embeddings_dir = os.path.abspath(embeddings_dir)
@@ -201,11 +222,8 @@ def generate(
201
  pretrained_model_name_or_path=f"{embeddings_dir}/{embedding}.pt",
202
  token=f"<{embedding}>",
203
  )
204
- # boost embeddings slightly
205
  negative_prompt = (
206
- f"{negative_prompt}, (<{embedding}>)1.1"
207
- if negative_prompt
208
- else f"(<{embedding}>)1.1"
209
  )
210
  except (EnvironmentError, HFValidationError, RepositoryNotFoundError):
211
  raise Error(f"Invalid embedding: <{embedding}>")
@@ -225,33 +243,33 @@ def generate(
225
 
226
  try:
227
  styled_negative_prompt = apply_style(negative_prompt, style, negative=True)
228
- neg_embeds = compel(styled_negative_prompt)
229
  except PromptParser.ParsingException:
230
- raise Error("ParsingException: Invalid negative prompt")
231
 
232
  for i in range(num_images):
233
  # seeded generator for each iteration
234
  generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
235
 
236
  try:
237
- all_positive_prompts = parse_prompt(positive_prompt)
238
  prompt_index = i % len(all_positive_prompts)
239
- pos_prompt = all_positive_prompts[prompt_index]
240
- styled_pos_prompt = apply_style(pos_prompt, style)
241
- pos_embeds = compel(styled_pos_prompt)
242
- pos_embeds, neg_embeds = compel.pad_conditioning_tensors_to_same_length(
243
- [pos_embeds, neg_embeds]
244
  )
245
  except PromptParser.ParsingException:
246
- raise Error("ParsingException: Invalid prompt")
247
 
248
  kwargs = {
249
  "width": width,
250
  "height": height,
251
  "generator": generator,
252
- "prompt_embeds": pos_embeds,
253
  "guidance_scale": guidance_scale,
254
- "negative_prompt_embeds": neg_embeds,
255
  "num_inference_steps": inference_steps,
256
  "output_type": "np" if scale > 1 else "pil",
257
  }
@@ -273,13 +291,13 @@ def generate(
273
  if scale > 1:
274
  image = upscaler.predict(image)
275
  images.append((image, str(current_seed)))
 
 
 
276
  finally:
277
  pipe.unload_textual_inversion()
278
- torch.cuda.empty_cache()
279
- torch.cuda.ipc_collect()
280
-
281
- # increment seed for next image
282
- current_seed += 1
283
 
284
  diff = time.perf_counter() - start
285
  if Info:
 
21
 
22
  from .loader import Loader
23
 
 
24
  __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
25
+ __import__("transformers").logging.set_verbosity_error()
26
 
27
  T = TypeVar("T")
28
  P = ParamSpec("P")
 
45
  return await anyio.to_thread.run_sync(partial_fn)
46
 
47
 
48
+ def parse_prompt_with_arrays(prompt: str) -> list[str]:
 
49
  arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
50
 
51
  if not arrays:
52
  return [prompt]
53
 
54
+ tokens = [item.split(",") for item in arrays] # [("a", "b"), ("1", "2")]
55
+ combinations = list(product(*tokens)) # [("a", "1"), ("a", "2"), ("b", "1"), ("b", "2")]
 
56
 
57
+ # find all the arrays in the prompt and replace them with tokens
58
+ prompts = []
59
  for combo in combinations:
60
  current_prompt = prompt
61
  for i, token in enumerate(combo):
 
71
  for style in STYLES:
72
  if style["id"] == style_id:
73
  if negative:
74
+ return (
75
+ # prepend our negative prompt to the style's negative prompt
76
+ f"{prompt}, {style['negative_prompt']}" if prompt else style["negative_prompt"]
77
+ )
78
  else:
79
+ # inject our positive prompt into the style prompt
80
  return style["prompt"].format(prompt=prompt)
81
  return prompt
82
 
 
101
 
102
 
103
  def gpu_duration(**kwargs):
104
+ loading = 20
105
+ duration = 10
106
+ width = kwargs.get("width", 512)
107
+ height = kwargs.get("height", 512)
108
  scale = kwargs.get("scale", 1)
109
  num_images = kwargs.get("num_images", 1)
110
+ size = width * height
111
+ if size > 500_000:
112
+ duration += 5
113
  if scale == 4:
114
  duration += 5
115
+ return loading + (duration * num_images)
116
 
117
 
118
  @spaces.GPU(duration=gpu_duration)
 
126
  style=None,
127
  seed=None,
128
  model="Lykon/dreamshaper-8",
129
+ scheduler="DDIM",
130
  width=512,
131
  height=512,
132
  guidance_scale=7.5,
 
150
  if seed is None or seed < 0:
151
  seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
152
 
153
+ CURRENT_STEP = 0
154
+ CURRENT_IMAGE = 1
155
+
156
+ KIND = "img2img" if image_prompt is not None else "txt2img"
157
+
158
  EMBEDDINGS_TYPE = (
159
  ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
160
  if clip_skip
161
  else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
162
  )
163
 
 
 
 
 
164
  if ip_image:
165
  IP_ADAPTER = "full-face" if ip_face else "plus"
166
  else:
 
173
  TQDM = True
174
 
175
  def callback_on_step_end(pipeline, step, timestep, latents):
176
+ nonlocal CURRENT_STEP, CURRENT_IMAGE
177
  if progress is None:
178
  return latents
179
  strength = denoising_strength if KIND == "img2img" else 1
180
  total_steps = min(int(inference_steps * strength), inference_steps)
181
+
182
+ CURRENT_STEP = step + 1
183
  progress(
184
+ (CURRENT_STEP, total_steps),
185
  desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
186
  )
 
 
187
  return latents
188
 
189
  start = time.perf_counter()
190
  loader = Loader()
191
+ loader.load(
192
  KIND,
193
  IP_ADAPTER,
194
  model,
 
201
  TQDM,
202
  )
203
 
204
+ if loader.pipe is None:
205
+ raise Error(f"RuntimeError: Error loading {model}")
206
+
207
+ pipe = loader.pipe
208
+ upscaler = None
209
+
210
+ if scale == 2:
211
+ upscaler = loader.upscaler_2x
212
+ if scale == 4:
213
+ upscaler = loader.upscaler_4x
214
+
215
  # load embeddings and append to negative prompt
216
  embeddings_dir = os.path.join(os.path.dirname(__file__), "..", "embeddings")
217
  embeddings_dir = os.path.abspath(embeddings_dir)
 
222
  pretrained_model_name_or_path=f"{embeddings_dir}/{embedding}.pt",
223
  token=f"<{embedding}>",
224
  )
 
225
  negative_prompt = (
226
+ f"{negative_prompt}, <{embedding}>" if negative_prompt else f"<{embedding}>"
 
 
227
  )
228
  except (EnvironmentError, HFValidationError, RepositoryNotFoundError):
229
  raise Error(f"Invalid embedding: <{embedding}>")
 
243
 
244
  try:
245
  styled_negative_prompt = apply_style(negative_prompt, style, negative=True)
246
+ negative_embeds = compel(styled_negative_prompt)
247
  except PromptParser.ParsingException:
248
+ raise Error("ValueError: Invalid negative prompt")
249
 
250
  for i in range(num_images):
251
  # seeded generator for each iteration
252
  generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
253
 
254
  try:
255
+ all_positive_prompts = parse_prompt_with_arrays(positive_prompt)
256
  prompt_index = i % len(all_positive_prompts)
257
+ prompt = all_positive_prompts[prompt_index]
258
+ prompt = apply_style(prompt, style)
259
+ positive_embeds = compel(prompt)
260
+ positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length(
261
+ [positive_embeds, negative_embeds]
262
  )
263
  except PromptParser.ParsingException:
264
+ raise Error("ValueError: Invalid prompt")
265
 
266
  kwargs = {
267
  "width": width,
268
  "height": height,
269
  "generator": generator,
270
+ "prompt_embeds": positive_embeds,
271
  "guidance_scale": guidance_scale,
272
+ "negative_prompt_embeds": negative_embeds,
273
  "num_inference_steps": inference_steps,
274
  "output_type": "np" if scale > 1 else "pil",
275
  }
 
291
  if scale > 1:
292
  image = upscaler.predict(image)
293
  images.append((image, str(current_seed)))
294
+ current_seed += 1
295
+ except Exception as e:
296
+ raise Error(f"RuntimeError: {e}")
297
  finally:
298
  pipe.unload_textual_inversion()
299
+ CURRENT_STEP = 0
300
+ CURRENT_IMAGE += 1
 
 
 
301
 
302
  diff = time.perf_counter() - start
303
  if Info:
lib/loader.py CHANGED
@@ -27,13 +27,11 @@ class Loader:
27
  cls._instance = super().__new__(cls)
28
  cls._instance.pipe = None
29
  cls._instance.model = None
30
- cls._instance.upscaler = None
31
  cls._instance.ip_adapter = None
 
 
32
  return cls._instance
33
 
34
- def _should_unload_upscaler(self, scale=1):
35
- return self.upscaler is not None and scale == 1
36
-
37
  def _should_unload_ip_adapter(self, ip_adapter=""):
38
  return self.ip_adapter is not None and not ip_adapter
39
 
@@ -78,25 +76,17 @@ class Loader:
78
  torch.cuda.reset_peak_memory_stats()
79
  torch.cuda.synchronize()
80
 
81
- def _unload(self, kind="", model="", ip_adapter="", scale=1):
82
  to_unload = []
83
-
84
- if self._should_unload_upscaler(scale):
85
- to_unload.append("upscaler")
86
-
87
  if self._should_unload_ip_adapter(ip_adapter):
88
  self._unload_ip_adapter()
89
  to_unload.append("ip_adapter")
90
-
91
  if self._should_unload_pipeline(kind, model):
92
  to_unload.append("model")
93
  to_unload.append("pipe")
94
-
95
  for component in to_unload:
96
  delattr(self, component)
97
-
98
  self._flush()
99
-
100
  for component in to_unload:
101
  setattr(self, component, None)
102
 
@@ -112,35 +102,46 @@ class Loader:
112
  self.pipe.set_ip_adapter_scale(0.5)
113
  self.ip_adapter = ip_adapter
114
 
115
- def _load_upscaler(self, scale=1, device=None):
116
- if scale > 1 and self.upscaler is None:
117
- print(f"Loading {scale}x upscaler...")
118
- self.upscaler = RealESRGAN(scale, device)
119
- self.upscaler.load_weights()
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- def _load_pipeline(self, kind, model, tqdm, device, **kwargs):
122
  pipeline = Config.PIPELINES[kind]
123
  if self.pipe is None:
124
- print(f"Loading {model}...")
125
  try:
 
 
126
  if model.lower() in Config.MODEL_CHECKPOINTS.keys():
127
  self.pipe = pipeline.from_single_file(
128
  f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
129
  **kwargs,
130
- ).to(device)
131
  else:
132
- self.pipe = pipeline.from_pretrained(model, **kwargs).to(device)
133
- self.model = model
134
  except Exception as e:
135
  print(f"Error loading {model}: {e}")
136
  self.model = None
137
  self.pipe = None
138
  return
139
-
140
  if not isinstance(self.pipe, pipeline):
141
- self.pipe = pipeline.from_pipe(self.pipe).to(device)
142
-
143
- self.pipe.set_progress_bar_config(disable=not tqdm)
144
 
145
  def _load_vae(self, taesd=False, model=""):
146
  vae_type = type(self.pipe.vae)
@@ -251,14 +252,15 @@ class Loader:
251
  else torch.float16
252
  )
253
  else:
 
254
  pipe_kwargs["torch_dtype"] = torch.float16
255
 
256
- self._unload(kind, model, ip_adapter, scale)
257
- self._load_pipeline(kind, model, tqdm, device, **pipe_kwargs)
258
 
259
  # error loading model
260
  if self.pipe is None:
261
- return None, None
262
 
263
  same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
264
  same_karras = (
@@ -279,5 +281,4 @@ class Loader:
279
  self._load_vae(taesd, model)
280
  self._load_deepcache(deepcache)
281
  self._load_ip_adapter(ip_adapter)
282
- self._load_upscaler(scale, device)
283
- return self.pipe, self.upscaler
 
27
  cls._instance = super().__new__(cls)
28
  cls._instance.pipe = None
29
  cls._instance.model = None
 
30
  cls._instance.ip_adapter = None
31
+ cls._instance.upscaler_2x = None
32
+ cls._instance.upscaler_4x = None
33
  return cls._instance
34
 
 
 
 
35
  def _should_unload_ip_adapter(self, ip_adapter=""):
36
  return self.ip_adapter is not None and not ip_adapter
37
 
 
76
  torch.cuda.reset_peak_memory_stats()
77
  torch.cuda.synchronize()
78
 
79
+ def _unload(self, kind="", model="", ip_adapter=""):
80
  to_unload = []
 
 
 
 
81
  if self._should_unload_ip_adapter(ip_adapter):
82
  self._unload_ip_adapter()
83
  to_unload.append("ip_adapter")
 
84
  if self._should_unload_pipeline(kind, model):
85
  to_unload.append("model")
86
  to_unload.append("pipe")
 
87
  for component in to_unload:
88
  delattr(self, component)
 
89
  self._flush()
 
90
  for component in to_unload:
91
  setattr(self, component, None)
92
 
 
102
  self.pipe.set_ip_adapter_scale(0.5)
103
  self.ip_adapter = ip_adapter
104
 
105
+ def _load_upscaler(self, scale=1):
106
+ if scale == 2 and self.upscaler_2x is None:
107
+ try:
108
+ print("Loading 2x upscaler...")
109
+ self.upscaler_2x = RealESRGAN(2, "cuda")
110
+ self.upscaler_2x.load_weights()
111
+ except Exception as e:
112
+ print(f"Error loading 2x upscaler: {e}")
113
+ self.upscaler_2x = None
114
+ if scale == 4 and self.upscaler_4x is None:
115
+ try:
116
+ print("Loading 4x upscaler...")
117
+ self.upscaler_4x = RealESRGAN(4, "cuda")
118
+ self.upscaler_4x.load_weights()
119
+ except Exception as e:
120
+ print(f"Error loading 4x upscaler: {e}")
121
+ self.upscaler_4x = None
122
 
123
+ def _load_pipeline(self, kind, model, tqdm, **kwargs):
124
  pipeline = Config.PIPELINES[kind]
125
  if self.pipe is None:
 
126
  try:
127
+ print(f"Loading {model}...")
128
+ self.model = model
129
  if model.lower() in Config.MODEL_CHECKPOINTS.keys():
130
  self.pipe = pipeline.from_single_file(
131
  f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
132
  **kwargs,
133
+ ).to("cuda")
134
  else:
135
+ self.pipe = pipeline.from_pretrained(model, **kwargs).to("cuda")
 
136
  except Exception as e:
137
  print(f"Error loading {model}: {e}")
138
  self.model = None
139
  self.pipe = None
140
  return
 
141
  if not isinstance(self.pipe, pipeline):
142
+ self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
143
+ if self.pipe is not None:
144
+ self.pipe.set_progress_bar_config(disable=not tqdm)
145
 
146
  def _load_vae(self, taesd=False, model=""):
147
  vae_type = type(self.pipe.vae)
 
252
  else torch.float16
253
  )
254
  else:
255
+ # defaults to float32
256
  pipe_kwargs["torch_dtype"] = torch.float16
257
 
258
+ self._unload(kind, model, ip_adapter)
259
+ self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
260
 
261
  # error loading model
262
  if self.pipe is None:
263
+ return
264
 
265
  same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
266
  same_karras = (
 
281
  self._load_vae(taesd, model)
282
  self._load_deepcache(deepcache)
283
  self._load_ip_adapter(ip_adapter)
284
+ self._load_upscaler(scale)