itsanurag commited on
Commit
a6ded40
·
verified ·
1 Parent(s): 202a5c4

Create render.py

Browse files
Files changed (1) hide show
  1. render.py +63 -0
render.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from sahi.utils.cv import read_image_as_pil,get_bool_mask_from_coco_segmentation
4
+ from sahi.prediction import ObjectPrediction, PredictionScore,visualize_object_predictions
5
+ from PIL import Image
6
+ def custom_render_result(model,image, result,rect_th=2,text_th=2):
7
+ if model.overrides["task"] not in ["detect", "segment"]:
8
+ raise ValueError(
9
+ f"Model task must be either 'detect' or 'segment'. Got {model.overrides['task']}"
10
+ )
11
+
12
+ image = read_image_as_pil(image)
13
+ np_image = np.ascontiguousarray(image)
14
+
15
+ names = model.model.names
16
+
17
+ masks = result.masks
18
+ boxes = result.boxes
19
+
20
+ object_predictions = []
21
+ if boxes is not None:
22
+ det_ind = 0
23
+ for xyxy, conf, cls in zip(boxes.xyxy, boxes.conf, boxes.cls):
24
+ if masks:
25
+ img_height = np_image.shape[0]
26
+ img_width = np_image.shape[1]
27
+ segments = masks.segments
28
+ segments = segments[det_ind] # segments: np.array([[x1, y1], [x2, y2]])
29
+ # convert segments into full shape
30
+ segments[:, 0] = segments[:, 0] * img_width
31
+ segments[:, 1] = segments[:, 1] * img_height
32
+ segmentation = [segments.ravel().tolist()]
33
+
34
+ bool_mask = get_bool_mask_from_coco_segmentation(
35
+ segmentation, width=img_width, height=img_height
36
+ )
37
+ if sum(sum(bool_mask == 1)) <= 2:
38
+ continue
39
+ object_prediction = ObjectPrediction.from_coco_segmentation(
40
+ segmentation=segmentation,
41
+ category_name=names[int(cls)],
42
+ category_id=int(cls),
43
+ full_shape=[img_height, img_width],
44
+ )
45
+ object_prediction.score = PredictionScore(value=conf)
46
+ else:
47
+ object_prediction = ObjectPrediction(
48
+ bbox=xyxy.tolist(),
49
+ category_name=names[int(cls)],
50
+ category_id=int(cls),
51
+ score=conf,
52
+ )
53
+ object_predictions.append(object_prediction)
54
+ det_ind += 1
55
+
56
+ result = visualize_object_predictions(
57
+ image=np_image,
58
+ object_prediction_list=object_predictions,
59
+ rect_th=rect_th,
60
+ text_th=text_th,
61
+ )
62
+
63
+ return Image.fromarray(result["image"])