MNGames commited on
Commit
93e8a7e
1 Parent(s): eda9783

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -13
app.py CHANGED
@@ -1,24 +1,77 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
3
 
4
- # Replace with a suitable image classification model ID
5
- model_id = "sayakpaul/resnet-50-finetuned-imagenet"
6
 
7
- def analyze_image(image):
8
- classifier = pipeline("image-classification", model=model_id)
9
- predictions = classifier(images=image) # Assuming the model outputs probabilities
10
- # Extract the most likely class and its probability
11
- top_class = predictions[0]["label"]
12
- top_prob = predictions[0]["score"]
13
- return f"Top Class: {top_class} (Probability: {top_prob:.2f})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Gradio interface
16
  interface = gr.Interface(
17
- fn=analyze_image,
18
- inputs="image",
19
  outputs="text",
20
- title="Image Analyzer (Generic)",
21
- description="Upload an image and get the most likely classification based on the chosen model.",
22
  )
23
 
24
  interface.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ import cv2 # OpenCV for video processing
4
 
5
+ # Model ID for video classification (UCF101 subset)
6
+ model_id = "sayakpaul/videomae-base-finetuned-ucf101-subset"
7
 
8
+ def analyze_video(video):
9
+ # Extract key frames from the video using OpenCV
10
+ frames = extract_key_frames(video)
11
+
12
+ # Analyze key frames using video classification model
13
+ results = []
14
+ classifier = pipeline("video-classification", model=model_id)
15
+ for frame in frames:
16
+ predictions = classifier(images=frame) # Assuming model outputs probabilities
17
+ # Analyze predictions for insights related to the play
18
+ result = analyze_predictions_ucf101(predictions)
19
+ results.append(result)
20
+
21
+ # Aggregate results across frames and provide a final analysis
22
+ final_result = aggregate_results(results)
23
+
24
+ return final_result
25
+
26
+ def extract_key_frames(video):
27
+ cap = cv2.VideoCapture(video)
28
+ frames = []
29
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
30
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
31
+
32
+ for i in range(frame_count):
33
+ ret, frame = cap.read()
34
+ if ret and i % (fps // 2) == 0: # Extract a frame every half second
35
+ frames.append(frame)
36
+
37
+ cap.release()
38
+ return frames
39
+
40
+ def analyze_predictions_ucf101(predictions):
41
+ # Analyze the model's predictions (probabilities) for insights relevant to baseball plays
42
+ # For simplicity, we'll assume predictions return the top-1 class
43
+ actions = [pred['label'] for pred in predictions]
44
+
45
+ relevant_actions = ["running", "sliding", "jumping"]
46
+ runner_actions = [action for action in actions if action in relevant_actions]
47
+
48
+ # Check for 'running', 'sliding' actions as key indicators for safe/out decision
49
+ if "sliding" in runner_actions:
50
+ return "potentially safe"
51
+ elif "running" in runner_actions:
52
+ return "potentially out"
53
+ else:
54
+ return "inconclusive"
55
+
56
+ def aggregate_results(results):
57
+ # Combine insights from analyzing each frame (e.g., dominant action classes, confidence scores)
58
+ safe_count = results.count("potentially safe")
59
+ out_count = results.count("potentially out")
60
+
61
+ if safe_count > out_count:
62
+ return "Safe"
63
+ elif out_count > safe_count:
64
+ return "Out"
65
+ else:
66
+ return "Inconclusive"
67
 
68
  # Gradio interface
69
  interface = gr.Interface(
70
+ fn=analyze_video,
71
+ inputs="video",
72
  outputs="text",
73
+ title="Baseball Play Analysis (UCF101 Subset Exploration)",
74
+ description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays.",
75
  )
76
 
77
  interface.launch()