Spaces:
Runtime error
Runtime error
import os | |
import time | |
from loguru import logger | |
import cv2 | |
import torch | |
from yolox.data.data_augment import ValTransform | |
from yolox.data.datasets import COCO_CLASSES | |
from yolox.utils import postprocess, vis | |
class Predictor(object): | |
def __init__( | |
self, | |
model, | |
cls_names=COCO_CLASSES, | |
device="cpu", | |
fp16=False, | |
legacy=False, | |
): | |
self.model = model | |
self.cls_names = cls_names | |
self.num_classes = len(COCO_CLASSES) | |
self.confthre = 0.01 | |
self.nmsthre = 0.01 | |
self.test_size = (640, 640) | |
self.device = device | |
self.fp16 = fp16 | |
self.preproc = ValTransform(legacy=legacy) | |
def inference(self, img, confthre=None, nmsthre=None, test_size=None): | |
if confthre is not None: | |
self.confthre = confthre | |
if nmsthre is not None: | |
self.nmsthre = nmsthre | |
if test_size is not None: | |
self.test_size = test_size | |
img_info = {"id": 0} | |
if isinstance(img, str): | |
img_info["file_name"] = os.path.basename(img) | |
img = cv2.imread(img) | |
else: | |
img_info["file_name"] = None | |
cv2.imwrite("test.png", img) | |
height, width = img.shape[:2] | |
img_info["height"] = height | |
img_info["width"] = width | |
img_info["raw_img"] = img | |
ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1]) | |
img_info["ratio"] = ratio | |
img, _ = self.preproc(img, None, self.test_size) | |
img = torch.from_numpy(img).unsqueeze(0) | |
img = img.float() | |
if self.device == "gpu": | |
img = img.cuda() | |
if self.fp16: | |
img = img.half() # to FP16 | |
with torch.no_grad(): | |
outputs = self.model(img) | |
outputs = postprocess( | |
outputs, self.num_classes, self.confthre, | |
self.nmsthre | |
) | |
return outputs, img_info | |
def visual(self, output, img_info): | |
ratio = img_info["ratio"] | |
img = img_info["raw_img"] | |
if output is None: | |
return img | |
output = output.cpu() | |
bboxes = output[:, 0:4] | |
# preprocessing: resize | |
bboxes /= ratio | |
cls = output[:, 6] | |
scores = output[:, 4] * output[:, 5] | |
vis_res = vis(img, bboxes, scores, cls, self.confthre, self.cls_names) | |
return vis_res | |