Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Loader and inference improvements
Browse files- lib/inference.py +58 -40
- lib/loader.py +33 -32
lib/inference.py
CHANGED
@@ -21,8 +21,8 @@ from typing_extensions import ParamSpec
|
|
21 |
|
22 |
from .loader import Loader
|
23 |
|
24 |
-
__import__("transformers").logging.set_verbosity_error()
|
25 |
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
|
|
|
26 |
|
27 |
T = TypeVar("T")
|
28 |
P = ParamSpec("P")
|
@@ -45,17 +45,17 @@ async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T
|
|
45 |
return await anyio.to_thread.run_sync(partial_fn)
|
46 |
|
47 |
|
48 |
-
|
49 |
-
def parse_prompt(prompt: str) -> list[str]:
|
50 |
arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
|
51 |
|
52 |
if not arrays:
|
53 |
return [prompt]
|
54 |
|
55 |
-
tokens = [item.split(",") for item in arrays]
|
56 |
-
combinations = list(product(*tokens))
|
57 |
-
prompts = []
|
58 |
|
|
|
|
|
59 |
for combo in combinations:
|
60 |
current_prompt = prompt
|
61 |
for i, token in enumerate(combo):
|
@@ -71,8 +71,12 @@ def apply_style(prompt, style_id, negative=False):
|
|
71 |
for style in STYLES:
|
72 |
if style["id"] == style_id:
|
73 |
if negative:
|
74 |
-
return
|
|
|
|
|
|
|
75 |
else:
|
|
|
76 |
return style["prompt"].format(prompt=prompt)
|
77 |
return prompt
|
78 |
|
@@ -97,12 +101,18 @@ def prepare_image(input, size=None):
|
|
97 |
|
98 |
|
99 |
def gpu_duration(**kwargs):
|
100 |
-
|
|
|
|
|
|
|
101 |
scale = kwargs.get("scale", 1)
|
102 |
num_images = kwargs.get("num_images", 1)
|
|
|
|
|
|
|
103 |
if scale == 4:
|
104 |
duration += 5
|
105 |
-
return duration * num_images
|
106 |
|
107 |
|
108 |
@spaces.GPU(duration=gpu_duration)
|
@@ -116,7 +126,7 @@ def generate(
|
|
116 |
style=None,
|
117 |
seed=None,
|
118 |
model="Lykon/dreamshaper-8",
|
119 |
-
scheduler="
|
120 |
width=512,
|
121 |
height=512,
|
122 |
guidance_scale=7.5,
|
@@ -140,16 +150,17 @@ def generate(
|
|
140 |
if seed is None or seed < 0:
|
141 |
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
|
142 |
|
|
|
|
|
|
|
|
|
|
|
143 |
EMBEDDINGS_TYPE = (
|
144 |
ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
|
145 |
if clip_skip
|
146 |
else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
|
147 |
)
|
148 |
|
149 |
-
KIND = "img2img" if image_prompt is not None else "txt2img"
|
150 |
-
|
151 |
-
CURRENT_IMAGE = 1
|
152 |
-
|
153 |
if ip_image:
|
154 |
IP_ADAPTER = "full-face" if ip_face else "plus"
|
155 |
else:
|
@@ -162,23 +173,22 @@ def generate(
|
|
162 |
TQDM = True
|
163 |
|
164 |
def callback_on_step_end(pipeline, step, timestep, latents):
|
165 |
-
nonlocal CURRENT_IMAGE
|
166 |
if progress is None:
|
167 |
return latents
|
168 |
strength = denoising_strength if KIND == "img2img" else 1
|
169 |
total_steps = min(int(inference_steps * strength), inference_steps)
|
170 |
-
|
|
|
171 |
progress(
|
172 |
-
(
|
173 |
desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
|
174 |
)
|
175 |
-
if current_step == total_steps:
|
176 |
-
CURRENT_IMAGE += 1
|
177 |
return latents
|
178 |
|
179 |
start = time.perf_counter()
|
180 |
loader = Loader()
|
181 |
-
|
182 |
KIND,
|
183 |
IP_ADAPTER,
|
184 |
model,
|
@@ -191,6 +201,17 @@ def generate(
|
|
191 |
TQDM,
|
192 |
)
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
# load embeddings and append to negative prompt
|
195 |
embeddings_dir = os.path.join(os.path.dirname(__file__), "..", "embeddings")
|
196 |
embeddings_dir = os.path.abspath(embeddings_dir)
|
@@ -201,11 +222,8 @@ def generate(
|
|
201 |
pretrained_model_name_or_path=f"{embeddings_dir}/{embedding}.pt",
|
202 |
token=f"<{embedding}>",
|
203 |
)
|
204 |
-
# boost embeddings slightly
|
205 |
negative_prompt = (
|
206 |
-
f"{negative_prompt},
|
207 |
-
if negative_prompt
|
208 |
-
else f"(<{embedding}>)1.1"
|
209 |
)
|
210 |
except (EnvironmentError, HFValidationError, RepositoryNotFoundError):
|
211 |
raise Error(f"Invalid embedding: <{embedding}>")
|
@@ -225,33 +243,33 @@ def generate(
|
|
225 |
|
226 |
try:
|
227 |
styled_negative_prompt = apply_style(negative_prompt, style, negative=True)
|
228 |
-
|
229 |
except PromptParser.ParsingException:
|
230 |
-
raise Error("
|
231 |
|
232 |
for i in range(num_images):
|
233 |
# seeded generator for each iteration
|
234 |
generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
|
235 |
|
236 |
try:
|
237 |
-
all_positive_prompts =
|
238 |
prompt_index = i % len(all_positive_prompts)
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
[
|
244 |
)
|
245 |
except PromptParser.ParsingException:
|
246 |
-
raise Error("
|
247 |
|
248 |
kwargs = {
|
249 |
"width": width,
|
250 |
"height": height,
|
251 |
"generator": generator,
|
252 |
-
"prompt_embeds":
|
253 |
"guidance_scale": guidance_scale,
|
254 |
-
"negative_prompt_embeds":
|
255 |
"num_inference_steps": inference_steps,
|
256 |
"output_type": "np" if scale > 1 else "pil",
|
257 |
}
|
@@ -273,13 +291,13 @@ def generate(
|
|
273 |
if scale > 1:
|
274 |
image = upscaler.predict(image)
|
275 |
images.append((image, str(current_seed)))
|
|
|
|
|
|
|
276 |
finally:
|
277 |
pipe.unload_textual_inversion()
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
# increment seed for next image
|
282 |
-
current_seed += 1
|
283 |
|
284 |
diff = time.perf_counter() - start
|
285 |
if Info:
|
|
|
21 |
|
22 |
from .loader import Loader
|
23 |
|
|
|
24 |
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
|
25 |
+
__import__("transformers").logging.set_verbosity_error()
|
26 |
|
27 |
T = TypeVar("T")
|
28 |
P = ParamSpec("P")
|
|
|
45 |
return await anyio.to_thread.run_sync(partial_fn)
|
46 |
|
47 |
|
48 |
+
def parse_prompt_with_arrays(prompt: str) -> list[str]:
|
|
|
49 |
arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
|
50 |
|
51 |
if not arrays:
|
52 |
return [prompt]
|
53 |
|
54 |
+
tokens = [item.split(",") for item in arrays] # [("a", "b"), ("1", "2")]
|
55 |
+
combinations = list(product(*tokens)) # [("a", "1"), ("a", "2"), ("b", "1"), ("b", "2")]
|
|
|
56 |
|
57 |
+
# find all the arrays in the prompt and replace them with tokens
|
58 |
+
prompts = []
|
59 |
for combo in combinations:
|
60 |
current_prompt = prompt
|
61 |
for i, token in enumerate(combo):
|
|
|
71 |
for style in STYLES:
|
72 |
if style["id"] == style_id:
|
73 |
if negative:
|
74 |
+
return (
|
75 |
+
# prepend our negative prompt to the style's negative prompt
|
76 |
+
f"{prompt}, {style['negative_prompt']}" if prompt else style["negative_prompt"]
|
77 |
+
)
|
78 |
else:
|
79 |
+
# inject our positive prompt into the style prompt
|
80 |
return style["prompt"].format(prompt=prompt)
|
81 |
return prompt
|
82 |
|
|
|
101 |
|
102 |
|
103 |
def gpu_duration(**kwargs):
|
104 |
+
loading = 20
|
105 |
+
duration = 10
|
106 |
+
width = kwargs.get("width", 512)
|
107 |
+
height = kwargs.get("height", 512)
|
108 |
scale = kwargs.get("scale", 1)
|
109 |
num_images = kwargs.get("num_images", 1)
|
110 |
+
size = width * height
|
111 |
+
if size > 500_000:
|
112 |
+
duration += 5
|
113 |
if scale == 4:
|
114 |
duration += 5
|
115 |
+
return loading + (duration * num_images)
|
116 |
|
117 |
|
118 |
@spaces.GPU(duration=gpu_duration)
|
|
|
126 |
style=None,
|
127 |
seed=None,
|
128 |
model="Lykon/dreamshaper-8",
|
129 |
+
scheduler="DDIM",
|
130 |
width=512,
|
131 |
height=512,
|
132 |
guidance_scale=7.5,
|
|
|
150 |
if seed is None or seed < 0:
|
151 |
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
|
152 |
|
153 |
+
CURRENT_STEP = 0
|
154 |
+
CURRENT_IMAGE = 1
|
155 |
+
|
156 |
+
KIND = "img2img" if image_prompt is not None else "txt2img"
|
157 |
+
|
158 |
EMBEDDINGS_TYPE = (
|
159 |
ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
|
160 |
if clip_skip
|
161 |
else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
|
162 |
)
|
163 |
|
|
|
|
|
|
|
|
|
164 |
if ip_image:
|
165 |
IP_ADAPTER = "full-face" if ip_face else "plus"
|
166 |
else:
|
|
|
173 |
TQDM = True
|
174 |
|
175 |
def callback_on_step_end(pipeline, step, timestep, latents):
|
176 |
+
nonlocal CURRENT_STEP, CURRENT_IMAGE
|
177 |
if progress is None:
|
178 |
return latents
|
179 |
strength = denoising_strength if KIND == "img2img" else 1
|
180 |
total_steps = min(int(inference_steps * strength), inference_steps)
|
181 |
+
|
182 |
+
CURRENT_STEP = step + 1
|
183 |
progress(
|
184 |
+
(CURRENT_STEP, total_steps),
|
185 |
desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
|
186 |
)
|
|
|
|
|
187 |
return latents
|
188 |
|
189 |
start = time.perf_counter()
|
190 |
loader = Loader()
|
191 |
+
loader.load(
|
192 |
KIND,
|
193 |
IP_ADAPTER,
|
194 |
model,
|
|
|
201 |
TQDM,
|
202 |
)
|
203 |
|
204 |
+
if loader.pipe is None:
|
205 |
+
raise Error(f"RuntimeError: Error loading {model}")
|
206 |
+
|
207 |
+
pipe = loader.pipe
|
208 |
+
upscaler = None
|
209 |
+
|
210 |
+
if scale == 2:
|
211 |
+
upscaler = loader.upscaler_2x
|
212 |
+
if scale == 4:
|
213 |
+
upscaler = loader.upscaler_4x
|
214 |
+
|
215 |
# load embeddings and append to negative prompt
|
216 |
embeddings_dir = os.path.join(os.path.dirname(__file__), "..", "embeddings")
|
217 |
embeddings_dir = os.path.abspath(embeddings_dir)
|
|
|
222 |
pretrained_model_name_or_path=f"{embeddings_dir}/{embedding}.pt",
|
223 |
token=f"<{embedding}>",
|
224 |
)
|
|
|
225 |
negative_prompt = (
|
226 |
+
f"{negative_prompt}, <{embedding}>" if negative_prompt else f"<{embedding}>"
|
|
|
|
|
227 |
)
|
228 |
except (EnvironmentError, HFValidationError, RepositoryNotFoundError):
|
229 |
raise Error(f"Invalid embedding: <{embedding}>")
|
|
|
243 |
|
244 |
try:
|
245 |
styled_negative_prompt = apply_style(negative_prompt, style, negative=True)
|
246 |
+
negative_embeds = compel(styled_negative_prompt)
|
247 |
except PromptParser.ParsingException:
|
248 |
+
raise Error("ValueError: Invalid negative prompt")
|
249 |
|
250 |
for i in range(num_images):
|
251 |
# seeded generator for each iteration
|
252 |
generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
|
253 |
|
254 |
try:
|
255 |
+
all_positive_prompts = parse_prompt_with_arrays(positive_prompt)
|
256 |
prompt_index = i % len(all_positive_prompts)
|
257 |
+
prompt = all_positive_prompts[prompt_index]
|
258 |
+
prompt = apply_style(prompt, style)
|
259 |
+
positive_embeds = compel(prompt)
|
260 |
+
positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length(
|
261 |
+
[positive_embeds, negative_embeds]
|
262 |
)
|
263 |
except PromptParser.ParsingException:
|
264 |
+
raise Error("ValueError: Invalid prompt")
|
265 |
|
266 |
kwargs = {
|
267 |
"width": width,
|
268 |
"height": height,
|
269 |
"generator": generator,
|
270 |
+
"prompt_embeds": positive_embeds,
|
271 |
"guidance_scale": guidance_scale,
|
272 |
+
"negative_prompt_embeds": negative_embeds,
|
273 |
"num_inference_steps": inference_steps,
|
274 |
"output_type": "np" if scale > 1 else "pil",
|
275 |
}
|
|
|
291 |
if scale > 1:
|
292 |
image = upscaler.predict(image)
|
293 |
images.append((image, str(current_seed)))
|
294 |
+
current_seed += 1
|
295 |
+
except Exception as e:
|
296 |
+
raise Error(f"RuntimeError: {e}")
|
297 |
finally:
|
298 |
pipe.unload_textual_inversion()
|
299 |
+
CURRENT_STEP = 0
|
300 |
+
CURRENT_IMAGE += 1
|
|
|
|
|
|
|
301 |
|
302 |
diff = time.perf_counter() - start
|
303 |
if Info:
|
lib/loader.py
CHANGED
@@ -27,13 +27,11 @@ class Loader:
|
|
27 |
cls._instance = super().__new__(cls)
|
28 |
cls._instance.pipe = None
|
29 |
cls._instance.model = None
|
30 |
-
cls._instance.upscaler = None
|
31 |
cls._instance.ip_adapter = None
|
|
|
|
|
32 |
return cls._instance
|
33 |
|
34 |
-
def _should_unload_upscaler(self, scale=1):
|
35 |
-
return self.upscaler is not None and scale == 1
|
36 |
-
|
37 |
def _should_unload_ip_adapter(self, ip_adapter=""):
|
38 |
return self.ip_adapter is not None and not ip_adapter
|
39 |
|
@@ -78,25 +76,17 @@ class Loader:
|
|
78 |
torch.cuda.reset_peak_memory_stats()
|
79 |
torch.cuda.synchronize()
|
80 |
|
81 |
-
def _unload(self, kind="", model="", ip_adapter=""
|
82 |
to_unload = []
|
83 |
-
|
84 |
-
if self._should_unload_upscaler(scale):
|
85 |
-
to_unload.append("upscaler")
|
86 |
-
|
87 |
if self._should_unload_ip_adapter(ip_adapter):
|
88 |
self._unload_ip_adapter()
|
89 |
to_unload.append("ip_adapter")
|
90 |
-
|
91 |
if self._should_unload_pipeline(kind, model):
|
92 |
to_unload.append("model")
|
93 |
to_unload.append("pipe")
|
94 |
-
|
95 |
for component in to_unload:
|
96 |
delattr(self, component)
|
97 |
-
|
98 |
self._flush()
|
99 |
-
|
100 |
for component in to_unload:
|
101 |
setattr(self, component, None)
|
102 |
|
@@ -112,35 +102,46 @@ class Loader:
|
|
112 |
self.pipe.set_ip_adapter_scale(0.5)
|
113 |
self.ip_adapter = ip_adapter
|
114 |
|
115 |
-
def _load_upscaler(self, scale=1
|
116 |
-
if scale
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
-
def _load_pipeline(self, kind, model, tqdm,
|
122 |
pipeline = Config.PIPELINES[kind]
|
123 |
if self.pipe is None:
|
124 |
-
print(f"Loading {model}...")
|
125 |
try:
|
|
|
|
|
126 |
if model.lower() in Config.MODEL_CHECKPOINTS.keys():
|
127 |
self.pipe = pipeline.from_single_file(
|
128 |
f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
|
129 |
**kwargs,
|
130 |
-
).to(
|
131 |
else:
|
132 |
-
self.pipe = pipeline.from_pretrained(model, **kwargs).to(
|
133 |
-
self.model = model
|
134 |
except Exception as e:
|
135 |
print(f"Error loading {model}: {e}")
|
136 |
self.model = None
|
137 |
self.pipe = None
|
138 |
return
|
139 |
-
|
140 |
if not isinstance(self.pipe, pipeline):
|
141 |
-
self.pipe = pipeline.from_pipe(self.pipe).to(
|
142 |
-
|
143 |
-
|
144 |
|
145 |
def _load_vae(self, taesd=False, model=""):
|
146 |
vae_type = type(self.pipe.vae)
|
@@ -251,14 +252,15 @@ class Loader:
|
|
251 |
else torch.float16
|
252 |
)
|
253 |
else:
|
|
|
254 |
pipe_kwargs["torch_dtype"] = torch.float16
|
255 |
|
256 |
-
self._unload(kind, model, ip_adapter
|
257 |
-
self._load_pipeline(kind, model, tqdm,
|
258 |
|
259 |
# error loading model
|
260 |
if self.pipe is None:
|
261 |
-
return
|
262 |
|
263 |
same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
|
264 |
same_karras = (
|
@@ -279,5 +281,4 @@ class Loader:
|
|
279 |
self._load_vae(taesd, model)
|
280 |
self._load_deepcache(deepcache)
|
281 |
self._load_ip_adapter(ip_adapter)
|
282 |
-
self._load_upscaler(scale
|
283 |
-
return self.pipe, self.upscaler
|
|
|
27 |
cls._instance = super().__new__(cls)
|
28 |
cls._instance.pipe = None
|
29 |
cls._instance.model = None
|
|
|
30 |
cls._instance.ip_adapter = None
|
31 |
+
cls._instance.upscaler_2x = None
|
32 |
+
cls._instance.upscaler_4x = None
|
33 |
return cls._instance
|
34 |
|
|
|
|
|
|
|
35 |
def _should_unload_ip_adapter(self, ip_adapter=""):
|
36 |
return self.ip_adapter is not None and not ip_adapter
|
37 |
|
|
|
76 |
torch.cuda.reset_peak_memory_stats()
|
77 |
torch.cuda.synchronize()
|
78 |
|
79 |
+
def _unload(self, kind="", model="", ip_adapter=""):
|
80 |
to_unload = []
|
|
|
|
|
|
|
|
|
81 |
if self._should_unload_ip_adapter(ip_adapter):
|
82 |
self._unload_ip_adapter()
|
83 |
to_unload.append("ip_adapter")
|
|
|
84 |
if self._should_unload_pipeline(kind, model):
|
85 |
to_unload.append("model")
|
86 |
to_unload.append("pipe")
|
|
|
87 |
for component in to_unload:
|
88 |
delattr(self, component)
|
|
|
89 |
self._flush()
|
|
|
90 |
for component in to_unload:
|
91 |
setattr(self, component, None)
|
92 |
|
|
|
102 |
self.pipe.set_ip_adapter_scale(0.5)
|
103 |
self.ip_adapter = ip_adapter
|
104 |
|
105 |
+
def _load_upscaler(self, scale=1):
|
106 |
+
if scale == 2 and self.upscaler_2x is None:
|
107 |
+
try:
|
108 |
+
print("Loading 2x upscaler...")
|
109 |
+
self.upscaler_2x = RealESRGAN(2, "cuda")
|
110 |
+
self.upscaler_2x.load_weights()
|
111 |
+
except Exception as e:
|
112 |
+
print(f"Error loading 2x upscaler: {e}")
|
113 |
+
self.upscaler_2x = None
|
114 |
+
if scale == 4 and self.upscaler_4x is None:
|
115 |
+
try:
|
116 |
+
print("Loading 4x upscaler...")
|
117 |
+
self.upscaler_4x = RealESRGAN(4, "cuda")
|
118 |
+
self.upscaler_4x.load_weights()
|
119 |
+
except Exception as e:
|
120 |
+
print(f"Error loading 4x upscaler: {e}")
|
121 |
+
self.upscaler_4x = None
|
122 |
|
123 |
+
def _load_pipeline(self, kind, model, tqdm, **kwargs):
|
124 |
pipeline = Config.PIPELINES[kind]
|
125 |
if self.pipe is None:
|
|
|
126 |
try:
|
127 |
+
print(f"Loading {model}...")
|
128 |
+
self.model = model
|
129 |
if model.lower() in Config.MODEL_CHECKPOINTS.keys():
|
130 |
self.pipe = pipeline.from_single_file(
|
131 |
f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
|
132 |
**kwargs,
|
133 |
+
).to("cuda")
|
134 |
else:
|
135 |
+
self.pipe = pipeline.from_pretrained(model, **kwargs).to("cuda")
|
|
|
136 |
except Exception as e:
|
137 |
print(f"Error loading {model}: {e}")
|
138 |
self.model = None
|
139 |
self.pipe = None
|
140 |
return
|
|
|
141 |
if not isinstance(self.pipe, pipeline):
|
142 |
+
self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
|
143 |
+
if self.pipe is not None:
|
144 |
+
self.pipe.set_progress_bar_config(disable=not tqdm)
|
145 |
|
146 |
def _load_vae(self, taesd=False, model=""):
|
147 |
vae_type = type(self.pipe.vae)
|
|
|
252 |
else torch.float16
|
253 |
)
|
254 |
else:
|
255 |
+
# defaults to float32
|
256 |
pipe_kwargs["torch_dtype"] = torch.float16
|
257 |
|
258 |
+
self._unload(kind, model, ip_adapter)
|
259 |
+
self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
|
260 |
|
261 |
# error loading model
|
262 |
if self.pipe is None:
|
263 |
+
return
|
264 |
|
265 |
same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
|
266 |
same_karras = (
|
|
|
281 |
self._load_vae(taesd, model)
|
282 |
self._load_deepcache(deepcache)
|
283 |
self._load_ip_adapter(ip_adapter)
|
284 |
+
self._load_upscaler(scale)
|
|