Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Commit
•
767128b
1
Parent(s):
6681256
Loading and inferencing improvements
Browse files- lib/inference.py +1 -4
- lib/loader.py +23 -15
- lib/upscaler.py +1 -1
lib/inference.py
CHANGED
@@ -21,8 +21,8 @@ from typing_extensions import ParamSpec
|
|
21 |
|
22 |
from .loader import Loader
|
23 |
|
24 |
-
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
|
25 |
__import__("transformers").logging.set_verbosity_error()
|
|
|
26 |
|
27 |
T = TypeVar("T")
|
28 |
P = ParamSpec("P")
|
@@ -140,8 +140,6 @@ def generate(
|
|
140 |
if seed is None or seed < 0:
|
141 |
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
|
142 |
|
143 |
-
DEVICE = torch.device("cuda")
|
144 |
-
|
145 |
EMBEDDINGS_TYPE = (
|
146 |
ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
|
147 |
if clip_skip
|
@@ -191,7 +189,6 @@ def generate(
|
|
191 |
deepcache,
|
192 |
scale,
|
193 |
TQDM,
|
194 |
-
DEVICE,
|
195 |
)
|
196 |
|
197 |
# load embeddings and append to negative prompt
|
|
|
21 |
|
22 |
from .loader import Loader
|
23 |
|
|
|
24 |
__import__("transformers").logging.set_verbosity_error()
|
25 |
+
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
|
26 |
|
27 |
T = TypeVar("T")
|
28 |
P = ParamSpec("P")
|
|
|
140 |
if seed is None or seed < 0:
|
141 |
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
|
142 |
|
|
|
|
|
143 |
EMBEDDINGS_TYPE = (
|
144 |
ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
|
145 |
if clip_skip
|
|
|
189 |
deepcache,
|
190 |
scale,
|
191 |
TQDM,
|
|
|
192 |
)
|
193 |
|
194 |
# load embeddings and append to negative prompt
|
lib/loader.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gc
|
2 |
from threading import Lock
|
|
|
3 |
|
4 |
import torch
|
5 |
from DeepCache import DeepCacheSDHelper
|
@@ -11,9 +12,9 @@ from torch._dynamo import OptimizedModule
|
|
11 |
from .config import Config
|
12 |
from .upscaler import RealESRGAN
|
13 |
|
14 |
-
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="diffusers")
|
15 |
-
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="torch")
|
16 |
__import__("diffusers").logging.set_verbosity_error()
|
|
|
|
|
17 |
|
18 |
|
19 |
class Loader:
|
@@ -69,6 +70,14 @@ class Loader:
|
|
69 |
)
|
70 |
self.pipe.unet.set_attn_processor(attn_procs)
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def _unload(self, kind="", model="", ip_adapter="", scale=1):
|
73 |
to_unload = []
|
74 |
|
@@ -86,11 +95,7 @@ class Loader:
|
|
86 |
for component in to_unload:
|
87 |
delattr(self, component)
|
88 |
|
89 |
-
|
90 |
-
torch.cuda.empty_cache()
|
91 |
-
torch.cuda.ipc_collect()
|
92 |
-
torch.cuda.reset_max_memory_allocated()
|
93 |
-
torch.cuda.reset_peak_memory_stats()
|
94 |
|
95 |
for component in to_unload:
|
96 |
setattr(self, component, None)
|
@@ -107,10 +112,10 @@ class Loader:
|
|
107 |
self.pipe.set_ip_adapter_scale(0.5)
|
108 |
self.ip_adapter = ip_adapter
|
109 |
|
110 |
-
def _load_upscaler(self,
|
111 |
if scale > 1 and self.upscaler is None:
|
112 |
print(f"Loading {scale}x upscaler...")
|
113 |
-
self.upscaler = RealESRGAN(
|
114 |
self.upscaler.load_weights()
|
115 |
|
116 |
def _load_pipeline(self, kind, model, tqdm, device, **kwargs):
|
@@ -207,8 +212,9 @@ class Loader:
|
|
207 |
deepcache,
|
208 |
scale,
|
209 |
tqdm,
|
210 |
-
device,
|
211 |
):
|
|
|
|
|
212 |
scheduler_kwargs = {
|
213 |
"beta_schedule": "scaled_linear",
|
214 |
"timestep_spacing": "leading",
|
@@ -237,20 +243,22 @@ class Loader:
|
|
237 |
else:
|
238 |
pipe_kwargs["variant"] = None
|
239 |
|
240 |
-
# convert fp32 to bf16
|
241 |
if model.lower() in ["linaqruf/anything-v3-1"]:
|
242 |
pipe_kwargs["torch_dtype"] = (
|
243 |
torch.bfloat16
|
244 |
if torch.cuda.get_device_properties(device).major >= 8
|
245 |
else torch.float16
|
246 |
)
|
|
|
|
|
247 |
|
248 |
self._unload(kind, model, ip_adapter, scale)
|
249 |
self._load_pipeline(kind, model, tqdm, device, **pipe_kwargs)
|
250 |
|
251 |
# error loading model
|
252 |
if self.pipe is None:
|
253 |
-
return
|
254 |
|
255 |
same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
|
256 |
same_karras = (
|
@@ -267,9 +275,9 @@ class Loader:
|
|
267 |
if not same_scheduler or not same_karras:
|
268 |
self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
|
269 |
|
270 |
-
self._load_upscaler(device, scale)
|
271 |
-
self._load_ip_adapter(ip_adapter)
|
272 |
-
self._load_vae(taesd, model)
|
273 |
self._load_freeu(freeu)
|
|
|
274 |
self._load_deepcache(deepcache)
|
|
|
|
|
275 |
return self.pipe, self.upscaler
|
|
|
1 |
import gc
|
2 |
from threading import Lock
|
3 |
+
from warnings import filterwarnings
|
4 |
|
5 |
import torch
|
6 |
from DeepCache import DeepCacheSDHelper
|
|
|
12 |
from .config import Config
|
13 |
from .upscaler import RealESRGAN
|
14 |
|
|
|
|
|
15 |
__import__("diffusers").logging.set_verbosity_error()
|
16 |
+
filterwarnings("ignore", category=FutureWarning, module="torch")
|
17 |
+
filterwarnings("ignore", category=FutureWarning, module="diffusers")
|
18 |
|
19 |
|
20 |
class Loader:
|
|
|
70 |
)
|
71 |
self.pipe.unet.set_attn_processor(attn_procs)
|
72 |
|
73 |
+
def _flush(self):
|
74 |
+
gc.collect()
|
75 |
+
torch.cuda.empty_cache()
|
76 |
+
torch.cuda.ipc_collect()
|
77 |
+
torch.cuda.reset_max_memory_allocated()
|
78 |
+
torch.cuda.reset_peak_memory_stats()
|
79 |
+
torch.cuda.synchronize()
|
80 |
+
|
81 |
def _unload(self, kind="", model="", ip_adapter="", scale=1):
|
82 |
to_unload = []
|
83 |
|
|
|
95 |
for component in to_unload:
|
96 |
delattr(self, component)
|
97 |
|
98 |
+
self._flush()
|
|
|
|
|
|
|
|
|
99 |
|
100 |
for component in to_unload:
|
101 |
setattr(self, component, None)
|
|
|
112 |
self.pipe.set_ip_adapter_scale(0.5)
|
113 |
self.ip_adapter = ip_adapter
|
114 |
|
115 |
+
def _load_upscaler(self, scale=1, device=None):
|
116 |
if scale > 1 and self.upscaler is None:
|
117 |
print(f"Loading {scale}x upscaler...")
|
118 |
+
self.upscaler = RealESRGAN(scale, device)
|
119 |
self.upscaler.load_weights()
|
120 |
|
121 |
def _load_pipeline(self, kind, model, tqdm, device, **kwargs):
|
|
|
212 |
deepcache,
|
213 |
scale,
|
214 |
tqdm,
|
|
|
215 |
):
|
216 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
217 |
+
|
218 |
scheduler_kwargs = {
|
219 |
"beta_schedule": "scaled_linear",
|
220 |
"timestep_spacing": "leading",
|
|
|
243 |
else:
|
244 |
pipe_kwargs["variant"] = None
|
245 |
|
246 |
+
# convert fp32 to bf16 if possible
|
247 |
if model.lower() in ["linaqruf/anything-v3-1"]:
|
248 |
pipe_kwargs["torch_dtype"] = (
|
249 |
torch.bfloat16
|
250 |
if torch.cuda.get_device_properties(device).major >= 8
|
251 |
else torch.float16
|
252 |
)
|
253 |
+
else:
|
254 |
+
pipe_kwargs["torch_dtype"] = torch.float16
|
255 |
|
256 |
self._unload(kind, model, ip_adapter, scale)
|
257 |
self._load_pipeline(kind, model, tqdm, device, **pipe_kwargs)
|
258 |
|
259 |
# error loading model
|
260 |
if self.pipe is None:
|
261 |
+
return None, None
|
262 |
|
263 |
same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
|
264 |
same_karras = (
|
|
|
275 |
if not same_scheduler or not same_karras:
|
276 |
self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
|
277 |
|
|
|
|
|
|
|
278 |
self._load_freeu(freeu)
|
279 |
+
self._load_vae(taesd, model)
|
280 |
self._load_deepcache(deepcache)
|
281 |
+
self._load_ip_adapter(ip_adapter)
|
282 |
+
self._load_upscaler(scale, device)
|
283 |
return self.pipe, self.upscaler
|
lib/upscaler.py
CHANGED
@@ -254,7 +254,7 @@ class RRDBNet(nn.Module):
|
|
254 |
|
255 |
|
256 |
class RealESRGAN:
|
257 |
-
def __init__(self,
|
258 |
self.device = device
|
259 |
self.scale = scale
|
260 |
self.model = RRDBNet(
|
|
|
254 |
|
255 |
|
256 |
class RealESRGAN:
|
257 |
+
def __init__(self, scale=2, device=None):
|
258 |
self.device = device
|
259 |
self.scale = scale
|
260 |
self.model = RRDBNet(
|