RashiAgarwal
commited on
Commit
·
165b257
1
Parent(s):
7b3d454
Update display.py
Browse files- display.py +2 -8
display.py
CHANGED
@@ -6,7 +6,7 @@ import random
|
|
6 |
from albumentations.pytorch import ToTensorV2
|
7 |
from yolov3 import YOLOV3_PL
|
8 |
|
9 |
-
def inference(image: np.ndarray, iou_thresh: float = 0.75, thresh: float = 0.75,
|
10 |
model = YOLOV3_PL()
|
11 |
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
|
12 |
|
@@ -39,14 +39,8 @@ def inference(image: np.ndarray, iou_thresh: float = 0.75, thresh: float = 0.75,
|
|
39 |
bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
|
40 |
)
|
41 |
plot_img = draw_predictions(image, nms_boxes, class_labels=config.PASCAL_CLASSES)
|
42 |
-
|
43 |
-
return [plot_img]
|
44 |
|
45 |
-
grayscale_cam = cam(transformed_image, scaled_anchors)[0, :, :]
|
46 |
-
img = cv2.resize(image, (416, 416))
|
47 |
-
img = np.float32(img) / 255
|
48 |
-
cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
49 |
-
return [plot_img, cam_image]
|
50 |
|
51 |
|
52 |
def draw_predictions(image: np.ndarray, boxes: list[list], class_labels: list[str]) -> np.ndarray:
|
|
|
6 |
from albumentations.pytorch import ToTensorV2
|
7 |
from yolov3 import YOLOV3_PL
|
8 |
|
9 |
+
def inference(image: np.ndarray, iou_thresh: float = 0.75, thresh: float = 0.75, transparency: float = 0.5):
|
10 |
model = YOLOV3_PL()
|
11 |
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
|
12 |
|
|
|
39 |
bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
|
40 |
)
|
41 |
plot_img = draw_predictions(image, nms_boxes, class_labels=config.PASCAL_CLASSES)
|
42 |
+
return [plot_img]
|
|
|
43 |
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
|
46 |
def draw_predictions(image: np.ndarray, boxes: list[list], class_labels: list[str]) -> np.ndarray:
|