This-and-That / curation_pipeline /tracking_by_keypoint.py
HikariDawn777's picture
feat: initial push
59b2a81
import os, shutil, sys
import argparse
import gdown
import cv2
import numpy as np
import os
import sys
import requests
import json
import torchvision
import torch
import psutil
import time
try:
from mmcv.cnn import ConvModule
except:
os.system("mim install mmcv")
# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from track_anything_code.model import TrackingAnything
from track_anything_code.track_anything_module import get_frames_from_video, download_checkpoint, parse_augment, sam_refine, vos_tracking_video
from scripts.compress_videos import compress_video
if __name__ == "__main__":
dataset_path = "Bridge_v1_TT14"
video_name = "combined.mp4"
verbose = True # If this is verbose, you will continue to write the code
################################################## Model setup ####################################################
# check and download checkpoints if needed
sam_checkpoint = "sam_vit_h_4b8939.pth"
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
xmem_checkpoint = "XMem-s012.pth"
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
folder ="./pretrained"
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
# argument
args = parse_augment()
args.device = "cuda" # Any GPU is ok
# Initialize the Track model
track_model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
###################################################################################################################
# Iterate all files under the folder
for sub_folder_name in sorted(os.listdir(dataset_path)):
################################################## Setting ####################################################
sub_folder_path = os.path.join(dataset_path, sub_folder_name)
click_state = [[],[]]
interactive_state = {
"inference_times": 0,
"negative_click_times" : 0,
"positive_click_times": 0,
"mask_save": args.mask_save,
"multi_mask": {
"mask_names": [],
"masks": []
},
"track_end_number": None,
"resize_ratio": 1
}
###################################################################################################################
video_path = os.path.join(sub_folder_path, video_name)
if not os.path.exists(video_path):
print("We cannot find the path of the ", video_path, " and we will compress one")
status = compress_video(sub_folder_path, video_name)
if not status:
print("We still cannot generate a video")
continue
# Read video state
video_state = {
"user_name": "",
"video_name": "",
"origin_images": None,
"painted_images": None,
"masks": None,
"inpaint_masks": None,
"logits": None,
"select_frame_number": 0,
"fps": 30
}
video_state, template_frame = get_frames_from_video(video_path, video_state, track_model)
########################################################## Get the sam point based on the data.txt ###########################################################
data_txt_path = os.path.join(sub_folder_path, "data.txt")
if not os.path.exists(data_txt_path):
print("We cannot find data.txt in this folder")
continue
data_file = open(data_txt_path, 'r')
lines = data_file.readlines()
frame_idx, horizontal, vertical = lines[0][:-2].split(' ') # Only read the first point
point_cord = [int(float(horizontal)), int(float(vertical))]
# Process by SAM
track_model.samcontroler.sam_controler.reset_image() # Reset the image to clean history
painted_image, video_state, interactive_state, operation_log = sam_refine(track_model, video_state, "Positive", click_state, interactive_state, point_cord)
################################################################################################################################################################
######################################################### Get the tracking output ########################################################################
# Track the video for processing
segment_output_path = os.path.join(sub_folder_path, "segment_output.gif")
video_state = vos_tracking_video(track_model, segment_output_path, video_state, interactive_state, mask_dropdown=[])[0] # mask_dropdown is empty now
# Extract the mask needed by us for further point calculating
masks = video_state["masks"] # In the range [0, 1]
if verbose:
for idx, mask in enumerate(masks):
cv2.imwrite(os.path.join(sub_folder_path, "mask"+str(idx)+".png"), mask*255)
##############################################################################################################################################################