bwconrad commited on
Commit
f121c5e
1 Parent(s): d19ba33

Init commit

Browse files
Files changed (2) hide show
  1. app.py +253 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import altair as alt
2
+ import cv2
3
+ import gradio as gr
4
+ import numpy as np
5
+ import open_clip
6
+ import pandas as pd
7
+ import torch
8
+ from PIL import Image
9
+ from torch.utils.data import DataLoader, Dataset
10
+ from torchvision.transforms.functional import to_pil_image, to_tensor
11
+
12
+
13
+ def run(
14
+ path: str,
15
+ model_key: str,
16
+ text_search: str,
17
+ image_search: Image.Image,
18
+ thresh: float,
19
+ stride: int,
20
+ batch_size: int,
21
+ center_crop: bool,
22
+ ):
23
+
24
+ assert path, "An input video should be provided"
25
+ assert (
26
+ text_search is not None or image_search is not None
27
+ ), "A text or image query should be provided"
28
+
29
+ if torch.cuda.is_available():
30
+ device = torch.device("cuda")
31
+ else:
32
+ device = torch.device("cpu")
33
+
34
+ # Initialize model
35
+ name, weights = MODELS[model_key]
36
+ model, _, preprocess = open_clip.create_model_and_transforms(
37
+ name, pretrained=weights, device=device
38
+ )
39
+ model.eval()
40
+
41
+ # Remove center crop transform
42
+ if not center_crop:
43
+ del preprocess.transforms[1]
44
+
45
+ # Load video
46
+ dataset = LoadVideo(path, transforms=preprocess, vid_stride=stride)
47
+ dataloader = DataLoader(
48
+ dataset, batch_size=batch_size, shuffle=False, num_workers=0
49
+ )
50
+
51
+ # Get text query features
52
+ if text_search:
53
+ # Tokenize search phrase
54
+ tokenizer = open_clip.get_tokenizer(name)
55
+ text = tokenizer([text_search]).to(device)
56
+
57
+ # Encode text query
58
+ with torch.no_grad():
59
+ query_features = model.encode_text(text)
60
+ query_features /= query_features.norm(dim=-1, keepdim=True)
61
+
62
+ # Get image query features
63
+ else:
64
+ image = preprocess(image_search).unsqueeze(0).to(device)
65
+ with torch.no_grad():
66
+ query_features = model.encode_image(image)
67
+ query_features /= query_features.norm(dim=-1, keepdim=True)
68
+
69
+ # Encode each frame and compare with query features
70
+ matches = []
71
+ res = pd.DataFrame(columns=["Frame", "Timestamp", "Similarity"])
72
+ for image, orig, frame, timestamp in dataloader:
73
+ with torch.no_grad():
74
+ image = image.to(device)
75
+ image_features = model.encode_image(image)
76
+
77
+ image_features /= image_features.norm(dim=-1, keepdim=True)
78
+ probs = query_features.cpu().numpy() @ image_features.cpu().numpy().T
79
+ probs = probs[0]
80
+
81
+ # Save frame similarity values
82
+ df = pd.DataFrame(
83
+ {
84
+ "Frame": frame.tolist(),
85
+ "Timestamp": torch.round(timestamp / 1000, decimals=2).tolist(),
86
+ "Similarity": probs.tolist(),
87
+ }
88
+ )
89
+ res = pd.concat([res, df])
90
+
91
+ # Check if frame is over threshold
92
+ for i, p in enumerate(probs):
93
+ if p > thresh:
94
+ matches.append(to_pil_image(orig[i]))
95
+
96
+ print(f"Frames: {frame.tolist()} - Probs: {probs}")
97
+
98
+ # Create plot of similarity values
99
+ lines = (
100
+ alt.Chart(res)
101
+ .mark_line(color="firebrick")
102
+ .encode(
103
+ alt.X("Timestamp", title="Timestamp (seconds)"),
104
+ alt.Y("Similarity", scale=alt.Scale(zero=False)),
105
+ )
106
+ ).properties(width=600)
107
+ rule = alt.Chart().mark_rule(strokeDash=[6, 3], size=2).encode(y=alt.datum(thresh))
108
+
109
+ return matches[:30], lines + rule
110
+
111
+
112
+ class LoadVideo(Dataset):
113
+ def __init__(self, path, transforms, vid_stride=1):
114
+
115
+ self.transforms = transforms
116
+ self.vid_stride = vid_stride
117
+ self.cur_frame = 0
118
+ self.cap = cv2.VideoCapture(path)
119
+ self.total_frames = int(
120
+ self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride
121
+ )
122
+
123
+ def __getitem__(self, _):
124
+ # Read video
125
+ # Skip over frames
126
+ for _ in range(self.vid_stride):
127
+ self.cap.grab()
128
+ self.cur_frame += 1
129
+
130
+ # Read frame
131
+ _, img = self.cap.retrieve()
132
+ timestamp = self.cap.get(cv2.CAP_PROP_POS_MSEC)
133
+
134
+ # Convert to PIL
135
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
136
+ img = Image.fromarray(np.uint8(img))
137
+
138
+ # Apply transforms
139
+ img_t = self.transforms(img)
140
+
141
+ return img_t, to_tensor(img), self.cur_frame, timestamp
142
+
143
+ def __len__(self):
144
+ return self.total_frames
145
+
146
+
147
+ MODELS = {
148
+ "convnext_base - laion400m_s13b_b51k": ("convnext_base", "laion400m_s13b_b51k"),
149
+ "convnext_base_w - laion2b_s13b_b82k": (
150
+ "convnext_base_w",
151
+ "laion2b_s13b_b82k",
152
+ ),
153
+ "convnext_base_w - laion2b_s13b_b82k_augreg": (
154
+ "convnext_base_w",
155
+ "laion2b_s13b_b82k_augreg",
156
+ ),
157
+ "convnext_base_w - laion_aesthetic_s13b_b82k": (
158
+ "convnext_base_w",
159
+ "laion_aesthetic_s13b_b82k",
160
+ ),
161
+ "convnext_base_w_320 - laion_aesthetic_s13b_b82k": (
162
+ "convnext_base_w_320",
163
+ "laion_aesthetic_s13b_b82k",
164
+ ),
165
+ "convnext_base_w_320 - laion_aesthetic_s13b_b82k_augreg": (
166
+ "convnext_base_w_320",
167
+ "laion_aesthetic_s13b_b82k_augreg",
168
+ ),
169
+ "convnext_large_d - laion2b_s26b_b102k_augreg": (
170
+ "convnext_large_d",
171
+ "laion2b_s26b_b102k_augreg",
172
+ ),
173
+ "convnext_large_d_320 - laion2b_s29b_b131k_ft": (
174
+ "convnext_large_d_320",
175
+ "laion2b_s29b_b131k_ft",
176
+ ),
177
+ "convnext_large_d_320 - laion2b_s29b_b131k_ft_soup": (
178
+ "convnext_large_d_320",
179
+ "laion2b_s29b_b131k_ft_soup",
180
+ ),
181
+ "convnext_xxlarge - laion2b_s34b_b82k_augreg": (
182
+ "convnext_xxlarge",
183
+ "laion2b_s34b_b82k_augreg",
184
+ ),
185
+ "convnext_xxlarge - laion2b_s34b_b82k_augreg_rewind": (
186
+ "convnext_xxlarge",
187
+ "laion2b_s34b_b82k_augreg_rewind",
188
+ ),
189
+ "convnext_xxlarge - laion2b_s34b_b82k_augreg_soup": (
190
+ "convnext_xxlarge",
191
+ "laion2b_s34b_b82k_augreg_soup",
192
+ ),
193
+ }
194
+
195
+
196
+ if __name__ == "__main__":
197
+ text_app = gr.Interface(
198
+ description="Search the content's of a video with a text description.",
199
+ fn=run,
200
+ inputs=[
201
+ gr.Video(label="Video"),
202
+ gr.Dropdown(
203
+ label="Model",
204
+ choices=list(MODELS.keys()),
205
+ value="convnext_base_w - laion2b_s13b_b82k",
206
+ ),
207
+ gr.Textbox(label="Text Search Query"),
208
+ gr.Image(label="Image Search Query", visible=False),
209
+ gr.Slider(label="Threshold", maximum=1.0, value=0.3),
210
+ gr.Slider(label="Frame-rate Stride", value=4, step=1),
211
+ gr.Slider(label="Batch Size", value=4, step=1),
212
+ gr.Checkbox(label="Center Crop"),
213
+ ],
214
+ outputs=[
215
+ gr.Gallery(label="Matched Frames").style(
216
+ columns=2, object_fit="contain", height="auto"
217
+ ),
218
+ gr.Plot(label="Similarity Plot"),
219
+ ],
220
+ allow_flagging="never",
221
+ )
222
+
223
+ image_app = gr.Interface(
224
+ description="Search the content's of a video with an image query.",
225
+ fn=run,
226
+ inputs=[
227
+ gr.Video(label="Video"),
228
+ gr.Dropdown(
229
+ label="Model",
230
+ choices=list(MODELS.keys()),
231
+ value="convnext_base_w - laion2b_s13b_b82k",
232
+ ),
233
+ gr.Textbox(label="Text Search Query", visible=False),
234
+ gr.Image(label="Image Search Query", type="pil"),
235
+ gr.Slider(label="Threshold", maximum=1.0, value=0.3),
236
+ gr.Slider(label="Frame-rate Stride", value=4, step=1),
237
+ gr.Slider(label="Batch Size", value=4, step=1),
238
+ gr.Checkbox(label="Center Crop"),
239
+ ],
240
+ outputs=[
241
+ gr.Gallery(label="Matched Frames").style(
242
+ columns=2, object_fit="contain", height="auto"
243
+ ),
244
+ gr.Plot(label="Similarity Plot"),
245
+ ],
246
+ allow_flagging="never",
247
+ )
248
+ app = gr.TabbedInterface(
249
+ interface_list=[text_app, image_app],
250
+ tab_names=["Text Query Search", "Image Query Search"],
251
+ title="CLIP Video Content Search",
252
+ )
253
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ altair==4.2.2
2
+ gradio==3.27.0
3
+ numpy==1.24.2
4
+ open_clip_torch==2.16.2
5
+ opencv_python_headless==4.7.0.72
6
+ pandas==1.5.3
7
+ Pillow==9.5.0
8
+ torch==2.0.0
9
+ torchvision==0.15.1