adamelliotfields commited on
Commit
80551a9
1 Parent(s): 083766b

Memory improvements

Browse files
Files changed (3) hide show
  1. lib/inference.py +6 -6
  2. lib/loader.py +104 -63
  3. lib/upscaler.py +4 -0
lib/inference.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import re
2
  import time
3
  from datetime import datetime
@@ -150,12 +151,7 @@ def generate(
150
 
151
  pipe = loader.pipe
152
  refiner = loader.refiner
153
-
154
- upscaler = None
155
- if scale == 2:
156
- upscaler = loader.upscaler_2x
157
- if scale == 4:
158
- upscaler = loader.upscaler_4x
159
 
160
  # prompt embeds for base and refiner
161
  compel_1 = Compel(
@@ -251,6 +247,10 @@ def generate(
251
  CURRENT_STEP = 0
252
  CURRENT_IMAGE += 1
253
 
 
 
 
 
254
  diff = time.perf_counter() - start
255
  if Info:
256
  Info(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
 
1
+ import gc
2
  import re
3
  import time
4
  from datetime import datetime
 
151
 
152
  pipe = loader.pipe
153
  refiner = loader.refiner
154
+ upscaler = loader.upscaler
 
 
 
 
 
155
 
156
  # prompt embeds for base and refiner
157
  compel_1 = Compel(
 
247
  CURRENT_STEP = 0
248
  CURRENT_IMAGE += 1
249
 
250
+ # cleanup
251
+ loader.collect()
252
+ gc.collect()
253
+
254
  diff = time.perf_counter() - start
255
  if Info:
256
  Info(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
lib/loader.py CHANGED
@@ -20,21 +20,25 @@ class Loader:
20
  cls._instance.pipe = None
21
  cls._instance.model = None
22
  cls._instance.refiner = None
23
- cls._instance.upscaler_2x = None
24
- cls._instance.upscaler_4x = None
25
  return cls._instance
26
 
27
- def _flush(self):
28
- gc.collect()
29
- torch.cuda.empty_cache()
30
- torch.cuda.ipc_collect()
31
- torch.cuda.reset_peak_memory_stats()
32
- torch.cuda.synchronize()
33
 
34
- def _should_unload_pipeline(self, model=""):
35
- if self.pipe is None:
36
  return False
37
- if self.model.lower() != model.lower():
 
 
 
 
 
38
  return True
39
  return False
40
 
@@ -46,31 +50,93 @@ class Loader:
46
  return True
47
  return False
48
 
49
- def _unload_deepcache(self):
50
- if self.pipe.deepcache is None:
51
- return
52
- print("Unloading DeepCache")
53
- self.pipe.deepcache.disable()
54
- delattr(self.pipe, "deepcache")
 
 
55
  if self.refiner is not None:
56
- if hasattr(self.refiner, "deepcache"):
57
- print("Unloading DeepCache for refiner")
58
- self.refiner.deepcache.disable()
59
- delattr(self.refiner, "deepcache")
 
 
 
 
 
60
 
61
- # don't unload refiner
62
- def _unload(self, model, deepcache):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  to_unload = []
64
- if self._should_unload_deepcache(deepcache):
65
  self._unload_deepcache()
 
 
 
 
 
 
 
 
 
 
 
 
66
  if self._should_unload_pipeline(model):
 
67
  to_unload.append("model")
68
  to_unload.append("pipe")
69
- for component in to_unload:
70
- delattr(self, component)
71
- self._flush()
72
  for component in to_unload:
73
  setattr(self, component, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def _load_deepcache(self, interval=1):
76
  pipe_has_deepcache = hasattr(self.pipe, "deepcache")
@@ -98,7 +164,7 @@ class Loader:
98
  pipeline = Config.PIPELINES[kind]
99
  if self.pipe is None:
100
  try:
101
- print(f"Loading {model}...")
102
  self.model = model
103
  if model.lower() in Config.MODEL_CHECKPOINTS.keys():
104
  self.pipe = pipeline.from_single_file(
@@ -112,6 +178,7 @@ class Loader:
112
  self.refiner.scheduler = self.pipe.scheduler
113
  self.refiner.tokenizer_2 = self.pipe.tokenizer_2
114
  self.refiner.text_encoder_2 = self.pipe.text_encoder_2
 
115
  except Exception as e:
116
  print(f"Error loading {model}: {e}")
117
  self.model = None
@@ -122,37 +189,11 @@ class Loader:
122
  if self.pipe is not None:
123
  self.pipe.set_progress_bar_config(disable=progress is not None)
124
 
125
- def _load_refiner(self, refiner, progress, **kwargs):
126
- if refiner and self.refiner is None:
127
- model = Config.REFINER_MODEL
128
- pipeline = Config.PIPELINES["img2img"]
129
- try:
130
- print(f"Loading {model}...")
131
- self.refiner = pipeline.from_pretrained(model, **kwargs).to("cuda")
132
- except Exception as e:
133
- print(f"Error loading {model}: {e}")
134
- self.refiner = None
135
- return
136
- if self.refiner is not None:
137
- self.refiner.set_progress_bar_config(disable=progress is not None)
138
-
139
- def _load_upscaler(self, scale=1):
140
- if scale == 2 and self.upscaler_2x is None:
141
- try:
142
- print("Loading 2x upscaler...")
143
- self.upscaler_2x = RealESRGAN(2, "cuda")
144
- self.upscaler_2x.load_weights()
145
- except Exception as e:
146
- print(f"Error loading 2x upscaler: {e}")
147
- self.upscaler_2x = None
148
- if scale == 4 and self.upscaler_4x is None:
149
- try:
150
- print("Loading 4x upscaler...")
151
- self.upscaler_4x = RealESRGAN(4, "cuda")
152
- self.upscaler_4x.load_weights()
153
- except Exception as e:
154
- print(f"Error loading 4x upscaler: {e}")
155
- self.upscaler_4x = None
156
 
157
  def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, progress):
158
  scheduler_kwargs = {
@@ -185,7 +226,7 @@ class Loader:
185
  "vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
186
  }
187
 
188
- self._unload(model, deepcache)
189
  self._load_pipeline(kind, model, progress, **pipe_kwargs)
190
 
191
  # error loading model
@@ -201,9 +242,9 @@ class Loader:
201
  # same model, different scheduler
202
  if self.model.lower() == model.lower():
203
  if not same_scheduler:
204
- print(f"Switching to {scheduler}...")
205
  if not same_karras:
206
- print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...")
207
  if not same_scheduler or not same_karras:
208
  self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
209
  if self.refiner is not None:
@@ -222,6 +263,6 @@ class Loader:
222
  "text_encoder_2": self.pipe.text_encoder_2,
223
  }
224
 
225
- self._load_refiner(refiner, progress, **refiner_kwargs)
226
  self._load_deepcache(deepcache)
227
  self._load_upscaler(scale)
 
20
  cls._instance.pipe = None
21
  cls._instance.model = None
22
  cls._instance.refiner = None
23
+ cls._instance.upscaler = None
 
24
  return cls._instance
25
 
26
+ def _should_offload_refiner(self, model=""):
27
+ if self.refiner is None:
28
+ return False
29
+ if self.model and self.model.lower() != model.lower():
30
+ return True
31
+ return False
32
 
33
+ def _should_unload_refiner(self, refiner=False):
34
+ if self.refiner is None:
35
  return False
36
+ if not refiner:
37
+ return True
38
+ return False
39
+
40
+ def _should_unload_upscaler(self, scale=1):
41
+ if self.upscaler is not None and self.upscaler.scale != scale:
42
  return True
43
  return False
44
 
 
50
  return True
51
  return False
52
 
53
+ def _should_unload_pipeline(self, model=""):
54
+ if self.pipe is None:
55
+ return False
56
+ if self.model and self.model.lower() != model.lower():
57
+ return True
58
+ return False
59
+
60
+ def _offload_refiner(self):
61
  if self.refiner is not None:
62
+ self.refiner.to("cpu", silence_dtype_warnings=True)
63
+ self.refiner.vae = None
64
+ self.refiner.scheduler = None
65
+ self.refiner.tokenizer_2 = None
66
+ self.refiner.text_encoder_2 = None
67
+
68
+ def _unload_refiner(self):
69
+ # already on CPU from offloading
70
+ print("Unloading refiner")
71
 
72
+ def _unload_upscaler(self):
73
+ print(f"Unloading {self.upscaler.scale}x upscaler")
74
+ self.upscaler.to("cpu")
75
+
76
+ def _unload_deepcache(self):
77
+ if self.pipe.deepcache is not None:
78
+ print("Unloading DeepCache")
79
+ self.pipe.deepcache.disable()
80
+ delattr(self.pipe, "deepcache")
81
+ if self.refiner is not None:
82
+ if hasattr(self.refiner, "deepcache"):
83
+ print("Unloading DeepCache for refiner")
84
+ self.refiner.deepcache.disable()
85
+ delattr(self.refiner, "deepcache")
86
+
87
+ def _unload_pipeline(self):
88
+ print(f"Unloading {self.model}")
89
+ self.pipe.to("cpu", silence_dtype_warnings=True)
90
+
91
+ def _unload(self, model, refiner, deepcache, scale):
92
  to_unload = []
93
+ if self._should_unload_deepcache(deepcache): # remove deepcache first
94
  self._unload_deepcache()
95
+
96
+ if self._should_offload_refiner(model):
97
+ self._offload_refiner()
98
+
99
+ if self._should_unload_refiner(refiner):
100
+ self._unload_refiner()
101
+ to_unload.append("refiner")
102
+
103
+ if self._should_unload_upscaler(scale):
104
+ self._unload_upscaler()
105
+ to_unload.append("upscaler")
106
+
107
  if self._should_unload_pipeline(model):
108
+ self._unload_pipeline()
109
  to_unload.append("model")
110
  to_unload.append("pipe")
111
+
112
+ self.collect()
 
113
  for component in to_unload:
114
  setattr(self, component, None)
115
+ gc.collect()
116
+
117
+ def _load_refiner(self, refiner, progress, **kwargs):
118
+ if refiner and self.refiner is None:
119
+ model = Config.REFINER_MODEL
120
+ pipeline = Config.PIPELINES["img2img"]
121
+ try:
122
+ print(f"Loading {model}")
123
+ self.refiner = pipeline.from_pretrained(model, **kwargs).to("cuda")
124
+ except Exception as e:
125
+ print(f"Error loading {model}: {e}")
126
+ self.refiner = None
127
+ return
128
+ if self.refiner is not None:
129
+ self.refiner.set_progress_bar_config(disable=progress is not None)
130
+
131
+ def _load_upscaler(self, scale=1):
132
+ if self.upscaler is None and scale > 1:
133
+ try:
134
+ print(f"Loading {scale}x upscaler")
135
+ self.upscaler = RealESRGAN(scale, device=self.pipe.device)
136
+ self.upscaler.load_weights()
137
+ except Exception as e:
138
+ print(f"Error loading {scale}x upscaler: {e}")
139
+ self.upscaler = None
140
 
141
  def _load_deepcache(self, interval=1):
142
  pipe_has_deepcache = hasattr(self.pipe, "deepcache")
 
164
  pipeline = Config.PIPELINES[kind]
165
  if self.pipe is None:
166
  try:
167
+ print(f"Loading {model}")
168
  self.model = model
169
  if model.lower() in Config.MODEL_CHECKPOINTS.keys():
170
  self.pipe = pipeline.from_single_file(
 
178
  self.refiner.scheduler = self.pipe.scheduler
179
  self.refiner.tokenizer_2 = self.pipe.tokenizer_2
180
  self.refiner.text_encoder_2 = self.pipe.text_encoder_2
181
+ self.refiner.to(self.pipe.device)
182
  except Exception as e:
183
  print(f"Error loading {model}: {e}")
184
  self.model = None
 
189
  if self.pipe is not None:
190
  self.pipe.set_progress_bar_config(disable=progress is not None)
191
 
192
+ def collect(self):
193
+ torch.cuda.empty_cache()
194
+ torch.cuda.ipc_collect()
195
+ torch.cuda.reset_peak_memory_stats()
196
+ torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, progress):
199
  scheduler_kwargs = {
 
226
  "vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
227
  }
228
 
229
+ self._unload(model, refiner, deepcache, scale)
230
  self._load_pipeline(kind, model, progress, **pipe_kwargs)
231
 
232
  # error loading model
 
242
  # same model, different scheduler
243
  if self.model.lower() == model.lower():
244
  if not same_scheduler:
245
+ print(f"Switching to {scheduler}")
246
  if not same_karras:
247
+ print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas")
248
  if not same_scheduler or not same_karras:
249
  self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
250
  if self.refiner is not None:
 
263
  "text_encoder_2": self.pipe.text_encoder_2,
264
  }
265
 
266
+ self._load_refiner(refiner, progress, **refiner_kwargs) # load refiner before deepcache
267
  self._load_deepcache(deepcache)
268
  self._load_upscaler(scale)
lib/upscaler.py CHANGED
@@ -264,6 +264,10 @@ class RealESRGAN:
264
  scale=scale,
265
  )
266
 
 
 
 
 
267
  def load_weights(self):
268
  assert self.scale in [2, 4], "You can download models only with scales: 2, 4"
269
  config = HF_MODELS[self.scale]
 
264
  scale=scale,
265
  )
266
 
267
+ def to(self, device):
268
+ self.device = device
269
+ self.model.to(device=device)
270
+
271
  def load_weights(self):
272
  assert self.scale in [2, 4], "You can download models only with scales: 2, 4"
273
  config = HF_MODELS[self.scale]