Spaces:
Runtime error
Runtime error
### demo.py | |
# Define model classes for inference. | |
### | |
import json | |
import torch | |
import torch.nn as nn | |
import torch.backends.cudnn as cudnn | |
from einops import rearrange | |
from transformers import BertTokenizer | |
from torchvision import transforms | |
from torchvision.transforms._transforms_video import ( | |
NormalizeVideo, | |
) | |
from svitt.model import SViTT | |
from svitt.config import load_cfg, setup_config | |
from svitt.base_dataset import read_frames_cv2_egoclip | |
class VideoModel(nn.Module): | |
""" Base model for video understanding based on SViTT architecture. """ | |
def __init__(self, config): | |
""" Initializes the model. | |
Parameters: | |
config: config file | |
""" | |
super().__init__() | |
self.cfg = load_cfg(config) | |
self.model = self.build_model() | |
use_gpu = torch.cuda.is_available() | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if use_gpu: | |
self.model = self.model.to(self.device) | |
self.templates = ['{}'] | |
self.dataset = self.cfg['data']['dataset'] | |
self.eval() | |
def build_model(self): | |
cfg = self.cfg | |
if cfg['model'].get('pretrain', False): | |
ckpt_path = cfg['model']['pretrain'] | |
else: | |
raise Exception('no checkpoint found') | |
if cfg['model'].get('config', False): | |
config_path = cfg['model']['config'] | |
else: | |
raise Exception('no model config found') | |
self.model_cfg = setup_config(config_path) | |
self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder) | |
model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer) | |
print(f"Loading checkpoint from {ckpt_path}") | |
checkpoint = torch.load(ckpt_path, map_location="cpu") | |
state_dict = checkpoint["model"] | |
# fix for zero-shot evaluation | |
for key in list(state_dict.keys()): | |
if "bert" in key: | |
encoder_key = key.replace("bert.", "") | |
state_dict[encoder_key] = state_dict[key] | |
if torch.cuda.is_available(): | |
model.cuda() | |
model.load_state_dict(state_dict, strict=False) | |
return model | |
def eval(self): | |
cudnn.benchmark = True | |
for p in self.model.parameters(): | |
p.requires_grad = False | |
self.model.eval() | |
class VideoCLSModel(VideoModel): | |
""" Video model for video classification tasks (Charades-Ego, EGTEA). """ | |
def __init__(self, config, sample_videos): | |
super().__init__(config) | |
self.sample_videos = sample_videos | |
self.video_transform = self.init_video_transform() | |
#def load_data(self, idx=None): | |
# filename = f"{self.cfg['data']['root']}/{idx}/tensors.pt" | |
# return torch.load(filename) | |
def init_video_transform(self, | |
input_res=224, | |
center_crop=256, | |
norm_mean=(0.485, 0.456, 0.406), | |
norm_std=(0.229, 0.224, 0.225), | |
): | |
print('Video Transform is used!') | |
normalize = NormalizeVideo(mean=norm_mean, std=norm_std) | |
return transforms.Compose( | |
[ | |
transforms.Resize(center_crop), | |
transforms.CenterCrop(center_crop), | |
transforms.Resize(input_res), | |
normalize, | |
] | |
) | |
def load_data(self, idx): | |
num_frames = self.model_cfg.video_input.num_frames | |
video_paths = self.sample_videos[idx] | |
clips = [None] * len(video_paths) | |
for i, path in enumerate(video_paths): | |
imgs = read_frames_cv2_egoclip(path, num_frames, 'uniform') | |
imgs = imgs.transpose(0, 1) | |
imgs = self.video_transform(imgs) | |
imgs = imgs.transpose(0, 1) | |
clips[i] = imgs | |
return torch.stack(clips) | |
def load_meta(self, idx=None): | |
filename = f"{self.cfg['data']['root']}/{idx}/meta.json" | |
with open(filename, "r") as f: | |
meta = json.load(f) | |
return meta | |
def get_text_features(self, text): | |
print('=> Extracting text features') | |
embeddings = self.tokenizer( | |
text, | |
padding="max_length", | |
truncation=True, | |
max_length=self.model_cfg.max_txt_l.video, | |
return_tensors="pt", | |
).to(self.device) | |
_, class_embeddings = self.model.encode_text(embeddings) | |
return class_embeddings | |
def forward(self, idx, text=None): | |
print('=> Start forwarding') | |
meta = self.load_meta(idx) | |
clips = self.load_data(idx) | |
if text is None: | |
text = meta["text"][4:] | |
text_features = self.get_text_features(text) | |
target = meta["correct"] | |
# encode images | |
pooled_image_feat_all = [] | |
for i in range(clips.shape[0]): | |
images = clips[i,:].unsqueeze(0).to(self.device) | |
bsz = images.shape[0] | |
_, pooled_image_feat, *outputs = self.model.encode_image(images) | |
if pooled_image_feat.ndim == 3: | |
pooled_image_feat = rearrange(pooled_image_feat, '(b k) n d -> b (k n) d', b=bsz) | |
else: | |
pooled_image_feat = rearrange(pooled_image_feat, '(b k) d -> b k d', b=bsz) | |
pooled_image_feat_all.append(pooled_image_feat) | |
pooled_image_feat_all = torch.cat(pooled_image_feat_all, dim=0) | |
similarity = self.model.get_sim(pooled_image_feat_all, text_features)[0] | |
return similarity.argmax(), target | |
def predict(self, idx, text=None): | |
output, target = self.forward(idx, text) | |
return output.cpu().numpy(), target | |