adamelliotfields commited on
Commit
039ff6d
1 Parent(s): 31ff262

Memory improvements

Browse files
Files changed (3) hide show
  1. lib/inference.py +8 -8
  2. lib/loader.py +104 -84
  3. lib/upscaler.py +6 -2
lib/inference.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import re
3
  import time
@@ -173,11 +174,11 @@ def generate(
173
  IP_ADAPTER,
174
  model,
175
  scheduler,
 
 
176
  karras,
177
  taesd,
178
  freeu,
179
- deepcache,
180
- scale,
181
  progress,
182
  )
183
 
@@ -185,12 +186,7 @@ def generate(
185
  raise Error(f"Error loading {model}")
186
 
187
  pipe = loader.pipe
188
- upscaler = None
189
-
190
- if scale == 2:
191
- upscaler = loader.upscaler_2x
192
- if scale == 4:
193
- upscaler = loader.upscaler_4x
194
 
195
  # load loras
196
  loras = []
@@ -311,6 +307,10 @@ def generate(
311
  CURRENT_STEP = 0
312
  CURRENT_IMAGE += 1
313
 
 
 
 
 
314
  diff = time.perf_counter() - start
315
  msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} done in {diff:.2f}s"
316
  log.info(msg)
 
1
+ import gc
2
  import os
3
  import re
4
  import time
 
174
  IP_ADAPTER,
175
  model,
176
  scheduler,
177
+ deepcache,
178
+ scale,
179
  karras,
180
  taesd,
181
  freeu,
 
 
182
  progress,
183
  )
184
 
 
186
  raise Error(f"Error loading {model}")
187
 
188
  pipe = loader.pipe
189
+ upscaler = loader.upscaler
 
 
 
 
 
190
 
191
  # load loras
192
  loras = []
 
307
  CURRENT_STEP = 0
308
  CURRENT_IMAGE += 1
309
 
310
+ # cleanup
311
+ loader.collect()
312
+ gc.collect()
313
+
314
  diff = time.perf_counter() - start
315
  msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} done in {diff:.2f}s"
316
  log.info(msg)
lib/loader.py CHANGED
@@ -22,12 +22,16 @@ class Loader:
22
  cls._instance = super().__new__(cls)
23
  cls._instance.pipe = None
24
  cls._instance.model = None
 
25
  cls._instance.ip_adapter = None
26
- cls._instance.upscaler_2x = None
27
- cls._instance.upscaler_4x = None
28
  cls._instance.log = Logger("Loader")
29
  return cls._instance
30
 
 
 
 
 
 
31
  def _should_unload_deepcache(self, interval=1):
32
  has_deepcache = hasattr(self.pipe, "deepcache")
33
  if has_deepcache and interval == 1:
@@ -55,60 +59,112 @@ class Loader:
55
  return True # img2img -> txt2img
56
  return False
57
 
 
 
 
 
 
 
 
 
58
  def _unload_deepcache(self):
59
- if self.pipe.deepcache is None:
60
- return
61
- self.log.info("Unloading DeepCache")
62
- self.pipe.deepcache.disable()
63
- delattr(self.pipe, "deepcache")
64
 
65
  # https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
66
  def _unload_ip_adapter(self):
67
- if self.ip_adapter is None:
68
- return
69
-
70
- self.log.info("Unloading IP-Adapter")
71
- if not isinstance(self.pipe, Config.PIPELINES["img2img"]):
72
- self.pipe.image_encoder = None
73
- self.pipe.register_to_config(image_encoder=[None, None])
74
-
75
- self.pipe.feature_extractor = None
76
- self.pipe.unet.encoder_hid_proj = None
77
- self.pipe.unet.config.encoder_hid_dim_type = None
78
- self.pipe.register_to_config(feature_extractor=[None, None])
79
-
80
- attn_procs = {}
81
- for name, value in self.pipe.unet.attn_processors.items():
82
- attn_processor_class = AttnProcessor2_0() # raises if not torch 2
83
- attn_procs[name] = (
84
- attn_processor_class
85
- if isinstance(value, IPAdapterAttnProcessor2_0)
86
- else value.__class__()
87
- )
88
- self.pipe.unet.set_attn_processor(attn_procs)
89
-
90
- def _flush(self):
91
- gc.collect()
92
- torch.cuda.empty_cache()
93
- torch.cuda.ipc_collect()
94
- torch.cuda.reset_peak_memory_stats()
95
- torch.cuda.synchronize()
 
 
96
 
97
- def _unload(self, kind="", model="", ip_adapter="", deepcache=1):
98
  to_unload = []
99
- if self._should_unload_deepcache(deepcache):
100
  self._unload_deepcache()
 
 
 
 
 
101
  if self._should_unload_ip_adapter(model, ip_adapter):
102
  self._unload_ip_adapter()
103
  to_unload.append("ip_adapter")
 
104
  if self._should_unload_pipeline(kind, model):
 
105
  to_unload.append("model")
106
  to_unload.append("pipe")
107
- for component in to_unload:
108
- delattr(self, component)
109
- self._flush()
110
  for component in to_unload:
111
  setattr(self, component, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def _load_ip_adapter(self, ip_adapter=""):
114
  if not self.ip_adapter and ip_adapter:
@@ -122,25 +178,6 @@ class Loader:
122
  self.pipe.set_ip_adapter_scale(0.5)
123
  self.ip_adapter = ip_adapter
124
 
125
- # upscalers don't need to be unloaded
126
- def _load_upscaler(self, scale=1):
127
- if scale == 2 and self.upscaler_2x is None:
128
- try:
129
- self.log.info("Loading 2x upscaler")
130
- self.upscaler_2x = RealESRGAN(2, "cuda")
131
- self.upscaler_2x.load_weights()
132
- except Exception as e:
133
- self.log.error(f"Error loading 2x upscaler: {e}")
134
- self.upscaler_2x = None
135
- if scale == 4 and self.upscaler_4x is None:
136
- try:
137
- self.log.info("Loading 4x upscaler")
138
- self.upscaler_4x = RealESRGAN(4, "cuda")
139
- self.upscaler_4x.load_weights()
140
- except Exception as e:
141
- self.log.error(f"Error loading 4x upscaler: {e}")
142
- self.upscaler_4x = None
143
-
144
  def _load_pipeline(
145
  self,
146
  kind,
@@ -203,28 +240,11 @@ class Loader:
203
  variant="fp16",
204
  ).to(self.pipe.device)
205
 
206
- def _load_deepcache(self, interval=1):
207
- has_deepcache = hasattr(self.pipe, "deepcache")
208
- if not has_deepcache and interval == 1:
209
- return
210
- if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
211
- return
212
- self.log.info("Loading DeepCache")
213
- self.pipe.deepcache = DeepCacheSDHelper(self.pipe)
214
- self.pipe.deepcache.set_params(cache_interval=interval)
215
- self.pipe.deepcache.enable()
216
-
217
- # https://github.com/ChenyangSi/FreeU
218
- def _load_freeu(self, freeu=False):
219
- block = self.pipe.unet.up_blocks[0]
220
- attrs = ["b1", "b2", "s1", "s2"]
221
- has_freeu = all(getattr(block, attr, None) is not None for attr in attrs)
222
- if has_freeu and not freeu:
223
- self.log.info("Disabling FreeU")
224
- self.pipe.disable_freeu()
225
- elif not has_freeu and freeu:
226
- self.log.info("Enabling FreeU")
227
- self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
228
 
229
  def load(
230
  self,
@@ -232,11 +252,11 @@ class Loader:
232
  ip_adapter,
233
  model,
234
  scheduler,
 
 
235
  karras,
236
  taesd,
237
  freeu,
238
- deepcache,
239
- scale,
240
  progress,
241
  ):
242
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
22
  cls._instance = super().__new__(cls)
23
  cls._instance.pipe = None
24
  cls._instance.model = None
25
+ cls._instance.upscaler = None
26
  cls._instance.ip_adapter = None
 
 
27
  cls._instance.log = Logger("Loader")
28
  return cls._instance
29
 
30
+ def _should_unload_upscaler(self, scale=1):
31
+ if self.upscaler is not None and self.upscaler.scale != scale:
32
+ return True
33
+ return False
34
+
35
  def _should_unload_deepcache(self, interval=1):
36
  has_deepcache = hasattr(self.pipe, "deepcache")
37
  if has_deepcache and interval == 1:
 
59
  return True # img2img -> txt2img
60
  return False
61
 
62
+ def _unload_upscaler(self):
63
+ if self.upscaler is not None:
64
+ start = time.perf_counter()
65
+ self.log.info(f"Unloading {self.upscaler.scale}x upscaler")
66
+ self.upscaler.to("cpu")
67
+ diff = time.perf_counter() - start
68
+ self.log.info(f"Unloading {self.upscaler.scale}x upscaler done in {diff:.2f}s")
69
+
70
  def _unload_deepcache(self):
71
+ if self.pipe.deepcache is not None:
72
+ self.log.info("Disabling DeepCache")
73
+ self.pipe.deepcache.disable()
74
+ delattr(self.pipe, "deepcache")
 
75
 
76
  # https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
77
  def _unload_ip_adapter(self):
78
+ if self.ip_adapter is not None:
79
+ start = time.perf_counter()
80
+ self.log.info("Unloading IP-Adapter")
81
+ if not isinstance(self.pipe, Config.PIPELINES["img2img"]):
82
+ self.pipe.image_encoder = None
83
+ self.pipe.register_to_config(image_encoder=[None, None])
84
+
85
+ self.pipe.feature_extractor = None
86
+ self.pipe.unet.encoder_hid_proj = None
87
+ self.pipe.unet.config.encoder_hid_dim_type = None
88
+ self.pipe.register_to_config(feature_extractor=[None, None])
89
+
90
+ attn_procs = {}
91
+ for name, value in self.pipe.unet.attn_processors.items():
92
+ attn_processor_class = AttnProcessor2_0() # raises if not torch 2
93
+ attn_procs[name] = (
94
+ attn_processor_class
95
+ if isinstance(value, IPAdapterAttnProcessor2_0)
96
+ else value.__class__()
97
+ )
98
+ self.pipe.unet.set_attn_processor(attn_procs)
99
+ diff = time.perf_counter() - start
100
+ self.log.info(f"Unloading IP-Adapter done in {diff:.2f}s")
101
+
102
+ def _unload_pipeline(self):
103
+ if self.pipe is not None:
104
+ start = time.perf_counter()
105
+ self.log.info(f"Unloading {self.model}")
106
+ self.pipe.to("cpu")
107
+ diff = time.perf_counter() - start
108
+ self.log.info(f"Unloading {self.model} done in {diff:.2f}s")
109
 
110
+ def _unload(self, kind="", model="", ip_adapter="", deepcache=1, scale=1):
111
  to_unload = []
112
+ if self._should_unload_deepcache(deepcache): # remove deepcache first
113
  self._unload_deepcache()
114
+
115
+ if self._should_unload_upscaler(scale):
116
+ self._unload_upscaler()
117
+ to_unload.append("upscaler")
118
+
119
  if self._should_unload_ip_adapter(model, ip_adapter):
120
  self._unload_ip_adapter()
121
  to_unload.append("ip_adapter")
122
+
123
  if self._should_unload_pipeline(kind, model):
124
+ self._unload_pipeline()
125
  to_unload.append("model")
126
  to_unload.append("pipe")
127
+
128
+ self.collect()
 
129
  for component in to_unload:
130
  setattr(self, component, None)
131
+ gc.collect()
132
+
133
+ def _load_upscaler(self, scale=1):
134
+ if self.upscaler is None and scale > 1:
135
+ try:
136
+ start = time.perf_counter()
137
+ self.log.info(f"Loading {scale}x upscaler")
138
+ self.upscaler = RealESRGAN(scale, device=self.pipe.device)
139
+ self.upscaler.load_weights()
140
+ diff = time.perf_counter() - start
141
+ self.log.info(f"Loading {scale}x upscaler done in {diff:.2f}s")
142
+ except Exception as e:
143
+ self.log.error(f"Error loading {scale}x upscaler: {e}")
144
+ self.upscaler = None
145
+
146
+ def _load_deepcache(self, interval=1):
147
+ has_deepcache = hasattr(self.pipe, "deepcache")
148
+ if not has_deepcache and interval == 1:
149
+ return
150
+ if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
151
+ return
152
+ self.log.info("Enabling DeepCache")
153
+ self.pipe.deepcache = DeepCacheSDHelper(self.pipe)
154
+ self.pipe.deepcache.set_params(cache_interval=interval)
155
+ self.pipe.deepcache.enable()
156
+
157
+ # https://github.com/ChenyangSi/FreeU
158
+ def _load_freeu(self, freeu=False):
159
+ block = self.pipe.unet.up_blocks[0]
160
+ attrs = ["b1", "b2", "s1", "s2"]
161
+ has_freeu = all(getattr(block, attr, None) is not None for attr in attrs)
162
+ if has_freeu and not freeu:
163
+ self.log.info("Disabling FreeU")
164
+ self.pipe.disable_freeu()
165
+ elif not has_freeu and freeu:
166
+ self.log.info("Enabling FreeU")
167
+ self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
168
 
169
  def _load_ip_adapter(self, ip_adapter=""):
170
  if not self.ip_adapter and ip_adapter:
 
178
  self.pipe.set_ip_adapter_scale(0.5)
179
  self.ip_adapter = ip_adapter
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  def _load_pipeline(
182
  self,
183
  kind,
 
240
  variant="fp16",
241
  ).to(self.pipe.device)
242
 
243
+ def collect(self):
244
+ torch.cuda.empty_cache()
245
+ torch.cuda.ipc_collect()
246
+ torch.cuda.reset_peak_memory_stats()
247
+ torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  def load(
250
  self,
 
252
  ip_adapter,
253
  model,
254
  scheduler,
255
+ deepcache,
256
+ scale,
257
  karras,
258
  taesd,
259
  freeu,
 
 
260
  progress,
261
  ):
262
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lib/upscaler.py CHANGED
@@ -266,6 +266,10 @@ class RealESRGAN:
266
  scale=scale,
267
  )
268
 
 
 
 
 
269
  def load_weights(self):
270
  assert self.scale in [2, 4], "You can download models only with scales: 2, 4"
271
  config = HF_MODELS[self.scale]
@@ -279,9 +283,8 @@ class RealESRGAN:
279
  self.model.load_state_dict(loadnet, strict=True)
280
  self.model.eval().to(device=self.device)
281
 
282
- @torch.cuda.amp.autocast()
283
  def predict(self, lr_image, batch_size=4, patches_size=192, padding=24, pad_size=15):
284
- scale = self.scale
285
  if not isinstance(lr_image, np.ndarray):
286
  lr_image = np.array(lr_image)
287
  if lr_image.min() < 0.0:
@@ -302,6 +305,7 @@ class RealESRGAN:
302
  for i in range(batch_size, image.shape[0], batch_size):
303
  res = torch.cat((res, self.model(image[i : i + batch_size])), 0)
304
 
 
305
  sr_image = einops.rearrange(res.clamp(0, 1), "b c h w -> b h w c").cpu().numpy()
306
  padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
307
  scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
 
266
  scale=scale,
267
  )
268
 
269
+ def to(self, device):
270
+ self.device = device
271
+ self.model.to(device=device)
272
+
273
  def load_weights(self):
274
  assert self.scale in [2, 4], "You can download models only with scales: 2, 4"
275
  config = HF_MODELS[self.scale]
 
283
  self.model.load_state_dict(loadnet, strict=True)
284
  self.model.eval().to(device=self.device)
285
 
286
+ @torch.autocast("cuda")
287
  def predict(self, lr_image, batch_size=4, patches_size=192, padding=24, pad_size=15):
 
288
  if not isinstance(lr_image, np.ndarray):
289
  lr_image = np.array(lr_image)
290
  if lr_image.min() < 0.0:
 
305
  for i in range(batch_size, image.shape[0], batch_size):
306
  res = torch.cat((res, self.model(image[i : i + batch_size])), 0)
307
 
308
+ scale = self.scale
309
  sr_image = einops.rearrange(res.clamp(0, 1), "b c h w -> b h w c").cpu().numpy()
310
  padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
311
  scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)