Spaces:
Sleeping
Sleeping
File size: 6,515 Bytes
1e2d295 13b8932 1e2d295 e2cf29c 1e2d295 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# Created by yarramsettinaresh GORAKA DIGITAL PRIVATE LIMITED at 01/11/24
import cv2
import numpy as np
from ultralytics import YOLO
import torch
import gradio as gr
# Load your models
model = YOLO("chepala_lekka_v5_yolov11n.pt")
# Initialize global video capture variable
cap = None
text_margin = 10
class_colors = {
"chepa": (255, 0, 0), # Red color for "chepa"
"sanchi": (0, 255, 0), # Green color for "sanchi"
"other_class": (0, 0, 255) # Blue color for other classes
}
class VideoProcessor:
def __init__(self):
self.g_bags = {}
self.g_fishes = {}
self.text_margin = 10
self.class_colors = {
"chepa": (255, 0, 0), # Red for "chepa"
"sanchi": (0, 255, 0), # Green for "sanchi"
"other_class": (0, 0, 255) # Blue for other classes
}
self.cap = None
def gradio_video_stream(self, video_file):
print(f"gradio_video_stream init : {video_file}")
self.g_bags = {}
self.g_fishes = {}
self.cap = cv2.VideoCapture(video_file) # Open the uploaded video file
while True:
frame = self.process_frame()
if frame is None:
break
yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB for Gradio
def process_frame(self):
g_fishes = self.g_fishes
g_bags = self.g_bags
cap = self.cap
ret, frame = cap.read()
if not ret:
cap.release() # Release the video capture if no frame is captured
return None # Return None if no frame is captured
frame_height, frame_width = frame.shape[:2]
results = model.track(frame, persist=True)
# person_result = person_model.predict(frame, show=False)
bag_pos = dict()
fishes_pos = dict()
if results[0].boxes.id is not None and results[0].masks is not None:
masks = results[0].masks.data
track_ids = results[0].boxes.id.int().cpu().tolist()
classes = results[0].boxes.cls # Class labels
confidences = results[0].boxes.conf
boxes = results[0].boxes.xyxy
for mask, track_id, cls, conf, box in zip(masks, track_ids, classes, confidences, boxes):
# Convert mask to numpy array if it is a tensor
if isinstance(mask, torch.Tensor):
mask = mask.cpu().numpy()
mask = mask # Convert to numpy array
mask = (mask * 255).astype(np.uint8) # Convert mask to binary format (0 or 255)
# Resize mask to match the original frame dimensions
mask_resized = cv2.resize(mask, (frame_width, frame_height), interpolation=cv2.INTER_NEAREST)
# Get the class name
class_name = model.names[int(cls)]
if class_name == "sanchi":
bag_pos[track_id] = dict(mask=mask)
elif class_name == "chepa":
fishes_pos[track_id] = mask
# Use static color for each class based on the class name
color = class_colors.get(class_name, (255, 255, 255)) # Default to white if class not in color map
color_mask = np.zeros_like(frame, dtype=np.uint8)
color_mask[mask_resized > 128] = color # Apply color where mask is
# Blend original frame with color mask
frame = cv2.addWeighted(frame, 1, color_mask, 0.5, 0)
# Display the label and confidence score on the frame
# Display the label and confidence score on the frame
label = f"{class_name}{track_id}"
position = np.where(mask_resized > 128)
if position[0].size > 0 and position[1].size > 0:
y, x = position[0][0], position[1][0] # Position for label
if not class_name == "sanchi":
cv2.putText(frame, label, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 2, color, 3)
else:
bag_pos[track_id]["xy"] = (x, y)
for fish_id, fish_mask in fishes_pos.items():
if fish_id not in g_fishes:
g_fishes[fish_id] = dict(in_sanchi=False)
if not g_fishes[fish_id]["in_sanchi"]:
for bag_id, bag_info in bag_pos.items():
bag_mask = bag_info["mask"]
if np.any(np.logical_and(fish_mask, bag_mask)):
if bag_id not in g_bags:
g_bags[bag_id] = 0
g_bags[bag_id] += 1
g_fishes[fish_id]["in_sanchi"] = True
print(g_bags)
for bag_id, v in bag_pos.items():
color = class_colors.get("sanchi", (255, 255, 255))
label = f"{g_bags.get(bag_id, 0)}: sanchi{bag_id}"
cv2.putText(frame, label, v["xy"], cv2.FONT_HERSHEY_SIMPLEX, 2, color, 3)
# Loop through each bag position
# Loop through each bag entry
for bag_id, v in g_bags.items():
if v:
# Set the text color to red
color = (0, 0, 255) # Red color in BGR format
label = f"BAG{bag_id}: {v}"
# Get the size of the text box
(text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 2, 5)
# Calculate the position for the rectangle (background)
x1 = frame_width - text_width - text_margin
y1 = text_margin + text_height + baseline
x2 = frame_width - text_margin
y2 = text_margin
# Draw a rectangle for the background
cv2.rectangle(frame, (x1, y2), (x2, y1), (0, 0, 0), thickness=-1) # Black rectangle
# Adjust the transparency (if you still want it)
# Optional: Create an overlay effect
overlay = frame.copy()
cv2.addWeighted(overlay, 0.5, frame, 0.5, 0, frame) # Create transparency effect
# Put the text on top of the rectangle
cv2.putText(frame, label, (x1, y2 + 100), cv2.FONT_HERSHEY_SIMPLEX, 2, color, 3)
return frame
# Gradio interface
iface = gr.Interface(fn=VideoProcessor().gradio_video_stream,
inputs=gr.Video(label="Upload Video"),
outputs=gr.Image(),
)
iface.launch()
|