UniVTG / run_on_video /video_extractor.py
KevinQHLin's picture
Update run_on_video/video_extractor.py
e9f28ca
import pdb
import torch as th
import math
import numpy as np
import torch
from run_on_video.video_loader import VideoLoader
from torch.utils.data import DataLoader
import argparse
from run_on_video.preprocessing import Preprocessing
import torch.nn.functional as F
from tqdm import tqdm
import os
import sys
from run_on_video import clip
import argparse
#################################
@torch.no_grad()
def vid2clip(model, vid_path, output_file,
model_version="ViT-B/32", output_feat_size=512,
clip_len=2, overwrite=True, num_decoding_thread=4, half_precision=False):
dataset = VideoLoader(
vid_path,
framerate=1/clip_len,
size=224,
centercrop=True,
overwrite=overwrite,
model_version=model_version
)
n_dataset = len(dataset)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=num_decoding_thread,
sampler=None,
)
preprocess = Preprocessing()
device_id = next(model.parameters()).device
totatl_num_frames = 0
with th.no_grad():
for k, data in enumerate(tqdm(loader)):
input_file = data['input'][0]
if os.path.isfile(output_file):
# print(f'Video {input_file} already processed.')
continue
elif not os.path.isfile(input_file):
print(f'{input_file}, does not exist.\n')
elif len(data['video'].shape) > 4:
video = data['video'].squeeze(0)
if len(video.shape) == 4:
video = preprocess(video)
n_chunk = len(video)
vid_features = th.cuda.FloatTensor(
n_chunk, output_feat_size).fill_(0)
n_iter = int(math.ceil(n_chunk))
for i in range(n_iter):
min_ind = i
max_ind = (i + 1)
video_batch = video[min_ind:max_ind].to(device_id)
batch_features = model.encode_image(video_batch)
vid_features[min_ind:max_ind] = batch_features
vid_features = vid_features.cpu().numpy()
if half_precision:
vid_features = vid_features.astype('float16')
totatl_num_frames += vid_features.shape[0]
# safeguard output path before saving
dirname = os.path.dirname(output_file)
# if not os.path.exists(dirname):
# print(f"Output directory {dirname} does not exists, creating...")
os.makedirs(output_file, exist_ok=True)
np.savez(os.path.join(output_file, 'vid.npz'), features=vid_features)
else:
print(f'{input_file}, failed at ffprobe.\n')
print(f"Total number of frames: {totatl_num_frames}")
return vid_features
def txt2clip(model, text, output_file):
device_id = next(model.parameters()).device
encoded_texts = clip.tokenize(text).to(device_id)
text_feature = model.encode_text(encoded_texts)['last_hidden_state']
valid_lengths = (encoded_texts != 0).sum(1).tolist()[0]
text_feature = text_feature[0, :valid_lengths].detach().cpu().numpy()
np.savez(os.path.join(output_file, 'txt.npz'), features=text_feature)
return text_feature
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='')
parser.add_argument('--vid_path', type=str, default='/data/home/qinghonglin/dataset/charades/videos/Charades_v1_480/0A8CF.mp4')
parser.add_argument('--text', nargs='+', type=str, default='a boy is drinking.')
parser.add_argument('--save_dir', type=str, default='./tmp')
args = parser.parse_args()