Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -55,6 +55,7 @@ def models_to(model, device="cpu", excepts=None):
|
|
55 |
attr_value.to(device)
|
56 |
|
57 |
torch.cuda.empty_cache()
|
|
|
58 |
|
59 |
# Stage C model configuration
|
60 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
@@ -214,19 +215,24 @@ def infer(ref_style_file, style_description, caption, progress):
|
|
214 |
# Remove the batch dimension and keep only the generated image
|
215 |
sampled = sampled[1] # This selects the generated image, discarding the reference style image
|
216 |
|
|
|
|
|
|
|
217 |
# Ensure the tensor is in [C, H, W] format
|
218 |
if sampled.dim() == 3 and sampled.shape[0] == 3:
|
219 |
sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
|
220 |
-
|
221 |
else:
|
222 |
raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
|
223 |
|
224 |
progress(1.0, "Inference complete")
|
225 |
-
return output_file # Return the path to the saved image
|
|
|
226 |
|
227 |
finally:
|
228 |
# Clear CUDA cache
|
229 |
torch.cuda.empty_cache()
|
|
|
230 |
|
231 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
|
232 |
global models_rbm, models_b, device, sam_model
|
@@ -348,6 +354,7 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progre
|
|
348 |
finally:
|
349 |
# Clear CUDA cache
|
350 |
torch.cuda.empty_cache()
|
|
|
351 |
|
352 |
def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
|
353 |
result = None
|
|
|
55 |
attr_value.to(device)
|
56 |
|
57 |
torch.cuda.empty_cache()
|
58 |
+
gc.collect()
|
59 |
|
60 |
# Stage C model configuration
|
61 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
|
|
215 |
# Remove the batch dimension and keep only the generated image
|
216 |
sampled = sampled[1] # This selects the generated image, discarding the reference style image
|
217 |
|
218 |
+
# Ensure the tensor values are in the correct range
|
219 |
+
sampled = torch.clamp(sampled, 0, 1)
|
220 |
+
|
221 |
# Ensure the tensor is in [C, H, W] format
|
222 |
if sampled.dim() == 3 and sampled.shape[0] == 3:
|
223 |
sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
|
224 |
+
# sampled_image.save(output_file) # Save the image as a PNG
|
225 |
else:
|
226 |
raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
|
227 |
|
228 |
progress(1.0, "Inference complete")
|
229 |
+
#return output_file # Return the path to the saved image
|
230 |
+
return sampled_image
|
231 |
|
232 |
finally:
|
233 |
# Clear CUDA cache
|
234 |
torch.cuda.empty_cache()
|
235 |
+
gc.collect()
|
236 |
|
237 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
|
238 |
global models_rbm, models_b, device, sam_model
|
|
|
354 |
finally:
|
355 |
# Clear CUDA cache
|
356 |
torch.cuda.empty_cache()
|
357 |
+
gc.collect()
|
358 |
|
359 |
def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
|
360 |
result = None
|