Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Commit
•
6829539
1
Parent(s):
9edebae
Custom progress bar
Browse files- lib/inference.py +127 -112
lib/inference.py
CHANGED
@@ -9,6 +9,7 @@ from itertools import product
|
|
9 |
from typing import Callable, TypeVar
|
10 |
|
11 |
import anyio
|
|
|
12 |
import numpy as np
|
13 |
import spaces
|
14 |
import torch
|
@@ -113,17 +114,16 @@ def generate(
|
|
113 |
guidance_scale=7.5,
|
114 |
inference_steps=50,
|
115 |
denoising_strength=0.8,
|
|
|
|
|
116 |
num_images=1,
|
117 |
karras=False,
|
118 |
taesd=False,
|
119 |
freeu=False,
|
120 |
clip_skip=False,
|
121 |
-
truncate_prompts=False,
|
122 |
-
increment_seed=True,
|
123 |
-
deepcache=1,
|
124 |
-
scale=1,
|
125 |
Info: Callable[[str], None] = None,
|
126 |
Error=Exception,
|
|
|
127 |
):
|
128 |
if not torch.cuda.is_available():
|
129 |
raise Error("CUDA not available")
|
@@ -134,12 +134,6 @@ def generate(
|
|
134 |
|
135 |
DEVICE = torch.device("cuda")
|
136 |
|
137 |
-
DTYPE = (
|
138 |
-
torch.bfloat16
|
139 |
-
if torch.cuda.is_available() and torch.cuda.get_device_properties(DEVICE).major >= 8
|
140 |
-
else torch.float16
|
141 |
-
)
|
142 |
-
|
143 |
EMBEDDINGS_TYPE = (
|
144 |
ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
|
145 |
if clip_skip
|
@@ -148,114 +142,135 @@ def generate(
|
|
148 |
|
149 |
KIND = "img2img" if image_prompt is not None else "txt2img"
|
150 |
|
151 |
-
|
152 |
|
153 |
if ip_image:
|
154 |
IP_ADAPTER = "full-face" if ip_face else "plus"
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
DEVICE,
|
170 |
-
DTYPE,
|
171 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
)
|
|
|
200 |
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
try:
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
except PromptParser.ParsingException:
|
208 |
-
raise Error("ParsingException: Invalid
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
kwargs =
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
images.append((image, str(current_seed)))
|
251 |
-
finally:
|
252 |
-
pipe.unload_textual_inversion()
|
253 |
-
torch.cuda.empty_cache()
|
254 |
-
|
255 |
-
if increment_seed:
|
256 |
-
current_seed += 1
|
257 |
-
|
258 |
-
diff = time.perf_counter() - start
|
259 |
-
if Info:
|
260 |
-
Info(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
|
261 |
-
return images
|
|
|
9 |
from typing import Callable, TypeVar
|
10 |
|
11 |
import anyio
|
12 |
+
import gradio as gr
|
13 |
import numpy as np
|
14 |
import spaces
|
15 |
import torch
|
|
|
114 |
guidance_scale=7.5,
|
115 |
inference_steps=50,
|
116 |
denoising_strength=0.8,
|
117 |
+
deepcache=1,
|
118 |
+
scale=1,
|
119 |
num_images=1,
|
120 |
karras=False,
|
121 |
taesd=False,
|
122 |
freeu=False,
|
123 |
clip_skip=False,
|
|
|
|
|
|
|
|
|
124 |
Info: Callable[[str], None] = None,
|
125 |
Error=Exception,
|
126 |
+
progress=gr.Progress(),
|
127 |
):
|
128 |
if not torch.cuda.is_available():
|
129 |
raise Error("CUDA not available")
|
|
|
134 |
|
135 |
DEVICE = torch.device("cuda")
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
EMBEDDINGS_TYPE = (
|
138 |
ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
|
139 |
if clip_skip
|
|
|
142 |
|
143 |
KIND = "img2img" if image_prompt is not None else "txt2img"
|
144 |
|
145 |
+
CURRENT_IMAGE = 1
|
146 |
|
147 |
if ip_image:
|
148 |
IP_ADAPTER = "full-face" if ip_face else "plus"
|
149 |
+
else:
|
150 |
+
IP_ADAPTER = ""
|
151 |
+
|
152 |
+
if progress is not None:
|
153 |
+
progress((0, inference_steps), desc=f"Generating image {CURRENT_IMAGE}/{num_images}")
|
154 |
+
|
155 |
+
def callback_on_step_end(pipeline, step, timestep, latents):
|
156 |
+
nonlocal CURRENT_IMAGE
|
157 |
+
strength = denoising_strength if KIND == "img2img" else 1
|
158 |
+
total_steps = min(int(inference_steps * strength), inference_steps)
|
159 |
+
current_step = step + 1
|
160 |
+
progress(
|
161 |
+
(current_step, total_steps),
|
162 |
+
desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
|
|
|
|
|
163 |
)
|
164 |
+
if current_step == total_steps:
|
165 |
+
CURRENT_IMAGE += 1
|
166 |
+
return latents
|
167 |
+
|
168 |
+
start = time.perf_counter()
|
169 |
+
loader = Loader()
|
170 |
+
pipe, upscaler = loader.load(
|
171 |
+
KIND,
|
172 |
+
IP_ADAPTER,
|
173 |
+
model,
|
174 |
+
scheduler,
|
175 |
+
karras,
|
176 |
+
taesd,
|
177 |
+
freeu,
|
178 |
+
deepcache,
|
179 |
+
scale,
|
180 |
+
DEVICE,
|
181 |
+
)
|
182 |
|
183 |
+
# load embeddings and append to negative prompt
|
184 |
+
embeddings_dir = os.path.join(os.path.dirname(__file__), "..", "embeddings")
|
185 |
+
embeddings_dir = os.path.abspath(embeddings_dir)
|
186 |
+
for embedding in embeddings:
|
187 |
+
try:
|
188 |
+
# wrap embeddings in angle brackets
|
189 |
+
pipe.load_textual_inversion(
|
190 |
+
pretrained_model_name_or_path=f"{embeddings_dir}/{embedding}.pt",
|
191 |
+
token=f"<{embedding}>",
|
192 |
+
)
|
193 |
+
# boost embeddings slightly
|
194 |
+
negative_prompt = (
|
195 |
+
f"{negative_prompt}, (<{embedding}>)1.1"
|
196 |
+
if negative_prompt
|
197 |
+
else f"(<{embedding}>)1.1"
|
198 |
+
)
|
199 |
+
except (EnvironmentError, HFValidationError, RepositoryNotFoundError):
|
200 |
+
raise Error(f"Invalid embedding: <{embedding}>")
|
201 |
+
|
202 |
+
# prompt embeds
|
203 |
+
compel = Compel(
|
204 |
+
device=pipe.device,
|
205 |
+
tokenizer=pipe.tokenizer,
|
206 |
+
text_encoder=pipe.text_encoder,
|
207 |
+
returned_embeddings_type=EMBEDDINGS_TYPE,
|
208 |
+
dtype_for_device_getter=lambda _: pipe.dtype,
|
209 |
+
textual_inversion_manager=DiffusersTextualInversionManager(pipe),
|
210 |
+
)
|
211 |
|
212 |
+
images = []
|
213 |
+
current_seed = seed
|
214 |
+
|
215 |
+
try:
|
216 |
+
styled_negative_prompt = apply_style(negative_prompt, style, negative=True)
|
217 |
+
neg_embeds = compel(styled_negative_prompt)
|
218 |
+
except PromptParser.ParsingException:
|
219 |
+
raise Error("ParsingException: Invalid negative prompt")
|
220 |
+
|
221 |
+
for i in range(num_images):
|
222 |
+
# seeded generator for each iteration
|
223 |
+
generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
|
224 |
|
225 |
try:
|
226 |
+
all_positive_prompts = parse_prompt(positive_prompt)
|
227 |
+
prompt_index = i % len(all_positive_prompts)
|
228 |
+
pos_prompt = all_positive_prompts[prompt_index]
|
229 |
+
styled_pos_prompt = apply_style(pos_prompt, style)
|
230 |
+
pos_embeds = compel(styled_pos_prompt)
|
231 |
+
pos_embeds, neg_embeds = compel.pad_conditioning_tensors_to_same_length(
|
232 |
+
[pos_embeds, neg_embeds]
|
233 |
+
)
|
234 |
except PromptParser.ParsingException:
|
235 |
+
raise Error("ParsingException: Invalid prompt")
|
236 |
+
|
237 |
+
kwargs = {
|
238 |
+
"width": width,
|
239 |
+
"height": height,
|
240 |
+
"generator": generator,
|
241 |
+
"prompt_embeds": pos_embeds,
|
242 |
+
"guidance_scale": guidance_scale,
|
243 |
+
"negative_prompt_embeds": neg_embeds,
|
244 |
+
"num_inference_steps": inference_steps,
|
245 |
+
"output_type": "np" if scale > 1 else "pil",
|
246 |
+
}
|
247 |
+
|
248 |
+
if progress is not None:
|
249 |
+
kwargs["callback_on_step_end"] = callback_on_step_end
|
250 |
+
|
251 |
+
if KIND == "img2img":
|
252 |
+
kwargs["strength"] = denoising_strength
|
253 |
+
kwargs["image"] = prepare_image(image_prompt, (width, height))
|
254 |
+
|
255 |
+
if IP_ADAPTER:
|
256 |
+
# don't resize full-face images
|
257 |
+
size = None if ip_face else (width, height)
|
258 |
+
kwargs["ip_adapter_image"] = prepare_image(ip_image, size)
|
259 |
+
|
260 |
+
try:
|
261 |
+
image = pipe(**kwargs).images[0]
|
262 |
+
if scale > 1:
|
263 |
+
image = upscaler.predict(image)
|
264 |
+
images.append((image, str(current_seed)))
|
265 |
+
finally:
|
266 |
+
pipe.unload_textual_inversion()
|
267 |
+
torch.cuda.empty_cache()
|
268 |
+
torch.cuda.ipc_collect()
|
269 |
+
|
270 |
+
# increment seed for next image
|
271 |
+
current_seed += 1
|
272 |
+
|
273 |
+
diff = time.perf_counter() - start
|
274 |
+
if Info:
|
275 |
+
Info(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
|
276 |
+
return images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|