File size: 6,515 Bytes
1e2d295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13b8932
1e2d295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2cf29c
1e2d295
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Created by yarramsettinaresh GORAKA DIGITAL PRIVATE LIMITED at 01/11/24
import cv2
import numpy as np
from ultralytics import YOLO
import torch
import gradio as gr

# Load your models
model = YOLO("chepala_lekka_v5_yolov11n.pt")

# Initialize global video capture variable
cap = None
text_margin = 10
class_colors = {
    "chepa": (255, 0, 0),  # Red color for "chepa"
    "sanchi": (0, 255, 0),  # Green color for "sanchi"
    "other_class": (0, 0, 255)  # Blue color for other classes
}

class VideoProcessor:
    def __init__(self):
        self.g_bags = {}
        self.g_fishes = {}
        self.text_margin = 10
        self.class_colors = {
            "chepa": (255, 0, 0),  # Red for "chepa"
            "sanchi": (0, 255, 0),  # Green for "sanchi"
            "other_class": (0, 0, 255)  # Blue for other classes
        }
        self.cap = None

    def gradio_video_stream(self, video_file):
        print(f"gradio_video_stream init : {video_file}")
        self.g_bags = {}
        self.g_fishes = {}
        self.cap = cv2.VideoCapture(video_file)  # Open the uploaded video file

        while True:
            frame = self.process_frame()
            if frame is None:
                break
            yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB for Gradio

    def process_frame(self):
        g_fishes = self.g_fishes
        g_bags = self.g_bags
        cap = self.cap
        ret, frame = cap.read()
        if not ret:
            cap.release()  # Release the video capture if no frame is captured
            return None  # Return None if no frame is captured

        frame_height, frame_width = frame.shape[:2]

        results = model.track(frame, persist=True)
        # person_result = person_model.predict(frame, show=False)
        bag_pos = dict()
        fishes_pos = dict()
        if results[0].boxes.id is not None and results[0].masks is not None:
            masks = results[0].masks.data
            track_ids = results[0].boxes.id.int().cpu().tolist()
            classes = results[0].boxes.cls  # Class labels
            confidences = results[0].boxes.conf
            boxes = results[0].boxes.xyxy

            for mask, track_id, cls, conf, box in zip(masks, track_ids, classes, confidences, boxes):
                # Convert mask to numpy array if it is a tensor
                if isinstance(mask, torch.Tensor):
                    mask = mask.cpu().numpy()
                mask = mask  # Convert to numpy array
                mask = (mask * 255).astype(np.uint8)  # Convert mask to binary format (0 or 255)

                # Resize mask to match the original frame dimensions
                mask_resized = cv2.resize(mask, (frame_width, frame_height), interpolation=cv2.INTER_NEAREST)

                # Get the class name
                class_name = model.names[int(cls)]
                if class_name == "sanchi":
                    bag_pos[track_id] = dict(mask=mask)
                elif class_name == "chepa":
                    fishes_pos[track_id] = mask
                # Use static color for each class based on the class name
                color = class_colors.get(class_name, (255, 255, 255))  # Default to white if class not in color map
                color_mask = np.zeros_like(frame, dtype=np.uint8)
                color_mask[mask_resized > 128] = color  # Apply color where mask is

                # Blend original frame with color mask
                frame = cv2.addWeighted(frame, 1, color_mask, 0.5, 0)

                # Display the label and confidence score on the frame
                # Display the label and confidence score on the frame
                label = f"{class_name}{track_id}"
                position = np.where(mask_resized > 128)
                if position[0].size > 0 and position[1].size > 0:

                    y, x = position[0][0], position[1][0]  # Position for label

                    if not class_name == "sanchi":
                        cv2.putText(frame, label, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 2, color, 3)
                    else:
                        bag_pos[track_id]["xy"] = (x, y)
        for fish_id, fish_mask in fishes_pos.items():
            if fish_id not in g_fishes:
                g_fishes[fish_id] = dict(in_sanchi=False)
            if not g_fishes[fish_id]["in_sanchi"]:
                for bag_id, bag_info in bag_pos.items():
                    bag_mask = bag_info["mask"]
                    if np.any(np.logical_and(fish_mask, bag_mask)):
                        if bag_id not in g_bags:
                            g_bags[bag_id] = 0
                        g_bags[bag_id] += 1
                        g_fishes[fish_id]["in_sanchi"] = True
        print(g_bags)
        for bag_id, v in bag_pos.items():
            color = class_colors.get("sanchi", (255, 255, 255))
            label = f"{g_bags.get(bag_id, 0)}: sanchi{bag_id}"
            cv2.putText(frame, label, v["xy"], cv2.FONT_HERSHEY_SIMPLEX, 2, color, 3)
        # Loop through each bag position
        # Loop through each bag entry
        for bag_id, v in g_bags.items():
            if v:
                # Set the text color to red
                color = (0, 0, 255)  # Red color in BGR format
                label = f"BAG{bag_id}: {v}"

                # Get the size of the text box
                (text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 2, 5)

                # Calculate the position for the rectangle (background)
                x1 = frame_width - text_width - text_margin
                y1 = text_margin + text_height + baseline
                x2 = frame_width - text_margin
                y2 = text_margin

                # Draw a rectangle for the background
                cv2.rectangle(frame, (x1, y2), (x2, y1), (0, 0, 0), thickness=-1)  # Black rectangle

                # Adjust the transparency (if you still want it)
                # Optional: Create an overlay effect
                overlay = frame.copy()
                cv2.addWeighted(overlay, 0.5, frame, 0.5, 0, frame)  # Create transparency effect

                # Put the text on top of the rectangle
                cv2.putText(frame, label, (x1, y2 + 100), cv2.FONT_HERSHEY_SIMPLEX, 2, color, 3)

        return frame

# Gradio interface
iface = gr.Interface(fn=VideoProcessor().gradio_video_stream,
                     inputs=gr.Video(label="Upload Video"),
                     outputs=gr.Image(),
                     )

iface.launch()