kiiwee commited on
Commit
79183dd
β€’
1 Parent(s): 0e05626

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +26 -0
  2. detectron_utils.py +130 -0
  3. requirements.txt +12 -0
  4. sam_utils.py +224 -0
  5. yolo_utils.py +18 -0
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import numpy as np
4
+ from sam_utils import grounded_segmentation, create_yellow_background_with_insects
5
+ from yolo_utils import yolo_processimage
6
+ from detectron_utils import detectron_process_image
7
+ def process_image(image, include_json):
8
+ detectron_result=detectron_process_image(image)
9
+ yolo_result = yolo_processimage(image)
10
+ insectsam_result = create_yellow_background_with_insects(image)
11
+
12
+ return insectsam_result, yolo_result, detectron_result
13
+
14
+ examples = [
15
+ ["demo.jpg"]
16
+ ]
17
+
18
+ gr.Interface(
19
+ fn=process_image,
20
+ inputs=[gr.Image(type="pil"), gr.Checkbox(label="Include JSON", value=False)],
21
+ outputs=[gr.Image(label='InsectSAM', type="numpy"),
22
+ gr.Image(label='Yolov8', type="numpy"),
23
+ gr.Image(label='Detectron', type="numpy")],
24
+ title="RB-IBDM Model Zoo Demo 🐞",
25
+ examples=examples
26
+ ).launch()
detectron_utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ REPO_ID = "kiiwee/Detectron2_FasterRCNN_InsectDetect"
7
+ FILENAME = "model.pth"
8
+ FILENAME_CONFIG = "config.yml"
9
+
10
+
11
+ # Ensure you have the model file
12
+
13
+ import cv2
14
+ from detectron2.config import get_cfg
15
+ from detectron2.engine import DefaultPredictor
16
+ from detectron2.data import MetadataCatalog
17
+ from detectron2.utils.visualizer import Visualizer, ColorMode
18
+ import matplotlib.pyplot as plt
19
+
20
+
21
+ viz_classes = {'thing_classes': ['Acrididae',
22
+ 'Agapeta',
23
+ 'Agapeta hamana',
24
+ 'Animalia',
25
+ 'Anisopodidae',
26
+ 'Aphididae',
27
+ 'Apidae',
28
+ 'Arachnida',
29
+ 'Araneae',
30
+ 'Arctiidae',
31
+ 'Auchenorrhyncha indet.',
32
+ 'Baetidae',
33
+ 'Cabera',
34
+ 'Caenidae',
35
+ 'Carabidae',
36
+ 'Cecidomyiidae',
37
+ 'Ceratopogonidae',
38
+ 'Cercopidae',
39
+ 'Chironomidae',
40
+ 'Chrysomelidae',
41
+ 'Chrysopidae',
42
+ 'Chrysoteuchia culmella',
43
+ 'Cicadellidae',
44
+ 'Coccinellidae',
45
+ 'Coleophoridae',
46
+ 'Coleoptera',
47
+ 'Collembola',
48
+ 'Corixidae',
49
+ 'Crambidae',
50
+ 'Culicidae',
51
+ 'Curculionidae',
52
+ 'Dermaptera',
53
+ 'Diptera',
54
+ 'Eilema',
55
+ 'Empididae',
56
+ 'Ephemeroptera',
57
+ 'Erebidae',
58
+ 'Fanniidae',
59
+ 'Formicidae',
60
+ 'Gastropoda',
61
+ 'Gelechiidae',
62
+ 'Geometridae',
63
+ 'Hemiptera',
64
+ 'Hydroptilidae',
65
+ 'Hymenoptera',
66
+ 'Ichneumonidae',
67
+ 'Idaea',
68
+ 'Insecta',
69
+ 'Lepidoptera',
70
+ 'Leptoceridae',
71
+ 'Limoniidae',
72
+ 'Lomaspilis marginata',
73
+ 'Miridae',
74
+ 'Mycetophilidae',
75
+ 'Nepticulidae',
76
+ 'Neuroptera',
77
+ 'Noctuidae',
78
+ 'Notodontidae',
79
+ 'Object',
80
+ 'Opiliones',
81
+ 'Orthoptera',
82
+ 'Panorpa germanica',
83
+ 'Panorpa vulgaris',
84
+ 'Parasitica indet.',
85
+ 'Plutellidae',
86
+ 'Psocodea',
87
+ 'Psychodidae',
88
+ 'Pterophoridae',
89
+ 'Pyralidae',
90
+ 'Pyrausta',
91
+ 'Sepsidae',
92
+ 'Spilosoma',
93
+ 'Staphylinidae',
94
+ 'Stratiomyidae',
95
+ 'Syrphidae',
96
+ 'Tettigoniidae',
97
+ 'Tipulidae',
98
+ 'Tomoceridae',
99
+ 'Tortricidae',
100
+ 'Trichoptera',
101
+ 'Triodia sylvina',
102
+ 'Yponomeuta',
103
+ 'Yponomeutidae']}
104
+
105
+
106
+
107
+ def detectron_process_image(image):
108
+ cfg = get_cfg()
109
+
110
+
111
+ cfg.merge_from_file(hf_hub_download(repo_id=REPO_ID, filename=FILENAME_CONFIG))
112
+ cfg.MODEL.WEIGHTS = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
113
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2
114
+ cfg.MODEL.DEVICE='cpu'
115
+ predictor = DefaultPredictor(cfg)
116
+
117
+ numpy_image = np.array(image)
118
+
119
+
120
+ im = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)
121
+
122
+ v = Visualizer(im[:, :, ::-1],
123
+ viz_classes,
124
+ scale=0.5)
125
+ outputs = predictor(im)
126
+ out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
127
+ results = out.get_image()[:, :, ::-1]
128
+ rgb_image = cv2.cvtColor(results, cv2.COLOR_BGR2RGB)
129
+
130
+ return rgb_image
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.29.0
2
+ torch
3
+ transformers
4
+ opencv-python
5
+ Pillow
6
+ numpy
7
+ requests
8
+ matplotlib
9
+ ultralytics
10
+ onnxruntime
11
+ efficientnet
12
+ detectron2 @ git+https://github.com/facebookresearch/detectron2.git
sam_utils.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import random
4
+ from dataclasses import dataclass
5
+ from typing import Any, List, Dict, Optional, Union, Tuple
6
+ import cv2
7
+ import torch
8
+ import requests
9
+ import numpy as np
10
+ from PIL import Image
11
+ import matplotlib.pyplot as plt
12
+ from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
13
+ import gradio as gr
14
+ import json
15
+
16
+
17
+ @dataclass
18
+ class BoundingBox:
19
+ xmin: int
20
+ ymin: int
21
+ xmax: int
22
+ ymax: int
23
+
24
+ @property
25
+ def xyxy(self) -> List[float]:
26
+ return [self.xmin, self.ymin, self.xmax, self.ymax]
27
+ @dataclass
28
+ class DetectionResult:
29
+ score: float
30
+ label: str
31
+ box: BoundingBox
32
+ mask: Optional[np.ndarray] = None
33
+
34
+ @classmethod
35
+ def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
36
+ return cls(
37
+ score=detection_dict['score'],
38
+ label=detection_dict['label'],
39
+ box=BoundingBox(
40
+ xmin=detection_dict['box']['xmin'],
41
+ ymin=detection_dict['box']['ymin'],
42
+ xmax=detection_dict['box']['xmax'],
43
+ ymax=detection_dict['box']['ymax']
44
+ )
45
+ )
46
+
47
+ def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[DetectionResult], include_bboxes: bool = True) -> np.ndarray:
48
+ image_cv2 = np.array(image) if isinstance(image, Image.Image) else image
49
+ image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_RGB2BGR)
50
+
51
+ for detection in detection_results:
52
+ label = detection.label
53
+ score = detection.score
54
+ box = detection.box
55
+ mask = detection.mask
56
+
57
+ if include_bboxes:
58
+ color = np.random.randint(0, 256, size=3).tolist()
59
+ cv2.rectangle(image_cv2, (box.xmin, box.ymin),
60
+ (box.xmax, box.ymax), color, 2)
61
+ cv2.putText(image_cv2, f'{label}: {score:.2f}', (box.xmin, box.ymin - 10),
62
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
63
+
64
+ return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
65
+
66
+
67
+ def plot_detections(image: Union[Image.Image, np.ndarray], detections: List[DetectionResult], include_bboxes: bool = True) -> np.ndarray:
68
+ annotated_image = annotate(image, detections, include_bboxes)
69
+ return annotated_image
70
+
71
+
72
+ def load_image(image: Union[str, Image.Image]) -> Image.Image:
73
+ if isinstance(image, str) and image.startswith("http"):
74
+ image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
75
+ elif isinstance(image, str):
76
+ image = Image.open(image).convert("RGB")
77
+ else:
78
+ image = image.convert("RGB")
79
+ return image
80
+
81
+
82
+ def get_boxes(detection_results: List[DetectionResult]) -> List[List[List[float]]]:
83
+ boxes = []
84
+ for result in detection_results:
85
+ xyxy = result.box.xyxy
86
+ boxes.append(xyxy)
87
+ return [boxes]
88
+
89
+
90
+ def mask_to_polygon(mask: np.ndarray) -> np.ndarray:
91
+ contours, _ = cv2.findContours(
92
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
93
+ if len(contours) == 0:
94
+ return np.array([])
95
+ largest_contour = max(contours, key=cv2.contourArea)
96
+ return largest_contour
97
+
98
+
99
+ def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
100
+ masks = masks.cpu().float().permute(0, 2, 3, 1).mean(
101
+ axis=-1).numpy().astype(np.uint8)
102
+ masks = (masks > 0).astype(np.uint8)
103
+ if polygon_refinement:
104
+ for idx, mask in enumerate(masks):
105
+ shape = mask.shape
106
+ polygon = mask_to_polygon(mask)
107
+ masks[idx] = cv2.fillPoly(
108
+ np.zeros(shape, dtype=np.uint8), [polygon], 1)
109
+ return list(masks)
110
+
111
+
112
+ def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[Dict[str, Any]]:
113
+ detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
114
+ object_detector = pipeline(
115
+ model=detector_id, task="zero-shot-object-detection", device="cpu")
116
+ labels = [label if label.endswith(".") else label+"." for label in labels]
117
+ results = object_detector(
118
+ image, candidate_labels=labels, threshold=threshold)
119
+ return [DetectionResult.from_dict(result) for result in results]
120
+
121
+
122
+ def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
123
+ segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
124
+ segmentator = AutoModelForMaskGeneration.from_pretrained(
125
+ segmenter_id).to("cpu")
126
+ processor = AutoProcessor.from_pretrained(segmenter_id)
127
+ boxes = get_boxes(detection_results)
128
+ inputs = processor(images=image, input_boxes=boxes,
129
+ return_tensors="pt").to("cpu")
130
+ outputs = segmentator(**inputs)
131
+ masks = processor.post_process_masks(
132
+ masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes)[0]
133
+ masks = refine_masks(masks, polygon_refinement)
134
+ for detection_result, mask in zip(detection_results, masks):
135
+ detection_result.mask = mask
136
+ return detection_results
137
+
138
+
139
+ def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None) -> Tuple[np.ndarray, List[DetectionResult]]:
140
+ image = load_image(image)
141
+ detections = detect(image, labels, threshold, detector_id)
142
+ detections = segment(image, detections, polygon_refinement, segmenter_id)
143
+ return np.array(image), detections
144
+
145
+
146
+ def mask_to_min_max(mask: np.ndarray) -> Tuple[int, int, int, int]:
147
+ y, x = np.where(mask)
148
+ return x.min(), y.min(), x.max(), y.max()
149
+
150
+
151
+ def extract_and_paste_insect(original_image: np.ndarray, detection: DetectionResult, background: np.ndarray) -> None:
152
+ mask = detection.mask
153
+ xmin, ymin, xmax, ymax = mask_to_min_max(mask)
154
+ insect_crop = original_image[ymin:ymax, xmin:xmax]
155
+ mask_crop = mask[ymin:ymax, xmin:xmax]
156
+
157
+ insect = cv2.bitwise_and(insect_crop, insect_crop, mask=mask_crop)
158
+
159
+ x_offset, y_offset = xmin, ymin
160
+ x_end, y_end = x_offset + insect.shape[1], y_offset + insect.shape[0]
161
+
162
+ insect_area = background[y_offset:y_end, x_offset:x_end]
163
+ insect_area[mask_crop == 1] = insect[mask_crop == 1]
164
+
165
+
166
+ def create_yellow_background_with_insects(image: np.ndarray) -> np.ndarray:
167
+ labels = ["insect"]
168
+
169
+ original_image, detections = grounded_segmentation(
170
+ image, labels, threshold=0.3, polygon_refinement=True)
171
+
172
+ yellow_background = np.full(
173
+ (original_image.shape[0], original_image.shape[1], 3), (0, 255, 255), dtype=np.uint8) # BGR for yellow
174
+ for detection in detections:
175
+ if detection.mask is not None:
176
+ extract_and_paste_insect(
177
+ original_image, detection, yellow_background)
178
+ # Convert back to RGB to match Gradio's expected input format
179
+ yellow_background = cv2.cvtColor(yellow_background, cv2.COLOR_BGR2RGB)
180
+ return yellow_background
181
+
182
+
183
+ def run_length_encoding(mask):
184
+ pixels = mask.flatten()
185
+ rle = []
186
+ last_val = 0
187
+ count = 0
188
+ for pixel in pixels:
189
+ if pixel == last_val:
190
+ count += 1
191
+ else:
192
+ if count > 0:
193
+ rle.append(count)
194
+ count = 1
195
+ last_val = pixel
196
+ if count > 0:
197
+ rle.append(count)
198
+ return rle
199
+
200
+
201
+ def detections_to_json(detections):
202
+ detections_list = []
203
+ for detection in detections:
204
+ detection_dict = {
205
+ "score": detection.score,
206
+ "label": detection.label,
207
+ "box": {
208
+ "xmin": detection.box.xmin,
209
+ "ymin": detection.box.ymin,
210
+ "xmax": detection.box.xmax
211
+ },
212
+ "mask": run_length_encoding(detection.mask) if detection.mask is not None else None
213
+ }
214
+ detections_list.append(detection_dict)
215
+ return detections_list
216
+
217
+
218
+ def crop_bounding_boxes_with_yellow_background(image: np.ndarray, yellow_background: np.ndarray, detections: List[DetectionResult]) -> List[np.ndarray]:
219
+ crops = []
220
+ for detection in detections:
221
+ xmin, ymin, xmax, ymax = detection.box.xyxy
222
+ crop = yellow_background[ymin:ymax, xmin:xmax]
223
+ crops.append(crop)
224
+ return crops
yolo_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ REPO_ID = "kiiwee/Yolov8_InsectDetect"
8
+ FILENAME = "insectYolo.pt"
9
+
10
+
11
+ # Ensure you have the model file
12
+ model = YOLO(hf_hub_download(repo_id=REPO_ID, filename=FILENAME))
13
+ def yolo_processimage(image):
14
+ results = model(source=image, show=True,save=True,
15
+ conf=0.2, device='mps',save_crop=True)
16
+ rgb_image = cv2.cvtColor(results[0].plot(), cv2.COLOR_BGR2RGB)
17
+ return rgb_image
18
+