fffiloni commited on
Commit
4904051
·
verified ·
1 Parent(s): 52273d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -55,6 +55,7 @@ def models_to(model, device="cpu", excepts=None):
55
  attr_value.to(device)
56
 
57
  torch.cuda.empty_cache()
 
58
 
59
  # Stage C model configuration
60
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
@@ -214,19 +215,24 @@ def infer(ref_style_file, style_description, caption, progress):
214
  # Remove the batch dimension and keep only the generated image
215
  sampled = sampled[1] # This selects the generated image, discarding the reference style image
216
 
 
 
 
217
  # Ensure the tensor is in [C, H, W] format
218
  if sampled.dim() == 3 and sampled.shape[0] == 3:
219
  sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
220
- sampled_image.save(output_file) # Save the image as a PNG
221
  else:
222
  raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
223
 
224
  progress(1.0, "Inference complete")
225
- return output_file # Return the path to the saved image
 
226
 
227
  finally:
228
  # Clear CUDA cache
229
  torch.cuda.empty_cache()
 
230
 
231
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
232
  global models_rbm, models_b, device, sam_model
@@ -348,6 +354,7 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progre
348
  finally:
349
  # Clear CUDA cache
350
  torch.cuda.empty_cache()
 
351
 
352
  def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
353
  result = None
 
55
  attr_value.to(device)
56
 
57
  torch.cuda.empty_cache()
58
+ gc.collect()
59
 
60
  # Stage C model configuration
61
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
 
215
  # Remove the batch dimension and keep only the generated image
216
  sampled = sampled[1] # This selects the generated image, discarding the reference style image
217
 
218
+ # Ensure the tensor values are in the correct range
219
+ sampled = torch.clamp(sampled, 0, 1)
220
+
221
  # Ensure the tensor is in [C, H, W] format
222
  if sampled.dim() == 3 and sampled.shape[0] == 3:
223
  sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
224
+ # sampled_image.save(output_file) # Save the image as a PNG
225
  else:
226
  raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
227
 
228
  progress(1.0, "Inference complete")
229
+ #return output_file # Return the path to the saved image
230
+ return sampled_image
231
 
232
  finally:
233
  # Clear CUDA cache
234
  torch.cuda.empty_cache()
235
+ gc.collect()
236
 
237
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
238
  global models_rbm, models_b, device, sam_model
 
354
  finally:
355
  # Clear CUDA cache
356
  torch.cuda.empty_cache()
357
+ gc.collect()
358
 
359
  def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
360
  result = None