adamelliotfields commited on
Commit
0acf94b
·
verified ·
1 Parent(s): 035c0aa

Fix deepcache loading

Browse files
Files changed (1) hide show
  1. lib/loader.py +23 -7
lib/loader.py CHANGED
@@ -28,6 +28,14 @@ class Loader:
28
  cls._instance.log = Logger("Loader")
29
  return cls._instance
30
 
 
 
 
 
 
 
 
 
31
  def _should_unload_ip_adapter(self, model="", ip_adapter=""):
32
  # unload if model changed
33
  if self.model and self.model.lower() != model.lower():
@@ -47,6 +55,13 @@ class Loader:
47
  return True # img2img -> txt2img
48
  return False
49
 
 
 
 
 
 
 
 
50
  # https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
51
  def _unload_ip_adapter(self):
52
  if self.ip_adapter is None:
@@ -79,8 +94,10 @@ class Loader:
79
  torch.cuda.reset_peak_memory_stats()
80
  torch.cuda.synchronize()
81
 
82
- def _unload(self, kind="", model="", ip_adapter=""):
83
  to_unload = []
 
 
84
  if self._should_unload_ip_adapter(model, ip_adapter):
85
  self._unload_ip_adapter()
86
  to_unload.append("ip_adapter")
@@ -178,13 +195,12 @@ class Loader:
178
 
179
  def _load_deepcache(self, interval=1):
180
  has_deepcache = hasattr(self.pipe, "deepcache")
 
 
181
  if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
182
  return
183
- if has_deepcache:
184
- self.pipe.deepcache.disable()
185
- else:
186
- self.log.info("Loading DeepCache")
187
- self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
188
  self.pipe.deepcache.set_params(cache_interval=interval)
189
  self.pipe.deepcache.enable()
190
 
@@ -254,7 +270,7 @@ class Loader:
254
  # defaults to float32
255
  pipe_kwargs["torch_dtype"] = torch.float16
256
 
257
- self._unload(kind, model, ip_adapter)
258
  self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
259
 
260
  # error loading model
 
28
  cls._instance.log = Logger("Loader")
29
  return cls._instance
30
 
31
+ def _should_unload_deepcache(self, interval=1):
32
+ has_deepcache = hasattr(self.pipe, "deepcache")
33
+ if has_deepcache and interval == 1:
34
+ return True
35
+ if has_deepcache and self.pipe.deepcache.params["cache_interval"] != interval:
36
+ return True
37
+ return False
38
+
39
  def _should_unload_ip_adapter(self, model="", ip_adapter=""):
40
  # unload if model changed
41
  if self.model and self.model.lower() != model.lower():
 
55
  return True # img2img -> txt2img
56
  return False
57
 
58
+ def _unload_deepcache(self):
59
+ if self.pipe.deepcache is None:
60
+ return
61
+ self.log.info("Unloading DeepCache")
62
+ self.pipe.deepcache.disable()
63
+ delattr(self.pipe, "deepcache")
64
+
65
  # https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
66
  def _unload_ip_adapter(self):
67
  if self.ip_adapter is None:
 
94
  torch.cuda.reset_peak_memory_stats()
95
  torch.cuda.synchronize()
96
 
97
+ def _unload(self, kind="", model="", ip_adapter="", deepcache=1):
98
  to_unload = []
99
+ if self._should_unload_deepcache(deepcache):
100
+ self._unload_deepcache()
101
  if self._should_unload_ip_adapter(model, ip_adapter):
102
  self._unload_ip_adapter()
103
  to_unload.append("ip_adapter")
 
195
 
196
  def _load_deepcache(self, interval=1):
197
  has_deepcache = hasattr(self.pipe, "deepcache")
198
+ if not has_deepcache and interval == 1:
199
+ return
200
  if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
201
  return
202
+ self.log.info("Loading DeepCache")
203
+ self.pipe.deepcache = DeepCacheSDHelper(self.pipe)
 
 
 
204
  self.pipe.deepcache.set_params(cache_interval=interval)
205
  self.pipe.deepcache.enable()
206
 
 
270
  # defaults to float32
271
  pipe_kwargs["torch_dtype"] = torch.float16
272
 
273
+ self._unload(kind, model, ip_adapter, deepcache)
274
  self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
275
 
276
  # error loading model