shin-mashita
Minor edits
bb83661
raw
history blame
No virus
2.78 kB
import torch
import cv2
import videotransforms
import numpy as np
import gradio as gr
from einops import rearrange
from torchvision import transforms
from pytorch_i3d import InceptionI3d
def preprocess(vidpath):
cap = cv2.VideoCapture(vidpath)
frames = []
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
for _ in range(num):
_, img = cap.read()
w, h, c = img.shape
if w < 226 or h < 226:
d = 226. - min(w, h)
sc = 1 + d / min(w, h)
img = cv2.resize(img, dsize=(0, 0), fx=sc, fy=sc)
img = (img / 255.) * 2 - 1
frames.append(img)
frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
transform = transforms.Compose([videotransforms.CenterCrop(224)])
frames = transform(frames)
frames = rearrange(frames, 't h w c-> 1 c t h w')
return frames
def classify(video,dataset='WLASL100'):
to_load = {
'WLASL100':{'logits':100,'path':'weights/asl100/FINAL_nslt_100_iters=896_top1=65.89_top5=84.11_top10=89.92.pt'},
'WLASL2000':{'logits':2000,'path':'weights/asl2000/FINAL_nslt_2000_iters=5104_top1=32.48_top5=57.31_top10=66.31.pt'}
}
input = preprocess(video)
model = InceptionI3d()
model.load_state_dict(torch.load('weights/rgb_imagenet.pt'))
model.replace_logits(to_load[dataset]['logits'])
model.load_state_dict(torch.load(to_load[dataset]['path']))
model.eval()
with torch.no_grad():
per_frame_logits = model(input)
predictions = rearrange(per_frame_logits,'1 j k -> j k')
predictions = torch.mean(predictions, dim = 1)
top = torch.argmax(predictions).item()
_, index = torch.topk(predictions,10)
index = index.numpy()
with open('wlasl_class_list.txt') as f:
idx2label = dict()
for line in f:
idx2label[int(line.split()[0])]=line.split()[1]
predictions = torch.nn.functional.softmax(predictions, dim=0).numpy()
return {idx2label[i]:float(predictions[i]) for i in index}
title = "I3D Sign Language Recognition"
description = "Description here"
examples = [['videos/no.mp4','WLASL100'],['videos/all.mp4','WLASL100'],['videos/blue.mp4','WLASL2000'],['videos/white.mp4','WLASL2000'],['videos/accident.mp4','WLASL2000']]
gr.Interface( fn=classify,
inputs=[gr.inputs.Video(label="VIDEO"),gr.inputs.Dropdown(choices=['WLASL100','WLASL2000'], default='WLASL100', label='DATASET USED')],
outputs=[gr.outputs.Label(num_top_classes=5, label='Top 5 Predictions')],
allow_flagging="never",
title=title,
description=description,
examples=examples).launch(cache_examples=True)