fffiloni commited on
Commit
3760c54
·
verified ·
1 Parent(s): 25265e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -316,6 +316,7 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
316
 
317
  # Initialize a list to store file paths of saved images
318
  jpeg_images = []
 
319
 
320
  # run propagation throughout the video and collect the results in a dict
321
  video_segments = {} # video_segments contains the per-frame segmentation results
@@ -352,11 +353,18 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
352
  if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
353
  available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
354
 
 
 
 
 
 
 
 
355
  torch.cuda.empty_cache()
356
  print(f"JPEG_IMAGES: {jpeg_images}")
357
 
358
  if vis_frame_type == "check":
359
- return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True)
360
  elif vis_frame_type == "render":
361
  # Create a video clip from the image sequence
362
  original_fps = get_video_fps(video_in)
@@ -371,8 +379,12 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
371
  final_vid_output_path,
372
  codec='libx264'
373
  )
 
 
 
 
374
 
375
- return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
376
 
377
  def update_ui(vis_frame_type):
378
  if vis_frame_type == "check":
@@ -478,6 +490,7 @@ with gr.Blocks(css=css) as demo:
478
  reset_prpgt_brn = gr.Button("Reset", visible=False)
479
  output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
480
  output_video = gr.Video(visible=False)
 
481
  # output_result_mask = gr.Image()
482
 
483
 
@@ -581,7 +594,7 @@ with gr.Blocks(css=css) as demo:
581
  ).then(
582
  fn = propagate_to_all,
583
  inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
584
- outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn]
585
  )
586
 
587
  demo.launch(show_api=False, show_error=True)
 
316
 
317
  # Initialize a list to store file paths of saved images
318
  jpeg_images = []
319
+ masks_frames = []
320
 
321
  # run propagation throughout the video and collect the results in a dict
322
  video_segments = {} # video_segments contains the per-frame segmentation results
 
353
  if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
354
  available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
355
 
356
+ # Save the raw binary mask as a separate image
357
+ mask_filename = os.path.join(frames_output_dir, f"mask_{out_frame_idx}.jpg")
358
+ binary_mask = (out_mask * 255).astype(np.uint8) # Scale mask to 0-255
359
+ mask_image = Image.fromarray(binary_mask)
360
+ mask_image.save(mask_filename) # Save the mask as a JPEG
361
+ masks_frames.append(mask_filename) # Append to the list of masks
362
+
363
  torch.cuda.empty_cache()
364
  print(f"JPEG_IMAGES: {jpeg_images}")
365
 
366
  if vis_frame_type == "check":
367
+ return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True), None
368
  elif vis_frame_type == "render":
369
  # Create a video clip from the image sequence
370
  original_fps = get_video_fps(video_in)
 
379
  final_vid_output_path,
380
  codec='libx264'
381
  )
382
+
383
+ mask_clip = ImageSequenceClip(masks_frames, fps=fps)
384
+ # Write the result to a file
385
+ mask_final_vid_output_path = "mask_output_video.mp4"
386
 
387
+ return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_final_vid_output_path
388
 
389
  def update_ui(vis_frame_type):
390
  if vis_frame_type == "check":
 
490
  reset_prpgt_brn = gr.Button("Reset", visible=False)
491
  output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
492
  output_video = gr.Video(visible=False)
493
+ mask_final_output = gr.Video(label="Mask Video")
494
  # output_result_mask = gr.Image()
495
 
496
 
 
594
  ).then(
595
  fn = propagate_to_all,
596
  inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
597
+ outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn, mask_final_output]
598
  )
599
 
600
  demo.launch(show_api=False, show_error=True)