Spaces:
Sleeping
Sleeping
adamelliotfields
commited on
Commit
•
67ca03a
1
Parent(s):
11ee0ea
Don't unload refiner and upscaler
Browse files- lib/inference.py +12 -4
- lib/loader.py +36 -25
lib/inference.py
CHANGED
@@ -154,7 +154,7 @@ def generate(
|
|
154 |
|
155 |
start = time.perf_counter()
|
156 |
loader = Loader()
|
157 |
-
|
158 |
KIND,
|
159 |
model,
|
160 |
scheduler,
|
@@ -165,6 +165,15 @@ def generate(
|
|
165 |
TQDM,
|
166 |
)
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
# prompt embeds for base and refiner
|
169 |
compel_1 = Compel(
|
170 |
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
|
@@ -204,7 +213,7 @@ def generate(
|
|
204 |
# refiner expects latents; upscaler expects numpy array
|
205 |
pipe_output_type = "pil"
|
206 |
refiner_output_type = "pil"
|
207 |
-
if
|
208 |
pipe_output_type = "latent"
|
209 |
if scale > 1:
|
210 |
refiner_output_type = "np"
|
@@ -215,7 +224,7 @@ def generate(
|
|
215 |
pipe_kwargs = {
|
216 |
"width": width,
|
217 |
"height": height,
|
218 |
-
"denoising_end": 0.8 if
|
219 |
"generator": generator,
|
220 |
"output_type": pipe_output_type,
|
221 |
"guidance_scale": guidance_scale,
|
@@ -255,7 +264,6 @@ def generate(
|
|
255 |
except Exception as e:
|
256 |
raise Error(f"RuntimeError: {e}")
|
257 |
finally:
|
258 |
-
# reset step and increment image
|
259 |
CURRENT_STEP = 0
|
260 |
CURRENT_IMAGE += 1
|
261 |
current_seed += 1
|
|
|
154 |
|
155 |
start = time.perf_counter()
|
156 |
loader = Loader()
|
157 |
+
loader.load(
|
158 |
KIND,
|
159 |
model,
|
160 |
scheduler,
|
|
|
165 |
TQDM,
|
166 |
)
|
167 |
|
168 |
+
pipe = loader.pipe
|
169 |
+
refiner = loader.refiner
|
170 |
+
|
171 |
+
upscaler = None
|
172 |
+
if scale == 2:
|
173 |
+
upscaler = loader.upscaler_2x
|
174 |
+
if scale == 4:
|
175 |
+
upscaler = loader.upscaler_4x
|
176 |
+
|
177 |
# prompt embeds for base and refiner
|
178 |
compel_1 = Compel(
|
179 |
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
|
|
|
213 |
# refiner expects latents; upscaler expects numpy array
|
214 |
pipe_output_type = "pil"
|
215 |
refiner_output_type = "pil"
|
216 |
+
if use_refiner:
|
217 |
pipe_output_type = "latent"
|
218 |
if scale > 1:
|
219 |
refiner_output_type = "np"
|
|
|
224 |
pipe_kwargs = {
|
225 |
"width": width,
|
226 |
"height": height,
|
227 |
+
"denoising_end": 0.8 if use_refiner else None,
|
228 |
"generator": generator,
|
229 |
"output_type": pipe_output_type,
|
230 |
"guidance_scale": guidance_scale,
|
|
|
264 |
except Exception as e:
|
265 |
raise Error(f"RuntimeError: {e}")
|
266 |
finally:
|
|
|
267 |
CURRENT_STEP = 0
|
268 |
CURRENT_IMAGE += 1
|
269 |
current_seed += 1
|
lib/loader.py
CHANGED
@@ -25,7 +25,8 @@ class Loader:
|
|
25 |
cls._instance.pipe = None
|
26 |
cls._instance.model = None
|
27 |
cls._instance.refiner = None
|
28 |
-
cls._instance.
|
|
|
29 |
return cls._instance
|
30 |
|
31 |
def _flush(self):
|
@@ -43,23 +44,11 @@ class Loader:
|
|
43 |
return True
|
44 |
return False
|
45 |
|
46 |
-
def
|
47 |
-
if self.refiner is not None and not refiner:
|
48 |
-
return True
|
49 |
-
return False
|
50 |
-
|
51 |
-
def _should_unload_upscaler(self, scale=1):
|
52 |
-
return self.upscaler is not None and scale == 1
|
53 |
-
|
54 |
-
def _unload(self, model, refiner, scale):
|
55 |
to_unload = []
|
56 |
if self._should_unload_pipeline(model):
|
57 |
to_unload.append("model")
|
58 |
to_unload.append("pipe")
|
59 |
-
if self._should_unload_refiner(refiner):
|
60 |
-
to_unload.append("refiner")
|
61 |
-
if self._should_unload_upscaler(scale):
|
62 |
-
to_unload.append("upscaler")
|
63 |
for component in to_unload:
|
64 |
delattr(self, component)
|
65 |
self._flush()
|
@@ -89,7 +78,7 @@ class Loader:
|
|
89 |
self.pipe.set_progress_bar_config(disable=not tqdm)
|
90 |
|
91 |
def _load_refiner(self, refiner, tqdm, **kwargs):
|
92 |
-
if self.refiner is None
|
93 |
model = Config.REFINER_MODEL
|
94 |
pipeline = Config.PIPELINES["img2img"]
|
95 |
try:
|
@@ -103,22 +92,45 @@ class Loader:
|
|
103 |
self.refiner.set_progress_bar_config(disable=not tqdm)
|
104 |
|
105 |
def _load_upscaler(self, scale=1):
|
106 |
-
if scale
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
def _load_deepcache(self, interval=1):
|
112 |
-
|
113 |
-
if
|
114 |
return
|
115 |
-
if
|
116 |
self.pipe.deepcache.disable()
|
117 |
else:
|
118 |
self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
|
119 |
self.pipe.deepcache.set_params(cache_interval=interval)
|
120 |
self.pipe.deepcache.enable()
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, tqdm):
|
123 |
model_lower = model.lower()
|
124 |
|
@@ -153,12 +165,12 @@ class Loader:
|
|
153 |
"vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
|
154 |
}
|
155 |
|
156 |
-
self._unload(model
|
157 |
self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
|
158 |
|
159 |
# error loading model
|
160 |
if self.pipe is None:
|
161 |
-
return
|
162 |
|
163 |
same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
|
164 |
same_karras = (
|
@@ -193,4 +205,3 @@ class Loader:
|
|
193 |
self._load_refiner(refiner, tqdm, **refiner_kwargs)
|
194 |
self._load_upscaler(scale)
|
195 |
self._load_deepcache(deepcache)
|
196 |
-
return self.pipe, self.refiner, self.upscaler
|
|
|
25 |
cls._instance.pipe = None
|
26 |
cls._instance.model = None
|
27 |
cls._instance.refiner = None
|
28 |
+
cls._instance.upscaler_2x = None
|
29 |
+
cls._instance.upscaler_4x = None
|
30 |
return cls._instance
|
31 |
|
32 |
def _flush(self):
|
|
|
44 |
return True
|
45 |
return False
|
46 |
|
47 |
+
def _unload(self, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
to_unload = []
|
49 |
if self._should_unload_pipeline(model):
|
50 |
to_unload.append("model")
|
51 |
to_unload.append("pipe")
|
|
|
|
|
|
|
|
|
52 |
for component in to_unload:
|
53 |
delattr(self, component)
|
54 |
self._flush()
|
|
|
78 |
self.pipe.set_progress_bar_config(disable=not tqdm)
|
79 |
|
80 |
def _load_refiner(self, refiner, tqdm, **kwargs):
|
81 |
+
if refiner and self.refiner is None:
|
82 |
model = Config.REFINER_MODEL
|
83 |
pipeline = Config.PIPELINES["img2img"]
|
84 |
try:
|
|
|
92 |
self.refiner.set_progress_bar_config(disable=not tqdm)
|
93 |
|
94 |
def _load_upscaler(self, scale=1):
|
95 |
+
if scale == 2 and self.upscaler_2x is None:
|
96 |
+
try:
|
97 |
+
print("Loading 2x upscaler...")
|
98 |
+
self.upscaler_2x = RealESRGAN(2, "cuda")
|
99 |
+
self.upscaler_2x.load_weights()
|
100 |
+
except Exception as e:
|
101 |
+
print(f"Error loading 2x upscaler: {e}")
|
102 |
+
self.upscaler_2x = None
|
103 |
+
if scale == 4 and self.upscaler_4x is None:
|
104 |
+
try:
|
105 |
+
print("Loading 4x upscaler...")
|
106 |
+
self.upscaler_4x = RealESRGAN(4, "cuda")
|
107 |
+
self.upscaler_4x.load_weights()
|
108 |
+
except Exception as e:
|
109 |
+
print(f"Error loading 4x upscaler: {e}")
|
110 |
+
self.upscaler_4x = None
|
111 |
|
112 |
def _load_deepcache(self, interval=1):
|
113 |
+
pipe_has_deepcache = hasattr(self.pipe, "deepcache")
|
114 |
+
if pipe_has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
|
115 |
return
|
116 |
+
if pipe_has_deepcache:
|
117 |
self.pipe.deepcache.disable()
|
118 |
else:
|
119 |
self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
|
120 |
self.pipe.deepcache.set_params(cache_interval=interval)
|
121 |
self.pipe.deepcache.enable()
|
122 |
|
123 |
+
if self.refiner is not None:
|
124 |
+
refiner_has_deepcache = hasattr(self.refiner, "deepcache")
|
125 |
+
if refiner_has_deepcache and self.refiner.deepcache.params["cache_interval"] == interval:
|
126 |
+
return
|
127 |
+
if refiner_has_deepcache:
|
128 |
+
self.refiner.deepcache.disable()
|
129 |
+
else:
|
130 |
+
self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
|
131 |
+
self.refiner.deepcache.set_params(cache_interval=interval)
|
132 |
+
self.refiner.deepcache.enable()
|
133 |
+
|
134 |
def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, tqdm):
|
135 |
model_lower = model.lower()
|
136 |
|
|
|
165 |
"vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
|
166 |
}
|
167 |
|
168 |
+
self._unload(model)
|
169 |
self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
|
170 |
|
171 |
# error loading model
|
172 |
if self.pipe is None:
|
173 |
+
return
|
174 |
|
175 |
same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
|
176 |
same_karras = (
|
|
|
205 |
self._load_refiner(refiner, tqdm, **refiner_kwargs)
|
206 |
self._load_upscaler(scale)
|
207 |
self._load_deepcache(deepcache)
|
|