nehulagrawal's picture
Upload 2 files
b5045f8
raw
history blame
1.72 kB
import cv2
import IPython
from PIL import ImageColor
from ultralytics import YOLO
class ObjectDetection:
def __init__(self, model_name='yolov8'):
self.model_name = model_name
self.model = self.load_model()
self.classes = self.model.names
self.device = 'cpu'
def load_model(self):
model = YOLO(f"weights/{self.model_name}_best.pt")
return model
def v8_score_frame(self, frame):
results = self.model(frame)
labels = results.names[results.pred[..., -1].argmax(-1)] # Get class labels
confidences = results.pred[..., -2].max(-1) # Get confidences
coords = results.pred[..., :-2] # Get coordinates
return labels, confidences, coords
def get_coords(self, frame, row):
return int(row[0]), int(row[1]), int(row[2]), int(row[3])
def class_to_label(self, x):
return self.classes[int(x)]
def get_color(self, code):
rgb = ImageColor.getcolor(code, "RGB")
return rgb
def plot_bboxes(self, results, frame, threshold=0.5, box_color='red', text_color='white'):
labels, conf, coord = results
frame = frame.copy()
box_color = self.get_color(box_color)
text_color = self.get_color(text_color)
for i in range(len(labels)):
if conf[i] >= threshold:
x1, y1, x2, y2 = self.get_coords(frame, coord[i])
class_name = self.class_to_label(labels[i])
cv2.rectangle(frame, (x1, y1), (x2, y2), box_color, 2)
cv2.putText(frame, f"{class_name} - {conf[i]*100:.2f}%", (x1, y1), cv2.FONT_HERSHEY_COMPLEX, 0.5, text_color)
return frame