adamelliotfields commited on
Commit
6829539
1 Parent(s): 9edebae

Custom progress bar

Browse files
Files changed (1) hide show
  1. lib/inference.py +127 -112
lib/inference.py CHANGED
@@ -9,6 +9,7 @@ from itertools import product
9
  from typing import Callable, TypeVar
10
 
11
  import anyio
 
12
  import numpy as np
13
  import spaces
14
  import torch
@@ -113,17 +114,16 @@ def generate(
113
  guidance_scale=7.5,
114
  inference_steps=50,
115
  denoising_strength=0.8,
 
 
116
  num_images=1,
117
  karras=False,
118
  taesd=False,
119
  freeu=False,
120
  clip_skip=False,
121
- truncate_prompts=False,
122
- increment_seed=True,
123
- deepcache=1,
124
- scale=1,
125
  Info: Callable[[str], None] = None,
126
  Error=Exception,
 
127
  ):
128
  if not torch.cuda.is_available():
129
  raise Error("CUDA not available")
@@ -134,12 +134,6 @@ def generate(
134
 
135
  DEVICE = torch.device("cuda")
136
 
137
- DTYPE = (
138
- torch.bfloat16
139
- if torch.cuda.is_available() and torch.cuda.get_device_properties(DEVICE).major >= 8
140
- else torch.float16
141
- )
142
-
143
  EMBEDDINGS_TYPE = (
144
  ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
145
  if clip_skip
@@ -148,114 +142,135 @@ def generate(
148
 
149
  KIND = "img2img" if image_prompt is not None else "txt2img"
150
 
151
- IP_ADAPTER = None
152
 
153
  if ip_image:
154
  IP_ADAPTER = "full-face" if ip_face else "plus"
155
-
156
- with torch.inference_mode():
157
- start = time.perf_counter()
158
- loader = Loader()
159
- pipe, upscaler = loader.load(
160
- KIND,
161
- IP_ADAPTER,
162
- model,
163
- scheduler,
164
- karras,
165
- taesd,
166
- freeu,
167
- deepcache,
168
- scale,
169
- DEVICE,
170
- DTYPE,
171
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- # load embeddings and append to negative prompt
174
- embeddings_dir = os.path.join(os.path.dirname(__file__), "..", "embeddings")
175
- embeddings_dir = os.path.abspath(embeddings_dir)
176
- for embedding in embeddings:
177
- try:
178
- pipe.load_textual_inversion(
179
- pretrained_model_name_or_path=f"{embeddings_dir}/{embedding}.pt",
180
- token=f"<{embedding}>",
181
- )
182
- negative_prompt = (
183
- f"{negative_prompt}, (<{embedding}>)1.1"
184
- if negative_prompt
185
- else f"(<{embedding}>)1.1"
186
- )
187
- except (EnvironmentError, HFValidationError, RepositoryNotFoundError):
188
- raise Error(f"Invalid embedding: <{embedding}>")
189
-
190
- # prompt embeds
191
- compel = Compel(
192
- device=pipe.device,
193
- tokenizer=pipe.tokenizer,
194
- text_encoder=pipe.text_encoder,
195
- truncate_long_prompts=truncate_prompts,
196
- dtype_for_device_getter=lambda _: DTYPE,
197
- returned_embeddings_type=EMBEDDINGS_TYPE,
198
- textual_inversion_manager=DiffusersTextualInversionManager(pipe),
199
- )
 
200
 
201
- images = []
202
- current_seed = seed
 
 
 
 
 
 
 
 
 
 
203
 
204
  try:
205
- styled_negative_prompt = apply_style(negative_prompt, style, negative=True)
206
- neg_embeds = compel(styled_negative_prompt)
 
 
 
 
 
 
207
  except PromptParser.ParsingException:
208
- raise Error("ParsingException: Invalid negative prompt")
209
-
210
- for i in range(num_images):
211
- # seeded generator for each iteration
212
- generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
213
-
214
- try:
215
- all_positive_prompts = parse_prompt(positive_prompt)
216
- prompt_index = i % len(all_positive_prompts)
217
- pos_prompt = all_positive_prompts[prompt_index]
218
- styled_pos_prompt = apply_style(pos_prompt, style)
219
- pos_embeds = compel(styled_pos_prompt)
220
- pos_embeds, neg_embeds = compel.pad_conditioning_tensors_to_same_length(
221
- [pos_embeds, neg_embeds]
222
- )
223
- except PromptParser.ParsingException:
224
- raise Error("ParsingException: Invalid prompt")
225
-
226
- kwargs = {
227
- "width": width,
228
- "height": height,
229
- "generator": generator,
230
- "prompt_embeds": pos_embeds,
231
- "guidance_scale": guidance_scale,
232
- "negative_prompt_embeds": neg_embeds,
233
- "num_inference_steps": inference_steps,
234
- "output_type": "np" if scale > 1 else "pil",
235
- }
236
-
237
- if KIND == "img2img":
238
- kwargs["strength"] = denoising_strength
239
- kwargs["image"] = prepare_image(image_prompt, (width, height))
240
-
241
- if IP_ADAPTER:
242
- # don't resize full-face images
243
- size = None if ip_face else (width, height)
244
- kwargs["ip_adapter_image"] = prepare_image(ip_image, size)
245
-
246
- try:
247
- image = pipe(**kwargs).images[0]
248
- if scale > 1:
249
- image = upscaler.predict(image)
250
- images.append((image, str(current_seed)))
251
- finally:
252
- pipe.unload_textual_inversion()
253
- torch.cuda.empty_cache()
254
-
255
- if increment_seed:
256
- current_seed += 1
257
-
258
- diff = time.perf_counter() - start
259
- if Info:
260
- Info(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
261
- return images
 
9
  from typing import Callable, TypeVar
10
 
11
  import anyio
12
+ import gradio as gr
13
  import numpy as np
14
  import spaces
15
  import torch
 
114
  guidance_scale=7.5,
115
  inference_steps=50,
116
  denoising_strength=0.8,
117
+ deepcache=1,
118
+ scale=1,
119
  num_images=1,
120
  karras=False,
121
  taesd=False,
122
  freeu=False,
123
  clip_skip=False,
 
 
 
 
124
  Info: Callable[[str], None] = None,
125
  Error=Exception,
126
+ progress=gr.Progress(),
127
  ):
128
  if not torch.cuda.is_available():
129
  raise Error("CUDA not available")
 
134
 
135
  DEVICE = torch.device("cuda")
136
 
 
 
 
 
 
 
137
  EMBEDDINGS_TYPE = (
138
  ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
139
  if clip_skip
 
142
 
143
  KIND = "img2img" if image_prompt is not None else "txt2img"
144
 
145
+ CURRENT_IMAGE = 1
146
 
147
  if ip_image:
148
  IP_ADAPTER = "full-face" if ip_face else "plus"
149
+ else:
150
+ IP_ADAPTER = ""
151
+
152
+ if progress is not None:
153
+ progress((0, inference_steps), desc=f"Generating image {CURRENT_IMAGE}/{num_images}")
154
+
155
+ def callback_on_step_end(pipeline, step, timestep, latents):
156
+ nonlocal CURRENT_IMAGE
157
+ strength = denoising_strength if KIND == "img2img" else 1
158
+ total_steps = min(int(inference_steps * strength), inference_steps)
159
+ current_step = step + 1
160
+ progress(
161
+ (current_step, total_steps),
162
+ desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
 
 
163
  )
164
+ if current_step == total_steps:
165
+ CURRENT_IMAGE += 1
166
+ return latents
167
+
168
+ start = time.perf_counter()
169
+ loader = Loader()
170
+ pipe, upscaler = loader.load(
171
+ KIND,
172
+ IP_ADAPTER,
173
+ model,
174
+ scheduler,
175
+ karras,
176
+ taesd,
177
+ freeu,
178
+ deepcache,
179
+ scale,
180
+ DEVICE,
181
+ )
182
 
183
+ # load embeddings and append to negative prompt
184
+ embeddings_dir = os.path.join(os.path.dirname(__file__), "..", "embeddings")
185
+ embeddings_dir = os.path.abspath(embeddings_dir)
186
+ for embedding in embeddings:
187
+ try:
188
+ # wrap embeddings in angle brackets
189
+ pipe.load_textual_inversion(
190
+ pretrained_model_name_or_path=f"{embeddings_dir}/{embedding}.pt",
191
+ token=f"<{embedding}>",
192
+ )
193
+ # boost embeddings slightly
194
+ negative_prompt = (
195
+ f"{negative_prompt}, (<{embedding}>)1.1"
196
+ if negative_prompt
197
+ else f"(<{embedding}>)1.1"
198
+ )
199
+ except (EnvironmentError, HFValidationError, RepositoryNotFoundError):
200
+ raise Error(f"Invalid embedding: <{embedding}>")
201
+
202
+ # prompt embeds
203
+ compel = Compel(
204
+ device=pipe.device,
205
+ tokenizer=pipe.tokenizer,
206
+ text_encoder=pipe.text_encoder,
207
+ returned_embeddings_type=EMBEDDINGS_TYPE,
208
+ dtype_for_device_getter=lambda _: pipe.dtype,
209
+ textual_inversion_manager=DiffusersTextualInversionManager(pipe),
210
+ )
211
 
212
+ images = []
213
+ current_seed = seed
214
+
215
+ try:
216
+ styled_negative_prompt = apply_style(negative_prompt, style, negative=True)
217
+ neg_embeds = compel(styled_negative_prompt)
218
+ except PromptParser.ParsingException:
219
+ raise Error("ParsingException: Invalid negative prompt")
220
+
221
+ for i in range(num_images):
222
+ # seeded generator for each iteration
223
+ generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
224
 
225
  try:
226
+ all_positive_prompts = parse_prompt(positive_prompt)
227
+ prompt_index = i % len(all_positive_prompts)
228
+ pos_prompt = all_positive_prompts[prompt_index]
229
+ styled_pos_prompt = apply_style(pos_prompt, style)
230
+ pos_embeds = compel(styled_pos_prompt)
231
+ pos_embeds, neg_embeds = compel.pad_conditioning_tensors_to_same_length(
232
+ [pos_embeds, neg_embeds]
233
+ )
234
  except PromptParser.ParsingException:
235
+ raise Error("ParsingException: Invalid prompt")
236
+
237
+ kwargs = {
238
+ "width": width,
239
+ "height": height,
240
+ "generator": generator,
241
+ "prompt_embeds": pos_embeds,
242
+ "guidance_scale": guidance_scale,
243
+ "negative_prompt_embeds": neg_embeds,
244
+ "num_inference_steps": inference_steps,
245
+ "output_type": "np" if scale > 1 else "pil",
246
+ }
247
+
248
+ if progress is not None:
249
+ kwargs["callback_on_step_end"] = callback_on_step_end
250
+
251
+ if KIND == "img2img":
252
+ kwargs["strength"] = denoising_strength
253
+ kwargs["image"] = prepare_image(image_prompt, (width, height))
254
+
255
+ if IP_ADAPTER:
256
+ # don't resize full-face images
257
+ size = None if ip_face else (width, height)
258
+ kwargs["ip_adapter_image"] = prepare_image(ip_image, size)
259
+
260
+ try:
261
+ image = pipe(**kwargs).images[0]
262
+ if scale > 1:
263
+ image = upscaler.predict(image)
264
+ images.append((image, str(current_seed)))
265
+ finally:
266
+ pipe.unload_textual_inversion()
267
+ torch.cuda.empty_cache()
268
+ torch.cuda.ipc_collect()
269
+
270
+ # increment seed for next image
271
+ current_seed += 1
272
+
273
+ diff = time.perf_counter() - start
274
+ if Info:
275
+ Info(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
276
+ return images