File size: 1,758 Bytes
8323f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b78d94b
8323f05
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from decord import VideoReader, cpu
import torch
import numpy as np

from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
from huggingface_hub import hf_hub_download
import gradio as gr

np.random.seed(0)

def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices


def inference(file_path):
  # video clip consists of 300 frames (10 seconds at 30 FPS)
  videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0))
  
  # sample 16 frames
  videoreader.seek(0)
  indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=len(videoreader))
  video = videoreader.get_batch(indices).asnumpy()
  
  feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
  model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
  
  inputs = feature_extractor(list(video), return_tensors="pt")
  
  with torch.no_grad():
      outputs = model(**inputs)
      logits = outputs.logits
  
  # model predicts one of the 400 Kinetics-400 classes
  predicted_label = logits.argmax(-1).item()
  return model.config.id2label[predicted_label]
 
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            video = gr.Video()
            btn = gr.Button(value="Run")
        with gr.Column():
            label = gr.Textbox(label="Predicted Label")

    btn.click(inference, inputs=video, outputs=label)

demo.launch()