|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import argparse |
|
import imageio.v3 as iio |
|
import numpy as np |
|
|
|
from cotracker.utils.visualizer import Visualizer |
|
from cotracker.predictor import CoTrackerOnlinePredictor |
|
|
|
|
|
|
|
|
|
DEFAULT_DEVICE = ( |
|
|
|
"cuda" |
|
if torch.cuda.is_available() |
|
else "cpu" |
|
) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--video_path", |
|
default="./assets/apple.mp4", |
|
help="path to a video", |
|
) |
|
parser.add_argument( |
|
"--checkpoint", |
|
default=None, |
|
help="CoTracker model parameters", |
|
) |
|
parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size") |
|
parser.add_argument( |
|
"--grid_query_frame", |
|
type=int, |
|
default=0, |
|
help="Compute dense and grid tracks starting from this frame", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
if not os.path.isfile(args.video_path): |
|
raise ValueError("Video file does not exist") |
|
|
|
if args.checkpoint is not None: |
|
model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint) |
|
else: |
|
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online") |
|
model = model.to(DEFAULT_DEVICE) |
|
|
|
window_frames = [] |
|
|
|
def _process_step(window_frames, is_first_step, grid_size, grid_query_frame): |
|
video_chunk = ( |
|
torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE) |
|
.float() |
|
.permute(0, 3, 1, 2)[None] |
|
) |
|
return model( |
|
video_chunk, |
|
is_first_step=is_first_step, |
|
grid_size=grid_size, |
|
grid_query_frame=grid_query_frame, |
|
) |
|
|
|
|
|
is_first_step = True |
|
for i, frame in enumerate( |
|
iio.imiter( |
|
args.video_path, |
|
plugin="FFMPEG", |
|
) |
|
): |
|
if i % model.step == 0 and i != 0: |
|
pred_tracks, pred_visibility = _process_step( |
|
window_frames, |
|
is_first_step, |
|
grid_size=args.grid_size, |
|
grid_query_frame=args.grid_query_frame, |
|
) |
|
is_first_step = False |
|
window_frames.append(frame) |
|
|
|
pred_tracks, pred_visibility = _process_step( |
|
window_frames[-(i % model.step) - model.step - 1 :], |
|
is_first_step, |
|
grid_size=args.grid_size, |
|
grid_query_frame=args.grid_query_frame, |
|
) |
|
|
|
print("Tracks are computed") |
|
|
|
|
|
seq_name = os.path.splitext(args.video_path.split("/")[-1])[0] |
|
video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None] |
|
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) |
|
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame, filename=seq_name) |
|
|