adamelliotfields commited on
Commit
67ca03a
1 Parent(s): 11ee0ea

Don't unload refiner and upscaler

Browse files
Files changed (2) hide show
  1. lib/inference.py +12 -4
  2. lib/loader.py +36 -25
lib/inference.py CHANGED
@@ -154,7 +154,7 @@ def generate(
154
 
155
  start = time.perf_counter()
156
  loader = Loader()
157
- pipe, refiner, upscaler = loader.load(
158
  KIND,
159
  model,
160
  scheduler,
@@ -165,6 +165,15 @@ def generate(
165
  TQDM,
166
  )
167
 
 
 
 
 
 
 
 
 
 
168
  # prompt embeds for base and refiner
169
  compel_1 = Compel(
170
  text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
@@ -204,7 +213,7 @@ def generate(
204
  # refiner expects latents; upscaler expects numpy array
205
  pipe_output_type = "pil"
206
  refiner_output_type = "pil"
207
- if refiner:
208
  pipe_output_type = "latent"
209
  if scale > 1:
210
  refiner_output_type = "np"
@@ -215,7 +224,7 @@ def generate(
215
  pipe_kwargs = {
216
  "width": width,
217
  "height": height,
218
- "denoising_end": 0.8 if refiner else None,
219
  "generator": generator,
220
  "output_type": pipe_output_type,
221
  "guidance_scale": guidance_scale,
@@ -255,7 +264,6 @@ def generate(
255
  except Exception as e:
256
  raise Error(f"RuntimeError: {e}")
257
  finally:
258
- # reset step and increment image
259
  CURRENT_STEP = 0
260
  CURRENT_IMAGE += 1
261
  current_seed += 1
 
154
 
155
  start = time.perf_counter()
156
  loader = Loader()
157
+ loader.load(
158
  KIND,
159
  model,
160
  scheduler,
 
165
  TQDM,
166
  )
167
 
168
+ pipe = loader.pipe
169
+ refiner = loader.refiner
170
+
171
+ upscaler = None
172
+ if scale == 2:
173
+ upscaler = loader.upscaler_2x
174
+ if scale == 4:
175
+ upscaler = loader.upscaler_4x
176
+
177
  # prompt embeds for base and refiner
178
  compel_1 = Compel(
179
  text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
 
213
  # refiner expects latents; upscaler expects numpy array
214
  pipe_output_type = "pil"
215
  refiner_output_type = "pil"
216
+ if use_refiner:
217
  pipe_output_type = "latent"
218
  if scale > 1:
219
  refiner_output_type = "np"
 
224
  pipe_kwargs = {
225
  "width": width,
226
  "height": height,
227
+ "denoising_end": 0.8 if use_refiner else None,
228
  "generator": generator,
229
  "output_type": pipe_output_type,
230
  "guidance_scale": guidance_scale,
 
264
  except Exception as e:
265
  raise Error(f"RuntimeError: {e}")
266
  finally:
 
267
  CURRENT_STEP = 0
268
  CURRENT_IMAGE += 1
269
  current_seed += 1
lib/loader.py CHANGED
@@ -25,7 +25,8 @@ class Loader:
25
  cls._instance.pipe = None
26
  cls._instance.model = None
27
  cls._instance.refiner = None
28
- cls._instance.upscaler = None
 
29
  return cls._instance
30
 
31
  def _flush(self):
@@ -43,23 +44,11 @@ class Loader:
43
  return True
44
  return False
45
 
46
- def _should_unload_refiner(self, refiner):
47
- if self.refiner is not None and not refiner:
48
- return True
49
- return False
50
-
51
- def _should_unload_upscaler(self, scale=1):
52
- return self.upscaler is not None and scale == 1
53
-
54
- def _unload(self, model, refiner, scale):
55
  to_unload = []
56
  if self._should_unload_pipeline(model):
57
  to_unload.append("model")
58
  to_unload.append("pipe")
59
- if self._should_unload_refiner(refiner):
60
- to_unload.append("refiner")
61
- if self._should_unload_upscaler(scale):
62
- to_unload.append("upscaler")
63
  for component in to_unload:
64
  delattr(self, component)
65
  self._flush()
@@ -89,7 +78,7 @@ class Loader:
89
  self.pipe.set_progress_bar_config(disable=not tqdm)
90
 
91
  def _load_refiner(self, refiner, tqdm, **kwargs):
92
- if self.refiner is None and refiner:
93
  model = Config.REFINER_MODEL
94
  pipeline = Config.PIPELINES["img2img"]
95
  try:
@@ -103,22 +92,45 @@ class Loader:
103
  self.refiner.set_progress_bar_config(disable=not tqdm)
104
 
105
  def _load_upscaler(self, scale=1):
106
- if scale > 1 and self.upscaler is None:
107
- print(f"Loading {scale}x upscaler...")
108
- self.upscaler = RealESRGAN(scale, "cuda")
109
- self.upscaler.load_weights()
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def _load_deepcache(self, interval=1):
112
- has_deepcache = hasattr(self.pipe, "deepcache")
113
- if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
114
  return
115
- if has_deepcache:
116
  self.pipe.deepcache.disable()
117
  else:
118
  self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
119
  self.pipe.deepcache.set_params(cache_interval=interval)
120
  self.pipe.deepcache.enable()
121
 
 
 
 
 
 
 
 
 
 
 
 
122
  def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, tqdm):
123
  model_lower = model.lower()
124
 
@@ -153,12 +165,12 @@ class Loader:
153
  "vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
154
  }
155
 
156
- self._unload(model, refiner, scale)
157
  self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
158
 
159
  # error loading model
160
  if self.pipe is None:
161
- return None, None, None
162
 
163
  same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
164
  same_karras = (
@@ -193,4 +205,3 @@ class Loader:
193
  self._load_refiner(refiner, tqdm, **refiner_kwargs)
194
  self._load_upscaler(scale)
195
  self._load_deepcache(deepcache)
196
- return self.pipe, self.refiner, self.upscaler
 
25
  cls._instance.pipe = None
26
  cls._instance.model = None
27
  cls._instance.refiner = None
28
+ cls._instance.upscaler_2x = None
29
+ cls._instance.upscaler_4x = None
30
  return cls._instance
31
 
32
  def _flush(self):
 
44
  return True
45
  return False
46
 
47
+ def _unload(self, model):
 
 
 
 
 
 
 
 
48
  to_unload = []
49
  if self._should_unload_pipeline(model):
50
  to_unload.append("model")
51
  to_unload.append("pipe")
 
 
 
 
52
  for component in to_unload:
53
  delattr(self, component)
54
  self._flush()
 
78
  self.pipe.set_progress_bar_config(disable=not tqdm)
79
 
80
  def _load_refiner(self, refiner, tqdm, **kwargs):
81
+ if refiner and self.refiner is None:
82
  model = Config.REFINER_MODEL
83
  pipeline = Config.PIPELINES["img2img"]
84
  try:
 
92
  self.refiner.set_progress_bar_config(disable=not tqdm)
93
 
94
  def _load_upscaler(self, scale=1):
95
+ if scale == 2 and self.upscaler_2x is None:
96
+ try:
97
+ print("Loading 2x upscaler...")
98
+ self.upscaler_2x = RealESRGAN(2, "cuda")
99
+ self.upscaler_2x.load_weights()
100
+ except Exception as e:
101
+ print(f"Error loading 2x upscaler: {e}")
102
+ self.upscaler_2x = None
103
+ if scale == 4 and self.upscaler_4x is None:
104
+ try:
105
+ print("Loading 4x upscaler...")
106
+ self.upscaler_4x = RealESRGAN(4, "cuda")
107
+ self.upscaler_4x.load_weights()
108
+ except Exception as e:
109
+ print(f"Error loading 4x upscaler: {e}")
110
+ self.upscaler_4x = None
111
 
112
  def _load_deepcache(self, interval=1):
113
+ pipe_has_deepcache = hasattr(self.pipe, "deepcache")
114
+ if pipe_has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
115
  return
116
+ if pipe_has_deepcache:
117
  self.pipe.deepcache.disable()
118
  else:
119
  self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
120
  self.pipe.deepcache.set_params(cache_interval=interval)
121
  self.pipe.deepcache.enable()
122
 
123
+ if self.refiner is not None:
124
+ refiner_has_deepcache = hasattr(self.refiner, "deepcache")
125
+ if refiner_has_deepcache and self.refiner.deepcache.params["cache_interval"] == interval:
126
+ return
127
+ if refiner_has_deepcache:
128
+ self.refiner.deepcache.disable()
129
+ else:
130
+ self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
131
+ self.refiner.deepcache.set_params(cache_interval=interval)
132
+ self.refiner.deepcache.enable()
133
+
134
  def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, tqdm):
135
  model_lower = model.lower()
136
 
 
165
  "vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
166
  }
167
 
168
+ self._unload(model)
169
  self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
170
 
171
  # error loading model
172
  if self.pipe is None:
173
+ return
174
 
175
  same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
176
  same_karras = (
 
205
  self._load_refiner(refiner, tqdm, **refiner_kwargs)
206
  self._load_upscaler(scale)
207
  self._load_deepcache(deepcache)