Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -107,28 +107,32 @@ models_b = WurstCoreB.Models(
|
|
107 |
)
|
108 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
109 |
|
110 |
-
if low_vram:
|
111 |
-
# Off-load old generator (which is not used in models_rbm)
|
112 |
-
models.generator.to("cpu")
|
113 |
-
torch.cuda.empty_cache()
|
114 |
|
115 |
-
generator_rbm = StageCRBM()
|
116 |
-
for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
|
117 |
-
set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
|
118 |
-
generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
|
119 |
-
generator_rbm = core.load_model(generator_rbm, 'generator')
|
120 |
-
|
121 |
-
models_rbm = core.Models(
|
122 |
-
effnet=models.effnet, previewer=models.previewer,
|
123 |
-
generator=generator_rbm, generator_ema=models.generator_ema,
|
124 |
-
tokenizer=models.tokenizer, text_model=models.text_model, image_model=models.image_model
|
125 |
-
)
|
126 |
-
models_rbm.generator.eval().requires_grad_(False)
|
127 |
|
128 |
sam_model = LangSAM()
|
129 |
|
130 |
def infer(ref_style_file, style_description, caption, progress):
|
131 |
global models_rbm, models_b, device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
if low_vram:
|
133 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
134 |
try:
|
@@ -236,6 +240,26 @@ def infer(ref_style_file, style_description, caption, progress):
|
|
236 |
|
237 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
|
238 |
global models_rbm, models_b, device, sam_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
if low_vram:
|
240 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
241 |
models_to(sam_model, device=device)
|
|
|
107 |
)
|
108 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
109 |
|
|
|
|
|
|
|
|
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
sam_model = LangSAM()
|
113 |
|
114 |
def infer(ref_style_file, style_description, caption, progress):
|
115 |
global models_rbm, models_b, device
|
116 |
+
|
117 |
+
if low_vram:
|
118 |
+
# Off-load old generator (which is not used in models_rbm)
|
119 |
+
models.generator.to("cpu")
|
120 |
+
torch.cuda.empty_cache()
|
121 |
+
gc.collect()
|
122 |
+
|
123 |
+
generator_rbm = StageCRBM()
|
124 |
+
for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
|
125 |
+
set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
|
126 |
+
generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
|
127 |
+
generator_rbm = core.load_model(generator_rbm, 'generator')
|
128 |
+
|
129 |
+
models_rbm = core.Models(
|
130 |
+
effnet=models.effnet, previewer=models.previewer,
|
131 |
+
generator=generator_rbm, generator_ema=models.generator_ema,
|
132 |
+
tokenizer=models.tokenizer, text_model=models.text_model, image_model=models.image_model
|
133 |
+
)
|
134 |
+
models_rbm.generator.eval().requires_grad_(False)
|
135 |
+
|
136 |
if low_vram:
|
137 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
138 |
try:
|
|
|
240 |
|
241 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
|
242 |
global models_rbm, models_b, device, sam_model
|
243 |
+
|
244 |
+
if low_vram:
|
245 |
+
# Off-load old generator (which is not used in models_rbm)
|
246 |
+
models.generator.to("cpu")
|
247 |
+
torch.cuda.empty_cache()
|
248 |
+
gc.collect()
|
249 |
+
|
250 |
+
generator_rbm = StageCRBM()
|
251 |
+
for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
|
252 |
+
set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
|
253 |
+
generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
|
254 |
+
generator_rbm = core.load_model(generator_rbm, 'generator')
|
255 |
+
|
256 |
+
models_rbm = core.Models(
|
257 |
+
effnet=models.effnet, previewer=models.previewer,
|
258 |
+
generator=generator_rbm, generator_ema=models.generator_ema,
|
259 |
+
tokenizer=models.tokenizer, text_model=models.text_model, image_model=models.image_model
|
260 |
+
)
|
261 |
+
models_rbm.generator.eval().requires_grad_(False)
|
262 |
+
|
263 |
if low_vram:
|
264 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
265 |
models_to(sam_model, device=device)
|