Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -316,6 +316,7 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
316 |
|
317 |
# Initialize a list to store file paths of saved images
|
318 |
jpeg_images = []
|
|
|
319 |
|
320 |
# run propagation throughout the video and collect the results in a dict
|
321 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
@@ -352,11 +353,18 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
352 |
if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
|
353 |
available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
|
354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
torch.cuda.empty_cache()
|
356 |
print(f"JPEG_IMAGES: {jpeg_images}")
|
357 |
|
358 |
if vis_frame_type == "check":
|
359 |
-
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True)
|
360 |
elif vis_frame_type == "render":
|
361 |
# Create a video clip from the image sequence
|
362 |
original_fps = get_video_fps(video_in)
|
@@ -371,8 +379,12 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
371 |
final_vid_output_path,
|
372 |
codec='libx264'
|
373 |
)
|
|
|
|
|
|
|
|
|
374 |
|
375 |
-
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
|
376 |
|
377 |
def update_ui(vis_frame_type):
|
378 |
if vis_frame_type == "check":
|
@@ -478,6 +490,7 @@ with gr.Blocks(css=css) as demo:
|
|
478 |
reset_prpgt_brn = gr.Button("Reset", visible=False)
|
479 |
output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
|
480 |
output_video = gr.Video(visible=False)
|
|
|
481 |
# output_result_mask = gr.Image()
|
482 |
|
483 |
|
@@ -581,7 +594,7 @@ with gr.Blocks(css=css) as demo:
|
|
581 |
).then(
|
582 |
fn = propagate_to_all,
|
583 |
inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
|
584 |
-
outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn]
|
585 |
)
|
586 |
|
587 |
demo.launch(show_api=False, show_error=True)
|
|
|
316 |
|
317 |
# Initialize a list to store file paths of saved images
|
318 |
jpeg_images = []
|
319 |
+
masks_frames = []
|
320 |
|
321 |
# run propagation throughout the video and collect the results in a dict
|
322 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
|
|
353 |
if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
|
354 |
available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
|
355 |
|
356 |
+
# Save the raw binary mask as a separate image
|
357 |
+
mask_filename = os.path.join(frames_output_dir, f"mask_{out_frame_idx}.jpg")
|
358 |
+
binary_mask = (out_mask * 255).astype(np.uint8) # Scale mask to 0-255
|
359 |
+
mask_image = Image.fromarray(binary_mask)
|
360 |
+
mask_image.save(mask_filename) # Save the mask as a JPEG
|
361 |
+
masks_frames.append(mask_filename) # Append to the list of masks
|
362 |
+
|
363 |
torch.cuda.empty_cache()
|
364 |
print(f"JPEG_IMAGES: {jpeg_images}")
|
365 |
|
366 |
if vis_frame_type == "check":
|
367 |
+
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True), None
|
368 |
elif vis_frame_type == "render":
|
369 |
# Create a video clip from the image sequence
|
370 |
original_fps = get_video_fps(video_in)
|
|
|
379 |
final_vid_output_path,
|
380 |
codec='libx264'
|
381 |
)
|
382 |
+
|
383 |
+
mask_clip = ImageSequenceClip(masks_frames, fps=fps)
|
384 |
+
# Write the result to a file
|
385 |
+
mask_final_vid_output_path = "mask_output_video.mp4"
|
386 |
|
387 |
+
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_final_vid_output_path
|
388 |
|
389 |
def update_ui(vis_frame_type):
|
390 |
if vis_frame_type == "check":
|
|
|
490 |
reset_prpgt_brn = gr.Button("Reset", visible=False)
|
491 |
output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
|
492 |
output_video = gr.Video(visible=False)
|
493 |
+
mask_final_output = gr.Video(label="Mask Video")
|
494 |
# output_result_mask = gr.Image()
|
495 |
|
496 |
|
|
|
594 |
).then(
|
595 |
fn = propagate_to_all,
|
596 |
inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
|
597 |
+
outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn, mask_final_output]
|
598 |
)
|
599 |
|
600 |
demo.launch(show_api=False, show_error=True)
|