shin-mashita commited on
Commit
2a7c856
1 Parent(s): abe3bfd

Added documentation

Browse files
Files changed (1) hide show
  1. app.py +33 -12
app.py CHANGED
@@ -9,33 +9,39 @@ from pytorch_i3d import InceptionI3d
9
 
10
 
11
  def preprocess(vidpath):
 
12
  cap = cv2.VideoCapture(vidpath)
13
 
14
  frames = []
15
  cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
16
  num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
17
-
 
18
  for _ in range(num):
19
  _, img = cap.read()
20
 
 
21
  if img is None:
22
  continue
23
 
 
24
  w, h, c = img.shape
25
  if w < 226 or h < 226:
26
  d = 226. - min(w, h)
27
  sc = 1 + d / min(w, h)
28
  img = cv2.resize(img, dsize=(0, 0), fx=sc, fy=sc)
 
 
29
  img = (img / 255.) * 2 - 1
30
 
31
  frames.append(img)
32
 
33
- # frames = torch.cuda.FloatTensor(np.asarray(frames, dtype=np.float32)) if torch.cuda.is_available() else torch.Tensor(np.asarray(frames, dtype=np.float32))
34
  frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
35
 
 
36
  transform = transforms.Compose([videotransforms.CenterCrop(224)])
37
  frames = transform(frames)
38
- frames = rearrange(frames, 't h w c-> 1 c t h w')
39
 
40
  return frames
41
 
@@ -45,42 +51,53 @@ def classify(video,dataset='WLASL100'):
45
  'WLASL2000':{'logits':2000,'path':'weights/asl2000/FINAL_nslt_2000_iters=5104_top1=32.48_top5=57.31_top10=66.31.pt'}
46
  }
47
 
 
48
  input = preprocess(video)
49
 
 
50
  model = InceptionI3d()
51
  model.load_state_dict(torch.load('weights/rgb_imagenet.pt',map_location=torch.device('cpu')))
52
  model.replace_logits(to_load[dataset]['logits'])
53
  model.load_state_dict(torch.load(to_load[dataset]['path'],map_location=torch.device('cpu')))
54
 
55
- # device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
56
- # model.to(device)
57
  model.cpu()
 
 
58
  model.eval()
59
 
60
- with torch.no_grad():
61
- per_frame_logits = model(input)
62
 
63
  per_frame_logits.cpu()
64
  model.cpu()
65
 
 
66
  predictions = rearrange(per_frame_logits,'1 j k -> j k')
67
  predictions = torch.mean(predictions, dim = 1)
68
 
69
- top = torch.argmax(predictions).item()
70
  _, index = torch.topk(predictions,10)
71
  index = index.cpu().numpy()
72
 
 
73
  with open('wlasl_class_list.txt') as f:
74
  idx2label = dict()
75
  for line in f:
76
  idx2label[int(line.split()[0])]=line.split()[1]
77
-
 
78
  predictions = torch.nn.functional.softmax(predictions, dim=0).cpu().numpy()
79
 
 
80
  return {idx2label[i]:float(predictions[i]) for i in index}
81
 
 
82
  title = "I3D Sign Language Recognition"
83
- description = "Description here"
 
 
 
84
  examples = [
85
  ['videos/no.mp4','WLASL100'],
86
  ['videos/all.mp4','WLASL100'],
@@ -90,11 +107,15 @@ examples = [
90
  ['videos/accident2.mp4','WLASL2000']
91
  ]
92
 
 
 
93
 
 
94
  gr.Interface( fn=classify,
95
- inputs=[gr.inputs.Video(label="VIDEO"),gr.inputs.Dropdown(choices=['WLASL100','WLASL2000'], default='WLASL100', label='DATASET USED')],
96
  outputs=[gr.outputs.Label(num_top_classes=5, label='Top 5 Predictions')],
97
  allow_flagging="never",
98
  title=title,
99
  description=description,
100
- examples=examples).launch()
 
 
9
 
10
 
11
  def preprocess(vidpath):
12
+ # Fetch video
13
  cap = cv2.VideoCapture(vidpath)
14
 
15
  frames = []
16
  cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
17
  num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
18
+
19
+ # Extract frames from video
20
  for _ in range(num):
21
  _, img = cap.read()
22
 
23
+ # Skip NoneType frames
24
  if img is None:
25
  continue
26
 
27
+ # Resize if (w,h) < (226,226)
28
  w, h, c = img.shape
29
  if w < 226 or h < 226:
30
  d = 226. - min(w, h)
31
  sc = 1 + d / min(w, h)
32
  img = cv2.resize(img, dsize=(0, 0), fx=sc, fy=sc)
33
+
34
+ # Normalize
35
  img = (img / 255.) * 2 - 1
36
 
37
  frames.append(img)
38
 
 
39
  frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
40
 
41
+ # Transform tensor and reshape to (1, c, t ,w, h)
42
  transform = transforms.Compose([videotransforms.CenterCrop(224)])
43
  frames = transform(frames)
44
+ frames = rearrange(frames, 't w h c-> 1 c t w h')
45
 
46
  return frames
47
 
 
51
  'WLASL2000':{'logits':2000,'path':'weights/asl2000/FINAL_nslt_2000_iters=5104_top1=32.48_top5=57.31_top10=66.31.pt'}
52
  }
53
 
54
+ # Preprocess video
55
  input = preprocess(video)
56
 
57
+ # Load model
58
  model = InceptionI3d()
59
  model.load_state_dict(torch.load('weights/rgb_imagenet.pt',map_location=torch.device('cpu')))
60
  model.replace_logits(to_load[dataset]['logits'])
61
  model.load_state_dict(torch.load(to_load[dataset]['path'],map_location=torch.device('cpu')))
62
 
63
+ # Run on cpu. Spaces environment is limited to CPU for free users.
 
64
  model.cpu()
65
+
66
+ # Evaluation mode
67
  model.eval()
68
 
69
+ with torch.no_grad(): # Disable gradient computation
70
+ per_frame_logits = model(input) # Inference
71
 
72
  per_frame_logits.cpu()
73
  model.cpu()
74
 
75
+ # Load predictions
76
  predictions = rearrange(per_frame_logits,'1 j k -> j k')
77
  predictions = torch.mean(predictions, dim = 1)
78
 
79
+ # Fetch top 10 predictions
80
  _, index = torch.topk(predictions,10)
81
  index = index.cpu().numpy()
82
 
83
+ # Load labels
84
  with open('wlasl_class_list.txt') as f:
85
  idx2label = dict()
86
  for line in f:
87
  idx2label[int(line.split()[0])]=line.split()[1]
88
+
89
+ # Get probabilities
90
  predictions = torch.nn.functional.softmax(predictions, dim=0).cpu().numpy()
91
 
92
+ # Return dict {label:pred}
93
  return {idx2label[i]:float(predictions[i]) for i in index}
94
 
95
+ # Gradio App config
96
  title = "I3D Sign Language Recognition"
97
+ description = "Gradio demo of word-level sign language classification using I3D model pretrained on the WLASL video dataset. " \
98
+ "WLASL is a large-scale dataset containing more than 2000 words in American Sign Language. " \
99
+ "Examples used in the demo are videos from the the test subset. " \
100
+ "Note that WLASL100 contains 100 words while WLASL2000 contains 2000."
101
  examples = [
102
  ['videos/no.mp4','WLASL100'],
103
  ['videos/all.mp4','WLASL100'],
 
107
  ['videos/accident2.mp4','WLASL2000']
108
  ]
109
 
110
+ article = "NOTE: This is not the official demonstration of the I3D sign language classification on the WLASL dataset. "\
111
+ "More information about the WLASL dataset and pretrained I3D models can be found <a href=https://github.com/dxli94/WLASL>here</a>."
112
 
113
+ # Gradio App interface
114
  gr.Interface( fn=classify,
115
+ inputs=[gr.inputs.Video(label="Video (*.mp4)"),gr.inputs.Radio(choices=['WLASL100','WLASL2000'], default='WLASL100', label='Trained on:')],
116
  outputs=[gr.outputs.Label(num_top_classes=5, label='Top 5 Predictions')],
117
  allow_flagging="never",
118
  title=title,
119
  description=description,
120
+ examples=examples,
121
+ article=article).launch()