MBase / app.py
MNGames's picture
Update app.py
16756b8 verified
raw
history blame
No virus
3.19 kB
import gradio as gr
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor, VideoClassificationPipeline
import cv2 # OpenCV for video processing
# Model ID for video classification (UCF101 subset)
model_id = "MCG-NJU/videomae-base"
def analyze_video(video):
# Extract key frames from the video using OpenCV
frames = extract_key_frames(video)
# Load model and feature extractor manually
model = VideoMAEForVideoClassification.from_pretrained(model_id)
feature_extractor = VideoMAEFeatureExtractor.from_pretrained(model_id)
# Create the pipeline
classifier = VideoClassificationPipeline(model=model, feature_extractor=feature_extractor, device=-1)
# Analyze key frames using video classification model
results = []
for frame in frames:
# OpenCV uses BGR, convert to RGB for the model
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
predictions = classifier([frame_rgb]) # Assuming model outputs probabilities
# Analyze predictions for insights related to the play
result = analyze_predictions_ucf101(predictions)
results.append(result)
# Aggregate results across frames and provide a final analysis
final_result = aggregate_results(results)
return final_result
def extract_key_frames(video):
cap = cv2.VideoCapture(video)
frames = []
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
for i in range(frame_count):
ret, frame = cap.read()
if ret and i % (fps // 2) == 0: # Extract a frame every half second
frames.append(frame)
cap.release()
return frames
def analyze_predictions_ucf101(predictions):
# Analyze the model's predictions (probabilities) for insights relevant to baseball plays
# For simplicity, we'll assume predictions return the top-1 class
actions = [pred['label'] for pred in predictions]
relevant_actions = ["running", "sliding", "jumping"]
runner_actions = [action for action in actions if action in relevant_actions]
# Check for 'running', 'sliding' actions as key indicators for safe/out decision
if "sliding" in runner_actions:
return "potentially safe"
elif "running" in runner_actions:
return "potentially out"
else:
return "inconclusive"
def aggregate_results(results):
# Combine insights from analyzing each frame (e.g., dominant action classes, confidence scores)
safe_count = results.count("potentially safe")
out_count = results.count("potentially out")
if safe_count > out_count:
return "Safe"
elif out_count > safe_count:
return "Out"
else:
return "Inconclusive"
# Gradio interface
interface = gr.Interface(
fn=analyze_video,
inputs="video",
outputs="text",
title="Baseball Play Analysis (UCF101 Subset Exploration)",
description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays.",
)
interface.launch()