MNGames commited on
Commit
e006d42
1 Parent(s): f3cd8b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -24
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
- from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor, VideoClassificationPipeline
 
3
  import cv2 # OpenCV for video processing
4
 
5
  # Model ID for video classification (UCF101 subset)
@@ -8,22 +9,25 @@ model_id = "MCG-NJU/videomae-base"
8
  def analyze_video(video):
9
  # Extract key frames from the video using OpenCV
10
  frames = extract_key_frames(video)
11
-
12
  # Load model and feature extractor manually
13
  model = VideoMAEForVideoClassification.from_pretrained(model_id)
14
- feature_extractor = VideoMAEFeatureExtractor.from_pretrained(model_id)
 
 
 
15
 
16
- # Create the pipeline
17
- classifier = VideoClassificationPipeline(model=model, feature_extractor=feature_extractor, device=-1)
 
18
 
19
- # Analyze key frames using video classification model
 
 
 
20
  results = []
21
- for frame in frames:
22
- # OpenCV uses BGR, convert to RGB for the model
23
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
24
- predictions = classifier([frame_rgb]) # Assuming model outputs probabilities
25
- # Analyze predictions for insights related to the play
26
- result = analyze_predictions_ucf101(predictions)
27
  results.append(result)
28
 
29
  # Aggregate results across frames and provide a final analysis
@@ -40,24 +44,29 @@ def extract_key_frames(video):
40
  for i in range(frame_count):
41
  ret, frame = cap.read()
42
  if ret and i % (fps // 2) == 0: # Extract a frame every half second
43
- frames.append(frame)
44
 
45
  cap.release()
46
  return frames
47
 
48
- def analyze_predictions_ucf101(predictions):
49
- # Analyze the model's predictions (probabilities) for insights relevant to baseball plays
50
- # For simplicity, we'll assume predictions return the top-1 class
51
- actions = [pred['label'] for pred in predictions]
 
 
 
 
 
52
 
53
  relevant_actions = ["running", "sliding", "jumping"]
54
- runner_actions = [action for action in actions if action in relevant_actions]
55
-
56
- # Check for 'running', 'sliding' actions as key indicators for safe/out decision
57
- if "sliding" in runner_actions:
58
- return "potentially safe"
59
- elif "running" in runner_actions:
60
- return "potentially out"
61
  else:
62
  return "inconclusive"
63
 
@@ -80,6 +89,7 @@ interface = gr.Interface(
80
  outputs="text",
81
  title="Baseball Play Analysis (UCF101 Subset Exploration)",
82
  description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays.",
 
83
  )
84
 
85
  interface.launch()
 
1
  import gradio as gr
2
+ from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
3
+ import torch
4
  import cv2 # OpenCV for video processing
5
 
6
  # Model ID for video classification (UCF101 subset)
 
9
  def analyze_video(video):
10
  # Extract key frames from the video using OpenCV
11
  frames = extract_key_frames(video)
12
+
13
  # Load model and feature extractor manually
14
  model = VideoMAEForVideoClassification.from_pretrained(model_id)
15
+ processor = VideoMAEImageProcessor.from_pretrained(model_id)
16
+
17
+ # Prepare frames for the model
18
+ inputs = processor(images=frames, return_tensors="pt")
19
 
20
+ # Make predictions
21
+ with torch.no_grad():
22
+ outputs = model(**inputs)
23
 
24
+ logits = outputs.logits
25
+ predictions = torch.argmax(logits, dim=-1)
26
+
27
+ # Analyze predictions for insights related to the play
28
  results = []
29
+ for prediction in predictions:
30
+ result = analyze_predictions_ucf101(prediction.item())
 
 
 
 
31
  results.append(result)
32
 
33
  # Aggregate results across frames and provide a final analysis
 
44
  for i in range(frame_count):
45
  ret, frame = cap.read()
46
  if ret and i % (fps // 2) == 0: # Extract a frame every half second
47
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB
48
 
49
  cap.release()
50
  return frames
51
 
52
+ def analyze_predictions_ucf101(prediction):
53
+ # Map prediction to action labels (this mapping is hypothetical)
54
+ action_labels = {
55
+ 0: "running",
56
+ 1: "sliding",
57
+ 2: "jumping",
58
+ # Add more labels as necessary
59
+ }
60
+ action = action_labels.get(prediction, "unknown")
61
 
62
  relevant_actions = ["running", "sliding", "jumping"]
63
+ if action in relevant_actions:
64
+ if action == "sliding":
65
+ return "potentially safe"
66
+ elif action == "running":
67
+ return "potentially out"
68
+ else:
69
+ return "inconclusive"
70
  else:
71
  return "inconclusive"
72
 
 
89
  outputs="text",
90
  title="Baseball Play Analysis (UCF101 Subset Exploration)",
91
  description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays.",
92
+ share=True
93
  )
94
 
95
  interface.launch()