Irpan
init commit
7fa2fcb
from ultralytics import YOLO
import supervision as sv
import pickle
import os
import numpy as np
import pandas as pd
import cv2
import sys
sys.path.append('../')
from utils import get_center_of_bbox, get_bbox_width, get_foot_position
class Tracker:
def __init__(self, model_path):
self.model = YOLO(model_path)
self.tracker = sv.ByteTrack()
def add_position_to_tracks(sekf,tracks):
for object, object_tracks in tracks.items():
for frame_num, track in enumerate(object_tracks):
for track_id, track_info in track.items():
bbox = track_info['bbox']
if object == 'ball':
position= get_center_of_bbox(bbox)
else:
position = get_foot_position(bbox)
tracks[object][frame_num][track_id]['position'] = position
def interpolate_ball_positions(self,ball_positions):
ball_positions = [x.get(1,{}).get('bbox',[]) for x in ball_positions]
df_ball_positions = pd.DataFrame(ball_positions,columns=['x1','y1','x2','y2'])
# Interpolate missing values
df_ball_positions = df_ball_positions.interpolate()
df_ball_positions = df_ball_positions.bfill()
ball_positions = [{1: {"bbox":x}} for x in df_ball_positions.to_numpy().tolist()]
return ball_positions
def detect_frames(self, frames):
batch_size=20
detections = []
for i in range(0,len(frames),batch_size):
detections_batch = self.model.predict(frames[i:i+batch_size],conf=0.1)
detections += detections_batch
return detections
def get_object_tracks(self, frames, read_from_stub=False, stub_path=None):
if read_from_stub and stub_path is not None and os.path.exists(stub_path):
with open(stub_path,'rb') as f:
tracks = pickle.load(f)
return tracks
detections = self.detect_frames(frames)
tracks={
"players":[],
"referees":[],
"ball":[]
}
print(len(tracks['players']), len(tracks['referees']), len(tracks['ball']))
for frame_num, detection in enumerate(detections):
cls_names = detection.names
cls_names_inv = {v:k for k,v in cls_names.items()}
# Covert to supervision Detection format
detection_supervision = sv.Detections.from_ultralytics(detection)
# Convert GoalKeeper to player object
for object_ind , class_id in enumerate(detection_supervision.class_id):
if cls_names[class_id] == "goalkeeper":
detection_supervision.class_id[object_ind] = cls_names_inv["player"]
# Track Objects
detection_with_tracks = self.tracker.update_with_detections(detection_supervision)
tracks["players"].append({})
tracks["referees"].append({})
tracks["ball"].append({})
for frame_detection in detection_with_tracks:
bbox = frame_detection[0].tolist()
cls_id = frame_detection[3]
track_id = frame_detection[4]
if cls_id == cls_names_inv['player']:
tracks["players"][frame_num][track_id] = {"bbox":bbox}
if cls_id == cls_names_inv['referee']:
tracks["referees"][frame_num][track_id] = {"bbox":bbox}
for frame_detection in detection_supervision:
bbox = frame_detection[0].tolist()
cls_id = frame_detection[3]
if cls_id == cls_names_inv['ball']:
tracks["ball"][frame_num][1] = {"bbox":bbox}
if stub_path is not None:
with open(stub_path,'wb') as f:
pickle.dump(tracks,f)
return tracks
def draw_ellipse(self,frame,bbox,color,track_id=None):
y2 = int(bbox[3])
x_center, _ = get_center_of_bbox(bbox)
width = get_bbox_width(bbox)
cv2.ellipse(
frame,
center=(x_center,y2),
axes=(int(width), int(0.35*width)),
angle=0.0,
startAngle=-45,
endAngle=235,
color = color,
thickness=2,
lineType=cv2.LINE_4
)
rectangle_width = 40
rectangle_height=20
x1_rect = x_center - rectangle_width//2
x2_rect = x_center + rectangle_width//2
y1_rect = (y2- rectangle_height//2) +15
y2_rect = (y2+ rectangle_height//2) +15
if track_id is not None:
cv2.rectangle(frame,
(int(x1_rect),int(y1_rect) ),
(int(x2_rect),int(y2_rect)),
color,
cv2.FILLED)
x1_text = x1_rect+12
if track_id > 99:
x1_text -=10
cv2.putText(
frame,
f"{track_id}",
(int(x1_text),int(y1_rect+15)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0,0,0),
2
)
return frame
def draw_traingle(self,frame,bbox,color):
y= int(bbox[1])
x,_ = get_center_of_bbox(bbox)
triangle_points = np.array([
[x,y],
[x-10,y-20],
[x+10,y-20],
])
cv2.drawContours(frame, [triangle_points],0,color, cv2.FILLED)
cv2.drawContours(frame, [triangle_points],0,(0,0,0), 2)
return frame
def draw_team_ball_control(self,frame,frame_num,team_ball_control):
# Draw a semi-transparent rectaggle
overlay = frame.copy()
cv2.rectangle(overlay, (1350, 850), (1900,970), (255,255,255), -1 )
alpha = 0.4
cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
team_ball_control_till_frame = team_ball_control[:frame_num+1]
# Get the number of time each team had ball control
team_1_num_frames = team_ball_control_till_frame[team_ball_control_till_frame==1].shape[0]
team_2_num_frames = team_ball_control_till_frame[team_ball_control_till_frame==2].shape[0]
team_1 = team_1_num_frames/(team_1_num_frames+team_2_num_frames)
team_2 = team_2_num_frames/(team_1_num_frames+team_2_num_frames)
cv2.putText(frame, f"Team 1 Possession: {team_1*100:.2f}%",(1400,900), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 3)
cv2.putText(frame, f"Team 2 Possession: {team_2*100:.2f}%",(1400,950), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 3)
return frame
def draw_annotations(self,video_frames, tracks,team_ball_control):
output_video_frames= []
for frame_num, frame in enumerate(video_frames):
frame = frame.copy()
player_dict = tracks["players"][frame_num]
ball_dict = tracks["ball"][frame_num]
referee_dict = tracks["referees"][frame_num]
# Draw Players
for track_id, player in player_dict.items():
color = player.get("team_color",(0,0,255))
frame = self.draw_ellipse(frame, player["bbox"],color, track_id)
if player.get('has_ball',False):
frame = self.draw_traingle(frame, player["bbox"],(0,0,255))
# Draw Referee
for _, referee in referee_dict.items():
frame = self.draw_ellipse(frame, referee["bbox"],(0,255,255))
# Draw ball
for track_id, ball in ball_dict.items():
frame = self.draw_traingle(frame, ball["bbox"],(255,255,50))
# Draw Team Ball Control
frame = self.draw_team_ball_control(frame, frame_num, team_ball_control)
output_video_frames.append(frame)
return output_video_frames