File size: 2,182 Bytes
a2b0f6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# Ultralytics YOLO 🚀, AGPL-3.0 license
import numpy as np
import torch
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils.torch_utils import select_device
from .modules.mask_generator import SamAutomaticMaskGenerator
class Predictor(BasePredictor):
def preprocess(self, im):
"""Prepares input image for inference."""
# TODO: Only support bs=1 for now
# im = ResizeLongestSide(1024).apply_image(im[0])
# im = torch.as_tensor(im, device=self.device)
# im = im.permute(2, 0, 1).contiguous()[None, :, :, :]
return im[0]
def setup_model(self, model):
"""Set up YOLO model with specified thresholds and device."""
device = select_device(self.args.device)
model.eval()
self.model = SamAutomaticMaskGenerator(model.to(device),
pred_iou_thresh=self.args.conf,
box_nms_thresh=self.args.iou)
self.device = device
# TODO: Temporary settings for compatibility
self.model.pt = False
self.model.triton = False
self.model.stride = 32
self.model.fp16 = False
self.done_warmup = True
def postprocess(self, preds, path, orig_imgs):
"""Postprocesses inference output predictions to create detection masks for objects."""
names = dict(enumerate(list(range(len(preds)))))
results = []
# TODO
for i, pred in enumerate([preds]):
masks = torch.from_numpy(np.stack([p['segmentation'] for p in pred], axis=0))
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks))
return results
# def __call__(self, source=None, model=None, stream=False):
# frame = cv2.imread(source)
# preds = self.model.generate(frame)
# return self.postprocess(preds, source, frame)
|