fffiloni commited on
Commit
b6f94c1
·
verified ·
1 Parent(s): f3011b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -273,7 +273,7 @@ def infer(ref_style_file, style_description, caption):
273
  reset_inference_state()
274
 
275
  def reset_compo_inference_state():
276
- global models_rbm, models_b, extras, extras_b, device, core, core_b
277
 
278
  # Reset sampling configurations
279
  extras.sampling_configs['cfg'] = 4
@@ -290,6 +290,13 @@ def reset_compo_inference_state():
290
  models_to(models_rbm, device="cpu")
291
  models_b.generator.to("cpu")
292
 
 
 
 
 
 
 
 
293
  # Clear CUDA cache
294
  torch.cuda.empty_cache()
295
  gc.collect()
@@ -305,7 +312,7 @@ def reset_compo_inference_state():
305
  gc.collect()
306
 
307
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
308
- global models_rbm, models_b, device
309
  try:
310
  caption = f"{caption} in {style_description}"
311
  sam_prompt = f"{caption}"
@@ -331,7 +338,12 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
331
  use_sam_mask = False
332
  x0_preview = models_rbm.previewer(x0_forward)
333
  sam_model = LangSAM()
334
- sam_model.to(device)
 
 
 
 
 
335
 
336
  x0_preview_pil = T.ToPILImage()(x0_preview[0].cpu())
337
  sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
@@ -344,8 +356,10 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
344
 
345
  if low_vram:
346
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
347
- models_to(sam_model, device="cpu")
348
- models_to(sam_model.sam, device="cpu")
 
 
349
 
350
  # Stage C reverse process.
351
  sampling_c = extras.gdf.sample(
 
273
  reset_inference_state()
274
 
275
  def reset_compo_inference_state():
276
+ global models_rbm, models_b, extras, extras_b, device, core, core_b, sam_model
277
 
278
  # Reset sampling configurations
279
  extras.sampling_configs['cfg'] = 4
 
290
  models_to(models_rbm, device="cpu")
291
  models_b.generator.to("cpu")
292
 
293
+ # Move SAM model components to CPU if they exist
294
+ if 'sam_model' in globals():
295
+ if hasattr(sam_model, 'sam'):
296
+ sam_model.sam.to("cpu")
297
+ if hasattr(sam_model, 'text_encoder'):
298
+ sam_model.text_encoder.to("cpu")
299
+
300
  # Clear CUDA cache
301
  torch.cuda.empty_cache()
302
  gc.collect()
 
312
  gc.collect()
313
 
314
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
315
+ global models_rbm, models_b, device, sam_model
316
  try:
317
  caption = f"{caption} in {style_description}"
318
  sam_prompt = f"{caption}"
 
338
  use_sam_mask = False
339
  x0_preview = models_rbm.previewer(x0_forward)
340
  sam_model = LangSAM()
341
+
342
+ # Move SAM model components to the correct device
343
+ if hasattr(sam_model, 'sam'):
344
+ sam_model.sam.to(device)
345
+ if hasattr(sam_model, 'text_encoder'):
346
+ sam_model.text_encoder.to(device)
347
 
348
  x0_preview_pil = T.ToPILImage()(x0_preview[0].cpu())
349
  sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
 
356
 
357
  if low_vram:
358
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
359
+ if hasattr(sam_model, 'sam'):
360
+ sam_model.sam.to("cpu")
361
+ if hasattr(sam_model, 'text_encoder'):
362
+ sam_model.text_encoder.to("cpu")
363
 
364
  # Stage C reverse process.
365
  sampling_c = extras.gdf.sample(