Spaces:
Running
on
Zero
Running
on
Zero
import os, shutil, sys | |
import argparse | |
import gdown | |
import cv2 | |
import numpy as np | |
import os | |
import sys | |
import requests | |
import json | |
import torchvision | |
import torch | |
import psutil | |
import time | |
try: | |
from mmcv.cnn import ConvModule | |
except: | |
os.system("mim install mmcv") | |
# Import files from the local folder | |
root_path = os.path.abspath('.') | |
sys.path.append(root_path) | |
from track_anything_code.model import TrackingAnything | |
from track_anything_code.track_anything_module import get_frames_from_video, download_checkpoint, parse_augment, sam_refine, vos_tracking_video | |
from scripts.compress_videos import compress_video | |
if __name__ == "__main__": | |
dataset_path = "Bridge_v1_TT14" | |
video_name = "combined.mp4" | |
verbose = True # If this is verbose, you will continue to write the code | |
################################################## Model setup #################################################### | |
# check and download checkpoints if needed | |
sam_checkpoint = "sam_vit_h_4b8939.pth" | |
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
xmem_checkpoint = "XMem-s012.pth" | |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth" | |
folder ="./pretrained" | |
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint) | |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) | |
# argument | |
args = parse_augment() | |
args.device = "cuda" # Any GPU is ok | |
# Initialize the Track model | |
track_model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) | |
################################################################################################################### | |
# Iterate all files under the folder | |
for sub_folder_name in sorted(os.listdir(dataset_path)): | |
################################################## Setting #################################################### | |
sub_folder_path = os.path.join(dataset_path, sub_folder_name) | |
click_state = [[],[]] | |
interactive_state = { | |
"inference_times": 0, | |
"negative_click_times" : 0, | |
"positive_click_times": 0, | |
"mask_save": args.mask_save, | |
"multi_mask": { | |
"mask_names": [], | |
"masks": [] | |
}, | |
"track_end_number": None, | |
"resize_ratio": 1 | |
} | |
################################################################################################################### | |
video_path = os.path.join(sub_folder_path, video_name) | |
if not os.path.exists(video_path): | |
print("We cannot find the path of the ", video_path, " and we will compress one") | |
status = compress_video(sub_folder_path, video_name) | |
if not status: | |
print("We still cannot generate a video") | |
continue | |
# Read video state | |
video_state = { | |
"user_name": "", | |
"video_name": "", | |
"origin_images": None, | |
"painted_images": None, | |
"masks": None, | |
"inpaint_masks": None, | |
"logits": None, | |
"select_frame_number": 0, | |
"fps": 30 | |
} | |
video_state, template_frame = get_frames_from_video(video_path, video_state, track_model) | |
########################################################## Get the sam point based on the data.txt ########################################################### | |
data_txt_path = os.path.join(sub_folder_path, "data.txt") | |
if not os.path.exists(data_txt_path): | |
print("We cannot find data.txt in this folder") | |
continue | |
data_file = open(data_txt_path, 'r') | |
lines = data_file.readlines() | |
frame_idx, horizontal, vertical = lines[0][:-2].split(' ') # Only read the first point | |
point_cord = [int(float(horizontal)), int(float(vertical))] | |
# Process by SAM | |
track_model.samcontroler.sam_controler.reset_image() # Reset the image to clean history | |
painted_image, video_state, interactive_state, operation_log = sam_refine(track_model, video_state, "Positive", click_state, interactive_state, point_cord) | |
################################################################################################################################################################ | |
######################################################### Get the tracking output ######################################################################## | |
# Track the video for processing | |
segment_output_path = os.path.join(sub_folder_path, "segment_output.gif") | |
video_state = vos_tracking_video(track_model, segment_output_path, video_state, interactive_state, mask_dropdown=[])[0] # mask_dropdown is empty now | |
# Extract the mask needed by us for further point calculating | |
masks = video_state["masks"] # In the range [0, 1] | |
if verbose: | |
for idx, mask in enumerate(masks): | |
cv2.imwrite(os.path.join(sub_folder_path, "mask"+str(idx)+".png"), mask*255) | |
############################################################################################################################################################## | |