Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -273,7 +273,7 @@ def infer(ref_style_file, style_description, caption):
|
|
273 |
reset_inference_state()
|
274 |
|
275 |
def reset_compo_inference_state():
|
276 |
-
global models_rbm, models_b, extras, extras_b, device, core, core_b
|
277 |
|
278 |
# Reset sampling configurations
|
279 |
extras.sampling_configs['cfg'] = 4
|
@@ -290,6 +290,13 @@ def reset_compo_inference_state():
|
|
290 |
models_to(models_rbm, device="cpu")
|
291 |
models_b.generator.to("cpu")
|
292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
# Clear CUDA cache
|
294 |
torch.cuda.empty_cache()
|
295 |
gc.collect()
|
@@ -305,7 +312,7 @@ def reset_compo_inference_state():
|
|
305 |
gc.collect()
|
306 |
|
307 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
|
308 |
-
global models_rbm, models_b, device
|
309 |
try:
|
310 |
caption = f"{caption} in {style_description}"
|
311 |
sam_prompt = f"{caption}"
|
@@ -331,7 +338,12 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
|
|
331 |
use_sam_mask = False
|
332 |
x0_preview = models_rbm.previewer(x0_forward)
|
333 |
sam_model = LangSAM()
|
334 |
-
|
|
|
|
|
|
|
|
|
|
|
335 |
|
336 |
x0_preview_pil = T.ToPILImage()(x0_preview[0].cpu())
|
337 |
sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
|
@@ -344,8 +356,10 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
|
|
344 |
|
345 |
if low_vram:
|
346 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
347 |
-
|
348 |
-
|
|
|
|
|
349 |
|
350 |
# Stage C reverse process.
|
351 |
sampling_c = extras.gdf.sample(
|
|
|
273 |
reset_inference_state()
|
274 |
|
275 |
def reset_compo_inference_state():
|
276 |
+
global models_rbm, models_b, extras, extras_b, device, core, core_b, sam_model
|
277 |
|
278 |
# Reset sampling configurations
|
279 |
extras.sampling_configs['cfg'] = 4
|
|
|
290 |
models_to(models_rbm, device="cpu")
|
291 |
models_b.generator.to("cpu")
|
292 |
|
293 |
+
# Move SAM model components to CPU if they exist
|
294 |
+
if 'sam_model' in globals():
|
295 |
+
if hasattr(sam_model, 'sam'):
|
296 |
+
sam_model.sam.to("cpu")
|
297 |
+
if hasattr(sam_model, 'text_encoder'):
|
298 |
+
sam_model.text_encoder.to("cpu")
|
299 |
+
|
300 |
# Clear CUDA cache
|
301 |
torch.cuda.empty_cache()
|
302 |
gc.collect()
|
|
|
312 |
gc.collect()
|
313 |
|
314 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
|
315 |
+
global models_rbm, models_b, device, sam_model
|
316 |
try:
|
317 |
caption = f"{caption} in {style_description}"
|
318 |
sam_prompt = f"{caption}"
|
|
|
338 |
use_sam_mask = False
|
339 |
x0_preview = models_rbm.previewer(x0_forward)
|
340 |
sam_model = LangSAM()
|
341 |
+
|
342 |
+
# Move SAM model components to the correct device
|
343 |
+
if hasattr(sam_model, 'sam'):
|
344 |
+
sam_model.sam.to(device)
|
345 |
+
if hasattr(sam_model, 'text_encoder'):
|
346 |
+
sam_model.text_encoder.to(device)
|
347 |
|
348 |
x0_preview_pil = T.ToPILImage()(x0_preview[0].cpu())
|
349 |
sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
|
|
|
356 |
|
357 |
if low_vram:
|
358 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
359 |
+
if hasattr(sam_model, 'sam'):
|
360 |
+
sam_model.sam.to("cpu")
|
361 |
+
if hasattr(sam_model, 'text_encoder'):
|
362 |
+
sam_model.text_encoder.to("cpu")
|
363 |
|
364 |
# Stage C reverse process.
|
365 |
sampling_c = extras.gdf.sample(
|