File size: 3,764 Bytes
c592663
e006d42
29cb7aa
 
 
99f44fe
29cb7aa
 
99f44fe
29cb7aa
 
c9b0a28
 
871946a
c9b0a28
 
93e8a7e
 
29cb7aa
e006d42
29cb7aa
 
 
e006d42
29cb7aa
e006d42
7fcc53a
29cb7aa
e006d42
29cb7aa
7fcc53a
e006d42
 
 
29cb7aa
 
 
 
 
 
 
 
 
 
 
 
93e8a7e
29cb7aa
 
93e8a7e
 
 
29cb7aa
93e8a7e
 
 
29cb7aa
c9b0a28
 
29cb7aa
93e8a7e
c9b0a28
 
29cb7aa
93e8a7e
 
 
29cb7aa
 
 
 
 
 
 
 
 
 
 
 
 
 
e006d42
 
 
 
 
 
 
29cb7aa
93e8a7e
 
29cb7aa
 
 
99f44fe
 
 
93e8a7e
 
99f44fe
29cb7aa
 
99f44fe
 
29cb7aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor

# Model IDs for video classification (UCF101 subset)
classification_model_id = "MCG-NJU/videomae-base"

# Object detection model (you can replace this with a more accurate one if needed)
object_detection_model = "yolov5s"

# Parameters for frame extraction
TARGET_FRAME_COUNT = 16
FRAME_SIZE = (224, 224)  # Expected frame size for the model

def analyze_video(video):
    # Extract key frames from the video using OpenCV
    frames = extract_key_frames(video)

    # Load classification model and image processor
    classification_model = VideoMAEForVideoClassification.from_pretrained(classification_model_id)
    processor = VideoMAEImageProcessor.from_pretrained(classification_model_id)

    # Prepare frames for the classification model
    inputs = processor(images=frames, return_tensors="pt")

    # Make predictions using the classification model
    with torch.no_grad():
        outputs = classification_model(**inputs)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    
    # Object detection and tracking (ball and baseman)
    object_detection_results = []
    for frame in frames:
        ball_position = detect_object(frame, "ball")
        baseman_position = detect_object(frame, "baseman")
        object_detection_results.append((ball_position, baseman_position))

    # Analyze predictions and object detection results
    analysis_results = []
    for prediction, (ball_position, baseman_position) in zip(predictions, object_detection_results):
        result = analyze_frame(prediction.item(), ball_position, baseman_position)
        analysis_results.append(result)

    # Aggregate analysis results
    final_result = aggregate_results(analysis_results)

    return final_result

def extract_key_frames(video):
    cap = cv2.VideoCapture(video)
    frames = []
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    interval = max(1, frame_count // TARGET_FRAME_COUNT)
    
    for i in range(frame_count):
        ret, frame = cap.read()
        if ret and i % interval == 0:  # Extract frames at regular intervals
            frame = cv2.resize(frame, FRAME_SIZE)  # Resize frame
            frames.append(frame)
    cap.release()
    return frames

def detect_object(frame, object_type):
    # Placeholder function for object detection (replace with actual implementation)
    # Here, we assume that the object is detected at the center of the frame
    h, w, _ = frame.shape
    if object_type == "ball":
        return (w // 2, h // 2)  # Return center coordinates for the ball
    elif object_type == "baseman":
        return (w // 2, h // 2)  # Return center coordinates for the baseman
    else:
        return None

def analyze_frame(prediction, ball_position, baseman_position):
    # Placeholder function for analyzing a single frame
    # You can replace this with actual logic based on your requirements
    action_labels = {
        0: "running",
        1: "sliding",
        2: "jumping",
        # Add more labels as necessary
    }
    action = action_labels.get(prediction, "unknown")
    return {"action": action, "ball_position": ball_position, "baseman_position": baseman_position}

def aggregate_results(results):
    # Placeholder function for aggregating analysis results
    # You can implement this based on your specific requirements
    return results

# Gradio interface
interface = gr.Interface(
    fn=analyze_video,
    inputs="video",
    outputs="text",
    title="Baseball Play Analysis",
    description="Upload a video of a baseball play to analyze.",
)

interface.launch()