Spaces:
Running
on
Zero
Running
on
Zero
Fix deepcache loading
Browse files- lib/loader.py +23 -7
lib/loader.py
CHANGED
@@ -28,6 +28,14 @@ class Loader:
|
|
28 |
cls._instance.log = Logger("Loader")
|
29 |
return cls._instance
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def _should_unload_ip_adapter(self, model="", ip_adapter=""):
|
32 |
# unload if model changed
|
33 |
if self.model and self.model.lower() != model.lower():
|
@@ -47,6 +55,13 @@ class Loader:
|
|
47 |
return True # img2img -> txt2img
|
48 |
return False
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
|
51 |
def _unload_ip_adapter(self):
|
52 |
if self.ip_adapter is None:
|
@@ -79,8 +94,10 @@ class Loader:
|
|
79 |
torch.cuda.reset_peak_memory_stats()
|
80 |
torch.cuda.synchronize()
|
81 |
|
82 |
-
def _unload(self, kind="", model="", ip_adapter=""):
|
83 |
to_unload = []
|
|
|
|
|
84 |
if self._should_unload_ip_adapter(model, ip_adapter):
|
85 |
self._unload_ip_adapter()
|
86 |
to_unload.append("ip_adapter")
|
@@ -178,13 +195,12 @@ class Loader:
|
|
178 |
|
179 |
def _load_deepcache(self, interval=1):
|
180 |
has_deepcache = hasattr(self.pipe, "deepcache")
|
|
|
|
|
181 |
if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
|
182 |
return
|
183 |
-
|
184 |
-
|
185 |
-
else:
|
186 |
-
self.log.info("Loading DeepCache")
|
187 |
-
self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
|
188 |
self.pipe.deepcache.set_params(cache_interval=interval)
|
189 |
self.pipe.deepcache.enable()
|
190 |
|
@@ -254,7 +270,7 @@ class Loader:
|
|
254 |
# defaults to float32
|
255 |
pipe_kwargs["torch_dtype"] = torch.float16
|
256 |
|
257 |
-
self._unload(kind, model, ip_adapter)
|
258 |
self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
|
259 |
|
260 |
# error loading model
|
|
|
28 |
cls._instance.log = Logger("Loader")
|
29 |
return cls._instance
|
30 |
|
31 |
+
def _should_unload_deepcache(self, interval=1):
|
32 |
+
has_deepcache = hasattr(self.pipe, "deepcache")
|
33 |
+
if has_deepcache and interval == 1:
|
34 |
+
return True
|
35 |
+
if has_deepcache and self.pipe.deepcache.params["cache_interval"] != interval:
|
36 |
+
return True
|
37 |
+
return False
|
38 |
+
|
39 |
def _should_unload_ip_adapter(self, model="", ip_adapter=""):
|
40 |
# unload if model changed
|
41 |
if self.model and self.model.lower() != model.lower():
|
|
|
55 |
return True # img2img -> txt2img
|
56 |
return False
|
57 |
|
58 |
+
def _unload_deepcache(self):
|
59 |
+
if self.pipe.deepcache is None:
|
60 |
+
return
|
61 |
+
self.log.info("Unloading DeepCache")
|
62 |
+
self.pipe.deepcache.disable()
|
63 |
+
delattr(self.pipe, "deepcache")
|
64 |
+
|
65 |
# https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
|
66 |
def _unload_ip_adapter(self):
|
67 |
if self.ip_adapter is None:
|
|
|
94 |
torch.cuda.reset_peak_memory_stats()
|
95 |
torch.cuda.synchronize()
|
96 |
|
97 |
+
def _unload(self, kind="", model="", ip_adapter="", deepcache=1):
|
98 |
to_unload = []
|
99 |
+
if self._should_unload_deepcache(deepcache):
|
100 |
+
self._unload_deepcache()
|
101 |
if self._should_unload_ip_adapter(model, ip_adapter):
|
102 |
self._unload_ip_adapter()
|
103 |
to_unload.append("ip_adapter")
|
|
|
195 |
|
196 |
def _load_deepcache(self, interval=1):
|
197 |
has_deepcache = hasattr(self.pipe, "deepcache")
|
198 |
+
if not has_deepcache and interval == 1:
|
199 |
+
return
|
200 |
if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
|
201 |
return
|
202 |
+
self.log.info("Loading DeepCache")
|
203 |
+
self.pipe.deepcache = DeepCacheSDHelper(self.pipe)
|
|
|
|
|
|
|
204 |
self.pipe.deepcache.set_params(cache_interval=interval)
|
205 |
self.pipe.deepcache.enable()
|
206 |
|
|
|
270 |
# defaults to float32
|
271 |
pipe_kwargs["torch_dtype"] = torch.float16
|
272 |
|
273 |
+
self._unload(kind, model, ip_adapter, deepcache)
|
274 |
self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
|
275 |
|
276 |
# error loading model
|