IbrahimHasani commited on
Commit
56de2d4
1 Parent(s): 3201059

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoProcessor, AutoModel
5
+ from PIL import Image
6
+ from decord import VideoReader, cpu
7
+
8
+ def sample_uniform_frame_indices(clip_len, seg_len):
9
+ """
10
+ Samples `clip_len` uniformly spaced frame indices from a video of length `seg_len`.
11
+ Handles edge cases where `seg_len` might be less than `clip_len`.
12
+ """
13
+ if seg_len < clip_len:
14
+ repeat_factor = np.ceil(clip_len / seg_len).astype(int)
15
+ indices = np.arange(seg_len).tolist() * repeat_factor
16
+ indices = indices[:clip_len]
17
+ else:
18
+ spacing = seg_len // clip_len
19
+ indices = [i * spacing for i in range(clip_len)]
20
+
21
+ return np.array(indices).astype(np.int64)
22
+
23
+ def read_video_decord(file_path, indices):
24
+ vr = VideoReader(file_path, num_threads=1, ctx=cpu(0))
25
+ video = vr.get_batch(indices).asnumpy()
26
+ return video
27
+
28
+ def concatenate_frames(frames, clip_len):
29
+ assert len(frames) == clip_len, f"The function expects {clip_len} frames as input."
30
+
31
+ layout = {
32
+ 32: (4, 8),
33
+ 16: (4, 4),
34
+ 8: (2, 4)
35
+ }
36
+ rows, cols = layout[clip_len]
37
+
38
+ combined_image = Image.new('RGB', (frames[0].shape[1]*cols, frames[0].shape[0]*rows))
39
+ frame_iter = iter(frames)
40
+ y_offset = 0
41
+ for i in range(rows):
42
+ x_offset = 0
43
+ for j in range(cols):
44
+ img = Image.fromarray(next(frame_iter))
45
+ combined_image.paste(img, (x_offset, y_offset))
46
+ x_offset += frames[0].shape[1]
47
+ y_offset += frames[0].shape[0]
48
+
49
+ return combined_image
50
+
51
+
52
+ def model_interface(uploaded_video, model_choice, activities):
53
+ clip_len = {
54
+ "microsoft/xclip-base-patch16-zero-shot": 32,
55
+ "microsoft/xclip-base-patch32-16-frames": 16,
56
+ "microsoft/xclip-base-patch32": 8
57
+ }.get(model_choice, 32)
58
+
59
+ indices = sample_uniform_frame_indices(clip_len, seg_len=len(VideoReader(uploaded_video)))
60
+ video = read_video_decord(uploaded_video, indices)
61
+ concatenated_image = concatenate_frames(video, clip_len) # Passed clip_len as argument
62
+
63
+ processor = AutoProcessor.from_pretrained(model_choice)
64
+ model = AutoModel.from_pretrained(model_choice)
65
+
66
+ activities_list = activities.split(",")
67
+ inputs = processor(
68
+ text=activities_list,
69
+ videos=list(video),
70
+ return_tensors="pt",
71
+ padding=True,
72
+ )
73
+
74
+ with torch.no_grad():
75
+ outputs = model(**inputs)
76
+
77
+ logits_per_video = outputs.logits_per_video
78
+ probs = logits_per_video.softmax(dim=1)
79
+
80
+ results_probs = []
81
+ results_logits = []
82
+ for i in range(len(activities_list)):
83
+ activity = activities_list[i]
84
+ prob = float(probs[0][i])
85
+ logit = float(logits_per_video[0][i])
86
+ results_probs.append((activity, f"Probability: {prob * 100:.2f}%"))
87
+ results_logits.append((activity, f"Raw Score: {logit:.2f}"))
88
+
89
+ # Retrieve most likely predicted label and its probability
90
+ max_prob_idx = probs[0].argmax().item()
91
+ most_likely_activity = activities_list[max_prob_idx]
92
+ most_likely_prob = float(probs[0][max_prob_idx])
93
+
94
+ return concatenated_image, results_probs, results_logits, (most_likely_activity, f"Probability: {most_likely_prob * 100:.2f}%")
95
+
96
+ iface = gr.Interface(
97
+ fn=model_interface,
98
+ inputs=[
99
+ gr.components.Video(label="Upload a video file"),
100
+ gr.components.Dropdown(choices=[
101
+ "microsoft/xclip-base-patch16-zero-shot",
102
+ "microsoft/xclip-base-patch32-16-frames",
103
+ "microsoft/xclip-base-patch32"
104
+ ], label="Model Choice"),
105
+ gr.components.Textbox(lines=4, label="Enter activities (comma-separated)"),
106
+ ],
107
+ outputs=[
108
+ gr.components.Image(type="pil", label="sampled frames"),
109
+ gr.components.Textbox(type="text", label="Probabilities"),
110
+ gr.components.Textbox(type="text", label="Raw Scores"),
111
+ gr.components.Textbox(type="text", label="Most Likely Prediction")
112
+ ],
113
+ live=False
114
+ )
115
+
116
+ iface.launch()