Spaces:
Starting
on
T4
Starting
on
T4
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import os | |
import numpy as np | |
import torch | |
from PIL import Image | |
from sam2.build_sam import build_sam2_video_predictor | |
# the PNG palette for DAVIS 2017 dataset | |
DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0" | |
def load_ann_png(path): | |
"""Load a PNG file as a mask and its palette.""" | |
mask = Image.open(path) | |
palette = mask.getpalette() | |
mask = np.array(mask).astype(np.uint8) | |
return mask, palette | |
def save_ann_png(path, mask, palette): | |
"""Save a mask as a PNG file with the given palette.""" | |
assert mask.dtype == np.uint8 | |
assert mask.ndim == 2 | |
output_mask = Image.fromarray(mask) | |
output_mask.putpalette(palette) | |
output_mask.save(path) | |
def get_per_obj_mask(mask): | |
"""Split a mask into per-object masks.""" | |
object_ids = np.unique(mask) | |
object_ids = object_ids[object_ids > 0].tolist() | |
per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids} | |
return per_obj_mask | |
def put_per_obj_mask(per_obj_mask, height, width): | |
"""Combine per-object masks into a single mask.""" | |
mask = np.zeros((height, width), dtype=np.uint8) | |
object_ids = sorted(per_obj_mask)[::-1] | |
for object_id in object_ids: | |
object_mask = per_obj_mask[object_id] | |
object_mask = object_mask.reshape(height, width) | |
mask[object_mask] = object_id | |
return mask | |
def load_masks_from_dir(input_mask_dir, video_name, frame_name, per_obj_png_file): | |
"""Load masks from a directory as a dict of per-object masks.""" | |
if not per_obj_png_file: | |
input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png") | |
input_mask, input_palette = load_ann_png(input_mask_path) | |
per_obj_input_mask = get_per_obj_mask(input_mask) | |
else: | |
per_obj_input_mask = {} | |
# each object is a directory in "{object_id:%03d}" format | |
for object_name in os.listdir(os.path.join(input_mask_dir, video_name)): | |
object_id = int(object_name) | |
input_mask_path = os.path.join( | |
input_mask_dir, video_name, object_name, f"{frame_name}.png" | |
) | |
input_mask, input_palette = load_ann_png(input_mask_path) | |
per_obj_input_mask[object_id] = input_mask > 0 | |
return per_obj_input_mask, input_palette | |
def save_masks_to_dir( | |
output_mask_dir, | |
video_name, | |
frame_name, | |
per_obj_output_mask, | |
height, | |
width, | |
per_obj_png_file, | |
output_palette, | |
): | |
"""Save masks to a directory as PNG files.""" | |
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True) | |
if not per_obj_png_file: | |
output_mask = put_per_obj_mask(per_obj_output_mask, height, width) | |
output_mask_path = os.path.join( | |
output_mask_dir, video_name, f"{frame_name}.png" | |
) | |
save_ann_png(output_mask_path, output_mask, output_palette) | |
else: | |
for object_id, object_mask in per_obj_output_mask.items(): | |
object_name = f"{object_id:03d}" | |
os.makedirs( | |
os.path.join(output_mask_dir, video_name, object_name), | |
exist_ok=True, | |
) | |
output_mask = object_mask.reshape(height, width).astype(np.uint8) | |
output_mask_path = os.path.join( | |
output_mask_dir, video_name, object_name, f"{frame_name}.png" | |
) | |
save_ann_png(output_mask_path, output_mask, output_palette) | |
def vos_inference( | |
predictor, | |
base_video_dir, | |
input_mask_dir, | |
output_mask_dir, | |
video_name, | |
score_thresh=0.0, | |
use_all_masks=False, | |
per_obj_png_file=False, | |
): | |
"""Run VOS inference on a single video with the given predictor.""" | |
# load the video frames and initialize the inference state on this video | |
video_dir = os.path.join(base_video_dir, video_name) | |
frame_names = [ | |
os.path.splitext(p)[0] | |
for p in os.listdir(video_dir) | |
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] | |
] | |
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) | |
inference_state = predictor.init_state( | |
video_path=video_dir, async_loading_frames=False | |
) | |
height = inference_state["video_height"] | |
width = inference_state["video_width"] | |
input_palette = None | |
# fetch mask inputs from input_mask_dir (either only mask for the first frame, or all available masks) | |
if not use_all_masks: | |
# use only the first video's ground-truth mask as the input mask | |
input_frame_inds = [0] | |
else: | |
# use all mask files available in the input_mask_dir as the input masks | |
if not per_obj_png_file: | |
input_frame_inds = [ | |
idx | |
for idx, name in enumerate(frame_names) | |
if os.path.exists( | |
os.path.join(input_mask_dir, video_name, f"{name}.png") | |
) | |
] | |
else: | |
input_frame_inds = [ | |
idx | |
for object_name in os.listdir(os.path.join(input_mask_dir, video_name)) | |
for idx, name in enumerate(frame_names) | |
if os.path.exists( | |
os.path.join(input_mask_dir, video_name, object_name, f"{name}.png") | |
) | |
] | |
input_frame_inds = sorted(set(input_frame_inds)) | |
# add those input masks to SAM 2 inference state before propagation | |
for input_frame_idx in input_frame_inds: | |
per_obj_input_mask, input_palette = load_masks_from_dir( | |
input_mask_dir=input_mask_dir, | |
video_name=video_name, | |
frame_name=frame_names[input_frame_idx], | |
per_obj_png_file=per_obj_png_file, | |
) | |
for object_id, object_mask in per_obj_input_mask.items(): | |
predictor.add_new_mask( | |
inference_state=inference_state, | |
frame_idx=input_frame_idx, | |
obj_id=object_id, | |
mask=object_mask, | |
) | |
# run propagation throughout the video and collect the results in a dict | |
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True) | |
output_palette = input_palette or DAVIS_PALETTE | |
video_segments = {} # video_segments contains the per-frame segmentation results | |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( | |
inference_state | |
): | |
per_obj_output_mask = { | |
out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy() | |
for i, out_obj_id in enumerate(out_obj_ids) | |
} | |
video_segments[out_frame_idx] = per_obj_output_mask | |
# write the output masks as palette PNG files to output_mask_dir | |
for out_frame_idx, per_obj_output_mask in video_segments.items(): | |
save_masks_to_dir( | |
output_mask_dir=output_mask_dir, | |
video_name=video_name, | |
frame_name=frame_names[out_frame_idx], | |
per_obj_output_mask=per_obj_output_mask, | |
height=height, | |
width=width, | |
per_obj_png_file=per_obj_png_file, | |
output_palette=output_palette, | |
) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--sam2_cfg", | |
type=str, | |
default="sam2_hiera_b+.yaml", | |
help="SAM 2 model configuration file", | |
) | |
parser.add_argument( | |
"--sam2_checkpoint", | |
type=str, | |
default="./checkpoints/sam2_hiera_b+.pt", | |
help="path to the SAM 2 model checkpoint", | |
) | |
parser.add_argument( | |
"--base_video_dir", | |
type=str, | |
required=True, | |
help="directory containing videos (as JPEG files) to run VOS prediction on", | |
) | |
parser.add_argument( | |
"--input_mask_dir", | |
type=str, | |
required=True, | |
help="directory containing input masks (as PNG files) of each video", | |
) | |
parser.add_argument( | |
"--video_list_file", | |
type=str, | |
default=None, | |
help="text file containing the list of video names to run VOS prediction on", | |
) | |
parser.add_argument( | |
"--output_mask_dir", | |
type=str, | |
required=True, | |
help="directory to save the output masks (as PNG files)", | |
) | |
parser.add_argument( | |
"--score_thresh", | |
type=float, | |
default=0.0, | |
help="threshold for the output mask logits (default: 0.0)", | |
) | |
parser.add_argument( | |
"--use_all_masks", | |
action="store_true", | |
help="whether to use all available PNG files in input_mask_dir " | |
"(default without this flag: just the first PNG file as input to the SAM 2 model; " | |
"usually we don't need this flag, since semi-supervised VOS evaluation usually takes input from the first frame only)", | |
) | |
parser.add_argument( | |
"--per_obj_png_file", | |
action="store_true", | |
help="whether use separate per-object PNG files for input and output masks " | |
"(default without this flag: all object masks are packed into a single PNG file on each frame following DAVIS format; " | |
"note that the SA-V dataset stores each object mask as an individual PNG file and requires this flag)", | |
) | |
parser.add_argument( | |
"--apply_postprocessing", | |
action="store_true", | |
help="whether to apply postprocessing (e.g. hole-filling) to the output masks " | |
"(we don't apply such post-processing in the SAM 2 model evaluation)", | |
) | |
args = parser.parse_args() | |
# if we use per-object PNG files, they could possibly overlap in inputs and outputs | |
hydra_overrides_extra = [ | |
"++model.non_overlap_masks=" + ("false" if args.per_obj_png_file else "true") | |
] | |
predictor = build_sam2_video_predictor( | |
config_file=args.sam2_cfg, | |
ckpt_path=args.sam2_checkpoint, | |
apply_postprocessing=args.apply_postprocessing, | |
hydra_overrides_extra=hydra_overrides_extra, | |
) | |
if args.use_all_masks: | |
print("using all available masks in input_mask_dir as input to the SAM 2 model") | |
else: | |
print( | |
"using only the first frame's mask in input_mask_dir as input to the SAM 2 model" | |
) | |
# if a video list file is provided, read the video names from the file | |
# (otherwise, we use all subdirectories in base_video_dir) | |
if args.video_list_file is not None: | |
with open(args.video_list_file, "r") as f: | |
video_names = [v.strip() for v in f.readlines()] | |
else: | |
video_names = [ | |
p | |
for p in os.listdir(args.base_video_dir) | |
if os.path.isdir(os.path.join(args.base_video_dir, p)) | |
] | |
print(f"running VOS prediction on {len(video_names)} videos:\n{video_names}") | |
for n_video, video_name in enumerate(video_names): | |
print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}") | |
vos_inference( | |
predictor=predictor, | |
base_video_dir=args.base_video_dir, | |
input_mask_dir=args.input_mask_dir, | |
output_mask_dir=args.output_mask_dir, | |
video_name=video_name, | |
score_thresh=args.score_thresh, | |
use_all_masks=args.use_all_masks, | |
per_obj_png_file=args.per_obj_png_file, | |
) | |
print( | |
f"completed VOS prediction on {len(video_names)} videos -- " | |
f"output masks saved to {args.output_mask_dir}" | |
) | |
if __name__ == "__main__": | |
main() | |