sapiens-pose / detector_utils.py
rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame contribute delete
No virus
5.93 kB
from typing import List, Optional, Sequence, Union
import torch
import cv2
import numpy as np
from mmcv.ops import RoIPool
from mmengine.dataset import Compose, pseudo_collate
from mmengine.device import get_device
from mmengine.registry import init_default_scope
from mmdet.apis import inference_detector, init_detector
from mmdet.structures import DetDataSample, SampleList
from mmdet.utils import get_test_pipeline_cfg
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
def nms(dets: np.ndarray, thr: float):
"""Greedily select boxes with high confidence and overlap <= thr.
Args:
dets (np.ndarray): [[x1, y1, x2, y2, score]].
thr (float): Retain overlap < thr.
Returns:
list: Indexes to keep.
"""
if len(dets) == 0:
return []
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while len(order) > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thr)[0]
order = order[inds + 1]
return keep
def adapt_mmdet_pipeline(cfg):
"""Converts pipeline types in MMDetection's test dataloader to use the
'mmdet' namespace.
Args:
cfg (ConfigDict): Configuration dictionary for MMDetection.
Returns:
ConfigDict: Configuration dictionary with updated pipeline types.
"""
# use lazy import to avoid hard dependence on mmdet
from mmdet.datasets import transforms
if 'test_dataloader' not in cfg:
return cfg
pipeline = cfg.test_dataloader.dataset.pipeline
for trans in pipeline:
if trans['type'] in dir(transforms):
trans['type'] = 'mmdet.' + trans['type']
return cfg
def inference_detector(
model: torch.nn.Module,
imgs: ImagesType,
test_pipeline: Optional[Compose] = None,
text_prompt: Optional[str] = None,
custom_entities: bool = False,
) -> Union[DetDataSample, SampleList]:
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str, ndarray, Sequence[str/ndarray]):
Either image files or loaded images.
test_pipeline (:obj:`Compose`): Test pipeline.
Returns:
:obj:`DetDataSample` or list[:obj:`DetDataSample`]:
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the detection results directly.
"""
if isinstance(imgs, torch.Tensor):
if imgs.is_cuda:
imgs = imgs.cpu()
# Remove batch dimension and transpose
imgs = imgs.squeeze(0).permute(1, 2, 0).numpy()
# Ensure the data type is appropriate (uint8 for most image processing functions)
imgs = (imgs * 255).astype(np.uint8)
if isinstance(imgs, (list, tuple)) or (isinstance(imgs, np.ndarray) and len(imgs.shape) == 4):
is_batch = True
else:
imgs = [imgs]
is_batch = False
cfg = model.cfg
if test_pipeline is None:
cfg = cfg.copy()
test_pipeline = get_test_pipeline_cfg(cfg)
if isinstance(imgs[0], np.ndarray):
# Calling this method across libraries will result
# in module unregistered error if not prefixed with mmdet.
test_pipeline[0].type = "mmdet.LoadImageFromNDArray"
test_pipeline = Compose(test_pipeline)
if model.data_preprocessor.device.type == "cpu":
for m in model.modules():
assert not isinstance(
m, RoIPool
), "CPU inference with RoIPool is not supported currently."
result_list = []
for i, img in enumerate(imgs):
# prepare data
if isinstance(img, np.ndarray):
# TODO: remove img_id.
data_ = dict(img=img, img_id=0)
else:
# TODO: remove img_id.
data_ = dict(img_path=img, img_id=0)
if text_prompt:
data_["text"] = text_prompt
data_["custom_entities"] = custom_entities
# build the data pipeline
data_ = test_pipeline(data_)
data_["inputs"] = [data_["inputs"]]
data_["data_samples"] = [data_["data_samples"]]
# forward the model
with torch.no_grad(), torch.autocast(device_type=get_device(), dtype=torch.bfloat16):
results = model.test_step(data_)[0]
result_list.append(results)
if not is_batch:
return result_list[0]
else:
return result_list
def process_one_image_bbox(pred_instance, det_cat_id, bbox_thr, nms_thr):
bboxes = np.concatenate(
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1
)
bboxes = bboxes[
np.logical_and(
pred_instance.labels == det_cat_id,
pred_instance.scores > bbox_thr,
)
]
bboxes = bboxes[nms(bboxes, nms_thr), :4]
return bboxes
def process_images_detector(imgs, detector):
"""Visualize predicted keypoints (and heatmaps) of one image."""
# predict bbox
det_results = inference_detector(detector, imgs)
pred_instances = list(
map(lambda det_result: det_result.pred_instances.numpy(), det_results)
)
bboxes_batch = list(
map(
lambda pred_instance: process_one_image_bbox(
pred_instance, 0, 0.3, 0.3 ## argparse.Namespace(det_cat_id=0, bbox_thr=0.3, nms_thr=0.3),
),
pred_instances,
)
)
return bboxes_batch