|
import onnxruntime |
|
import numpy as np |
|
import onnx |
|
import copy |
|
import cv2 |
|
from pathlib import Path |
|
import matplotlib.pyplot as plt |
|
import torch |
|
import time |
|
import torchvision |
|
import re |
|
import glob |
|
from contextlib import contextmanager |
|
from torch.utils.data import Dataset |
|
import yaml |
|
import os |
|
from multiprocessing.pool import ThreadPool, Pool |
|
from tqdm import tqdm |
|
from itertools import repeat |
|
import logging |
|
from PIL import Image, ExifTags |
|
import hashlib |
|
import shutil |
|
import sys |
|
import pathlib |
|
CURRENT_DIR = pathlib.Path(__file__).parent |
|
sys.path.append(str(CURRENT_DIR)) |
|
|
|
IMG_FORMATS = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] |
|
NUM_THREADS = min(8, os.cpu_count()) |
|
img_formats = IMG_FORMATS |
|
vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] |
|
|
|
|
|
for orientation in ExifTags.TAGS.keys(): |
|
if ExifTags.TAGS[orientation] == 'Orientation': |
|
break |
|
|
|
|
|
def make_dirs(dir='./datasets/coco'): |
|
|
|
dir = Path(dir) |
|
for p in [dir / 'labels']: |
|
p.mkdir(parents=True, exist_ok=True) |
|
return dir |
|
|
|
|
|
def coco91_to_coco80_class(): |
|
|
|
x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None, |
|
None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, |
|
51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, |
|
None, 73, 74, 75, 76, 77, 78, 79, None] |
|
return x |
|
|
|
|
|
def is_ascii(s=""): |
|
|
|
s = str(s) |
|
return len(s.encode().decode("ascii", "ignore")) == len(s) |
|
|
|
|
|
def is_chinese(s="人工智能"): |
|
|
|
return re.search("[\u4e00-\u9fff]", s) |
|
|
|
|
|
def letterbox( |
|
im, |
|
new_shape=(640, 640), |
|
color=(114, 114, 114), |
|
auto=True, |
|
scaleFill=False, |
|
scaleup=True, |
|
stride=32, |
|
): |
|
|
|
shape = im.shape[:2] |
|
if isinstance(new_shape, int): |
|
new_shape = (new_shape, new_shape) |
|
|
|
|
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) |
|
if not scaleup: |
|
r = min(r, 1.0) |
|
|
|
|
|
ratio = r, r |
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) |
|
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] |
|
if auto: |
|
dw, dh = np.mod(dw, stride), np.mod(dh, stride) |
|
elif scaleFill: |
|
dw, dh = 0.0, 0.0 |
|
new_unpad = (new_shape[1], new_shape[0]) |
|
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] |
|
|
|
dw /= 2 |
|
dh /= 2 |
|
|
|
if shape[::-1] != new_unpad: |
|
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) |
|
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) |
|
left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) |
|
im = cv2.copyMakeBorder( |
|
im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color |
|
) |
|
return im, ratio, (dw, dh) |
|
|
|
|
|
def xyxy2xywh(x): |
|
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) |
|
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 |
|
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 |
|
y[:, 2] = x[:, 2] - x[:, 0] |
|
y[:, 3] = x[:, 3] - x[:, 1] |
|
return y |
|
|
|
|
|
def xywh2xyxy(x): |
|
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) |
|
y[:, 0] = x[:, 0] - x[:, 2] / 2 |
|
y[:, 1] = x[:, 1] - x[:, 3] / 2 |
|
y[:, 2] = x[:, 0] + x[:, 2] / 2 |
|
y[:, 3] = x[:, 1] + x[:, 3] / 2 |
|
return y |
|
|
|
|
|
def non_max_suppression( |
|
prediction, |
|
conf_thres=0.25, |
|
iou_thres=0.45, |
|
classes=None, |
|
agnostic=False, |
|
multi_label=False, |
|
labels=(), |
|
max_det=300, |
|
): |
|
"""Runs Non-Maximum Suppression (NMS) on inference results |
|
|
|
Returns: |
|
list of detections, on (n,6) tensor per image [xyxy, conf, cls] |
|
""" |
|
|
|
nc = prediction.shape[2] - 5 |
|
xc = prediction[..., 4] > conf_thres |
|
|
|
|
|
assert ( |
|
0 <= conf_thres <= 1 |
|
), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" |
|
assert ( |
|
0 <= iou_thres <= 1 |
|
), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" |
|
|
|
|
|
min_wh, max_wh = 2, 4096 |
|
max_nms = 30000 |
|
time_limit = 10.0 |
|
redundant = True |
|
multi_label &= nc > 1 |
|
merge = False |
|
|
|
t = time.time() |
|
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] |
|
for xi, x in enumerate(prediction): |
|
|
|
|
|
x = x[xc[xi]] |
|
|
|
|
|
if labels and len(labels[xi]): |
|
l = labels[xi] |
|
v = torch.zeros((len(l), nc + 5), device=x.device) |
|
v[:, :4] = l[:, 1:5] |
|
v[:, 4] = 1.0 |
|
v[range(len(l)), l[:, 0].long() + 5] = 1.0 |
|
x = torch.cat((x, v), 0) |
|
|
|
|
|
if not x.shape[0]: |
|
continue |
|
|
|
|
|
x[:, 5:] *= x[:, 4:5] |
|
|
|
|
|
box = xywh2xyxy(x[:, :4]) |
|
|
|
|
|
if multi_label: |
|
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T |
|
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) |
|
else: |
|
conf, j = x[:, 5:].max(1, keepdim=True) |
|
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] |
|
|
|
|
|
if classes is not None: |
|
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
n = x.shape[0] |
|
if not n: |
|
continue |
|
elif n > max_nms: |
|
x = x[x[:, 4].argsort(descending=True)[:max_nms]] |
|
|
|
|
|
c = x[:, 5:6] * (0 if agnostic else max_wh) |
|
boxes, scores = x[:, :4] + c, x[:, 4] |
|
i = torchvision.ops.nms(boxes, scores, iou_thres) |
|
if i.shape[0] > max_det: |
|
i = i[:max_det] |
|
if merge and (1 < n < 3e3): |
|
|
|
iou = box_iou(boxes[i], boxes) > iou_thres |
|
weights = iou * scores[None] |
|
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum( |
|
1, keepdim=True |
|
) |
|
if redundant: |
|
i = i[iou.sum(1) > 1] |
|
|
|
output[xi] = x[i] |
|
if (time.time() - t) > time_limit: |
|
print(f"WARNING: NMS time limit {time_limit}s exceeded") |
|
break |
|
|
|
return output |
|
|
|
|
|
def clip_coords(boxes, shape): |
|
|
|
if isinstance(boxes, torch.Tensor): |
|
boxes[:, 0].clamp_(0, shape[1]) |
|
boxes[:, 1].clamp_(0, shape[0]) |
|
boxes[:, 2].clamp_(0, shape[1]) |
|
boxes[:, 3].clamp_(0, shape[0]) |
|
else: |
|
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) |
|
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) |
|
|
|
|
|
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): |
|
|
|
if ratio_pad is None: |
|
gain = min( |
|
img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1] |
|
) |
|
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, ( |
|
img1_shape[0] - img0_shape[0] * gain |
|
) / 2 |
|
else: |
|
gain = ratio_pad[0][0] |
|
pad = ratio_pad[1] |
|
|
|
coords[:, [0, 2]] -= pad[0] |
|
coords[:, [1, 3]] -= pad[1] |
|
coords[:, :4] /= gain |
|
clip_coords(coords, img0_shape) |
|
return coords |
|
|
|
|
|
class Annotator: |
|
|
|
def __init__( |
|
self, |
|
im, |
|
line_width=None, |
|
font_size=None, |
|
font="Arial.ttf", |
|
pil=False, |
|
example="abc", |
|
): |
|
assert ( |
|
im.data.contiguous |
|
), "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images." |
|
self.pil = pil or not is_ascii(example) or is_chinese(example) |
|
if self.pil: |
|
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) |
|
self.draw = ImageDraw.Draw(self.im) |
|
self.font = check_font( |
|
font="Arial.Unicode.ttf" if is_chinese(example) else font, |
|
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12), |
|
) |
|
else: |
|
self.im = im |
|
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) |
|
|
|
def box_label( |
|
self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255) |
|
): |
|
|
|
if self.pil or not is_ascii(label): |
|
self.draw.rectangle(box, width=self.lw, outline=color) |
|
if label: |
|
w, h = self.font.getsize(label) |
|
outside = box[1] - h >= 0 |
|
self.draw.rectangle( |
|
[ |
|
box[0], |
|
box[1] - h if outside else box[1], |
|
box[0] + w + 1, |
|
box[1] + 1 if outside else box[1] + h + 1, |
|
], |
|
fill=color, |
|
) |
|
|
|
self.draw.text( |
|
(box[0], box[1] - h if outside else box[1]), |
|
label, |
|
fill=txt_color, |
|
font=self.font, |
|
) |
|
else: |
|
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) |
|
cv2.rectangle( |
|
self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA |
|
) |
|
if label: |
|
tf = max(self.lw - 1, 1) |
|
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[ |
|
0 |
|
] |
|
outside = p1[1] - h - 3 >= 0 |
|
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 |
|
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) |
|
cv2.putText( |
|
self.im, |
|
label, |
|
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2), |
|
0, |
|
self.lw / 3, |
|
txt_color, |
|
thickness=tf, |
|
lineType=cv2.LINE_AA, |
|
) |
|
|
|
def rectangle(self, xy, fill=None, outline=None, width=1): |
|
|
|
self.draw.rectangle(xy, fill, outline, width) |
|
|
|
def text(self, xy, text, txt_color=(255, 255, 255)): |
|
|
|
w, h = self.font.getsize(text) |
|
self.draw.text((xy[0], xy[1] - h + 1), text, fill=txt_color, font=self.font) |
|
|
|
def result(self): |
|
|
|
return np.asarray(self.im) |
|
|
|
|
|
class Colors: |
|
|
|
def __init__(self): |
|
|
|
hex = ( |
|
"FF3838", |
|
"FF9D97", |
|
"FF701F", |
|
"FFB21D", |
|
"CFD231", |
|
"48F90A", |
|
"92CC17", |
|
"3DDB86", |
|
"1A9334", |
|
"00D4BB", |
|
"2C99A8", |
|
"00C2FF", |
|
"344593", |
|
"6473FF", |
|
"0018EC", |
|
"8438FF", |
|
"520085", |
|
"CB38FF", |
|
"FF95C8", |
|
"FF37C7", |
|
) |
|
self.palette = [self.hex2rgb("#" + c) for c in hex] |
|
self.n = len(self.palette) |
|
|
|
def __call__(self, i, bgr=False): |
|
c = self.palette[int(i) % self.n] |
|
return (c[2], c[1], c[0]) if bgr else c |
|
|
|
@staticmethod |
|
def hex2rgb(h): |
|
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) |
|
|
|
|
|
@contextmanager |
|
def torch_distributed_zero_first(local_rank: int): |
|
""" |
|
Decorator to make all processes in distributed training wait for each local_master to do something. |
|
""" |
|
if local_rank not in [-1, 0]: |
|
dist.barrier(device_ids=[local_rank]) |
|
yield |
|
if local_rank == 0: |
|
dist.barrier(device_ids=[0]) |
|
|
|
|
|
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, |
|
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''): |
|
|
|
with torch_distributed_zero_first(rank): |
|
dataset = LoadImagesAndLabels(path, imgsz, batch_size, |
|
augment=augment, |
|
hyp=hyp, |
|
rect=rect, |
|
cache_images=cache, |
|
single_cls=single_cls, |
|
stride=int(stride), |
|
pad=pad, |
|
image_weights=image_weights, |
|
prefix=prefix) |
|
|
|
batch_size = min(batch_size, len(dataset)) |
|
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) |
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None |
|
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader |
|
|
|
dataloader = loader(dataset, |
|
batch_size=batch_size, |
|
num_workers=nw, |
|
sampler=sampler, |
|
pin_memory=True, |
|
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) |
|
return dataloader, dataset |
|
|
|
|
|
class LoadImagesAndLabels(Dataset): |
|
|
|
cache_version = 0.5 |
|
|
|
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, |
|
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''): |
|
self.img_size = img_size |
|
self.augment = augment |
|
self.hyp = hyp |
|
self.image_weights = image_weights |
|
self.rect = False if image_weights else rect |
|
self.mosaic = self.augment and not self.rect |
|
self.mosaic_border = [-img_size // 2, -img_size // 2] |
|
self.stride = stride |
|
self.path = path |
|
self.albumentations = Albumentations() if augment else None |
|
|
|
try: |
|
f = [] |
|
for p in path if isinstance(path, list) else [path]: |
|
p = Path(p) |
|
if p.is_dir(): |
|
f += glob.glob(str(p / '**' / '*.*'), recursive=True) |
|
|
|
elif p.is_file(): |
|
with open(p, 'r') as t: |
|
t = t.read().strip().splitlines() |
|
parent = str(p.parent) + os.sep |
|
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] |
|
|
|
else: |
|
raise Exception(f'{prefix}{p} does not exist') |
|
self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS]) |
|
|
|
assert self.img_files, f'{prefix}No images found' |
|
except Exception as e: |
|
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}') |
|
|
|
|
|
self.label_files = img2label_paths(self.img_files) |
|
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') |
|
try: |
|
cache, exists = np.load(cache_path, allow_pickle=True).item(), True |
|
assert cache['version'] == self.cache_version |
|
assert cache['hash'] == get_hash(self.label_files + self.img_files) |
|
except: |
|
cache, exists = self.cache_labels(cache_path, prefix), False |
|
|
|
|
|
nf, nm, ne, nc, n = cache.pop('results') |
|
if exists: |
|
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" |
|
tqdm(None, desc=prefix + d, total=n, initial=n) |
|
if cache['msgs']: |
|
logging.info('\n'.join(cache['msgs'])) |
|
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}' |
|
|
|
|
|
[cache.pop(k) for k in ('hash', 'version', 'msgs')] |
|
labels, shapes, self.segments = zip(*cache.values()) |
|
self.labels = list(labels) |
|
self.shapes = np.array(shapes, dtype=np.float64) |
|
self.img_files = list(cache.keys()) |
|
self.label_files = img2label_paths(cache.keys()) |
|
if single_cls: |
|
for x in self.labels: |
|
x[:, 0] = 0 |
|
|
|
n = len(shapes) |
|
bi = np.floor(np.arange(n) / batch_size).astype(int) |
|
nb = bi[-1] + 1 |
|
self.batch = bi |
|
self.n = n |
|
self.indices = range(n) |
|
|
|
|
|
if self.rect: |
|
|
|
s = self.shapes |
|
ar = s[:, 1] / s[:, 0] |
|
irect = ar.argsort() |
|
self.img_files = [self.img_files[i] for i in irect] |
|
self.label_files = [self.label_files[i] for i in irect] |
|
self.labels = [self.labels[i] for i in irect] |
|
self.shapes = s[irect] |
|
ar = ar[irect] |
|
|
|
|
|
shapes = [[1, 1]] * nb |
|
for i in range(nb): |
|
ari = ar[bi == i] |
|
mini, maxi = ari.min(), ari.max() |
|
if maxi < 1: |
|
shapes[i] = [maxi, 1] |
|
elif mini > 1: |
|
shapes[i] = [1, 1 / mini] |
|
|
|
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride |
|
|
|
|
|
self.imgs, self.img_npy = [None] * n, [None] * n |
|
if cache_images: |
|
if cache_images == 'disk': |
|
self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy') |
|
self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files] |
|
self.im_cache_dir.mkdir(parents=True, exist_ok=True) |
|
gb = 0 |
|
self.img_hw0, self.img_hw = [None] * n, [None] * n |
|
results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) |
|
pbar = tqdm(enumerate(results), total=n) |
|
for i, x in pbar: |
|
if cache_images == 'disk': |
|
if not self.img_npy[i].exists(): |
|
np.save(self.img_npy[i].as_posix(), x[0]) |
|
gb += self.img_npy[i].stat().st_size |
|
else: |
|
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x |
|
gb += self.imgs[i].nbytes |
|
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})' |
|
pbar.close() |
|
|
|
def cache_labels(self, path=Path('./labels.cache'), prefix=''): |
|
|
|
x = {} |
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] |
|
desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..." |
|
with Pool(NUM_THREADS) as pool: |
|
pbar = tqdm(pool.imap(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))), desc=desc, total=len(self.img_files)) |
|
for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar: |
|
nm += nm_f |
|
nf += nf_f |
|
ne += ne_f |
|
nc += nc_f |
|
if im_file: |
|
x[im_file] = [l, shape, segments] |
|
if msg: |
|
msgs.append(msg) |
|
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted" |
|
|
|
pbar.close() |
|
if msgs: |
|
logging.info('\n'.join(msgs)) |
|
if nf == 0: |
|
logging.info(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}') |
|
x['hash'] = get_hash(self.label_files + self.img_files) |
|
x['results'] = nf, nm, ne, nc, len(self.img_files) |
|
x['msgs'] = msgs |
|
x['version'] = self.cache_version |
|
try: |
|
np.save(path, x) |
|
path.with_suffix('.cache.npy').rename(path) |
|
logging.info(f'{prefix}New cache created: {path}') |
|
except Exception as e: |
|
logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') |
|
return x |
|
|
|
def __len__(self): |
|
return len(self.img_files) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
index = self.indices[index] |
|
|
|
hyp = self.hyp |
|
mosaic = self.mosaic and random.random() < hyp['mosaic'] |
|
if mosaic: |
|
|
|
img, labels = load_mosaic(self, index) |
|
shapes = None |
|
|
|
|
|
if random.random() < hyp['mixup']: |
|
img, labels = mixup(img, labels, *load_mosaic(self, random.randint(0, self.n - 1))) |
|
|
|
else: |
|
|
|
img, (h0, w0), (h, w) = load_image(self, index) |
|
|
|
|
|
shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size |
|
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment) |
|
shapes = (h0, w0), ((h / h0, w / w0), pad) |
|
|
|
labels = self.labels[index].copy() |
|
if labels.size: |
|
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1]) |
|
|
|
if self.augment: |
|
img, labels = random_perspective(img, labels, |
|
degrees=hyp['degrees'], |
|
translate=hyp['translate'], |
|
scale=hyp['scale'], |
|
shear=hyp['shear'], |
|
perspective=hyp['perspective']) |
|
|
|
nl = len(labels) |
|
if nl: |
|
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3) |
|
|
|
if self.augment: |
|
|
|
img, labels = self.albumentations(img, labels) |
|
nl = len(labels) |
|
|
|
|
|
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v']) |
|
|
|
|
|
if random.random() < hyp['flipud']: |
|
img = np.flipud(img) |
|
if nl: |
|
labels[:, 2] = 1 - labels[:, 2] |
|
|
|
|
|
if random.random() < hyp['fliplr']: |
|
img = np.fliplr(img) |
|
if nl: |
|
labels[:, 1] = 1 - labels[:, 1] |
|
|
|
|
|
|
|
|
|
labels_out = torch.zeros((nl, 6)) |
|
if nl: |
|
labels_out[:, 1:] = torch.from_numpy(labels) |
|
|
|
|
|
img = img.transpose((2, 0, 1))[::-1] |
|
img = np.ascontiguousarray(img) |
|
|
|
return torch.from_numpy(img), labels_out, self.img_files[index], shapes |
|
|
|
@staticmethod |
|
def collate_fn(batch): |
|
img, label, path, shapes = zip(*batch) |
|
for i, l in enumerate(label): |
|
l[:, 0] = i |
|
return torch.stack(img, 0), torch.cat(label, 0), path, shapes |
|
|
|
@staticmethod |
|
def collate_fn4(batch): |
|
img, label, path, shapes = zip(*batch) |
|
n = len(shapes) // 4 |
|
img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n] |
|
|
|
ho = torch.tensor([[0., 0, 0, 1, 0, 0]]) |
|
wo = torch.tensor([[0., 0, 1, 0, 0, 0]]) |
|
s = torch.tensor([[1, 1, .5, .5, .5, .5]]) |
|
for i in range(n): |
|
i *= 4 |
|
if random.random() < 0.5: |
|
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[ |
|
0].type(img[i].type()) |
|
l = label[i] |
|
else: |
|
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2) |
|
l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s |
|
img4.append(im) |
|
label4.append(l) |
|
|
|
for i, l in enumerate(label4): |
|
l[:, 0] = i |
|
|
|
return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4 |
|
|
|
|
|
def coco80_to_coco91_class(): |
|
|
|
|
|
|
|
|
|
|
|
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, |
|
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, |
|
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] |
|
return x |
|
|
|
|
|
def check_dataset(data, autodownload=True): |
|
|
|
|
|
|
|
|
|
extract_dir = '' |
|
if isinstance(data, (str, Path)) and str(data).endswith('.zip'): |
|
download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1) |
|
data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml')) |
|
extract_dir, autodownload = data.parent, False |
|
|
|
|
|
if isinstance(data, (str, Path)): |
|
with open(data, errors='ignore') as f: |
|
data = yaml.safe_load(f) |
|
|
|
|
|
path = extract_dir or Path(data.get('path') or '') |
|
for k in 'train', 'val', 'test': |
|
if data.get(k): |
|
data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]] |
|
|
|
assert 'nc' in data, "Dataset 'nc' key missing." |
|
if 'names' not in data: |
|
data['names'] = [f'class{i}' for i in range(data['nc'])] |
|
train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')] |
|
if val: |
|
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] |
|
if not all(x.exists() for x in val): |
|
print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()]) |
|
if s and autodownload: |
|
root = path.parent if 'path' in data else '..' |
|
if s.startswith('http') and s.endswith('.zip'): |
|
f = Path(s).name |
|
print(f'Downloading {s} to {f}...') |
|
torch.hub.download_url_to_file(s, f) |
|
Path(root).mkdir(parents=True, exist_ok=True) |
|
ZipFile(f).extractall(path=root) |
|
Path(f).unlink() |
|
r = None |
|
elif s.startswith('bash '): |
|
print(f'Running {s} ...') |
|
r = os.system(s) |
|
else: |
|
r = exec(s, {'yaml': data}) |
|
print(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n") |
|
else: |
|
raise Exception('Dataset not found.') |
|
|
|
return data |
|
|
|
|
|
def box_iou(box1, box2): |
|
|
|
""" |
|
Return intersection-over-union (Jaccard index) of boxes. |
|
Both sets of boxes are expected to be in (x1, y1, x2, y2) format. |
|
Arguments: |
|
box1 (Tensor[N, 4]) |
|
box2 (Tensor[M, 4]) |
|
Returns: |
|
iou (Tensor[N, M]): the NxM matrix containing the pairwise |
|
IoU values for every element in boxes1 and boxes2 |
|
""" |
|
|
|
def box_area(box): |
|
|
|
return (box[2] - box[0]) * (box[3] - box[1]) |
|
|
|
area1 = box_area(box1.T) |
|
area2 = box_area(box2.T) |
|
|
|
|
|
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) |
|
return inter / (area1[:, None] + area2 - inter) |
|
|
|
|
|
def increment_path(path, exist_ok=False, sep='', mkdir=False): |
|
|
|
path = Path(path) |
|
if path.exists() and not exist_ok: |
|
suffix = path.suffix |
|
path = path.with_suffix('') |
|
dirs = glob.glob(f"{path}{sep}*") |
|
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] |
|
i = [int(m.groups()[0]) for m in matches if m] |
|
n = max(i) + 1 if i else 2 |
|
path = Path(f"{path}{sep}{n}{suffix}") |
|
dir = path if path.suffix == '' else path.parent |
|
if not dir.exists() and mkdir: |
|
dir.mkdir(parents=True, exist_ok=True) |
|
return path |
|
|
|
|
|
def colorstr(*input): |
|
|
|
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) |
|
colors = {'black': '\033[30m', |
|
'red': '\033[31m', |
|
'green': '\033[32m', |
|
'yellow': '\033[33m', |
|
'blue': '\033[34m', |
|
'magenta': '\033[35m', |
|
'cyan': '\033[36m', |
|
'white': '\033[37m', |
|
'bright_black': '\033[90m', |
|
'bright_red': '\033[91m', |
|
'bright_green': '\033[92m', |
|
'bright_yellow': '\033[93m', |
|
'bright_blue': '\033[94m', |
|
'bright_magenta': '\033[95m', |
|
'bright_cyan': '\033[96m', |
|
'bright_white': '\033[97m', |
|
'end': '\033[0m', |
|
'bold': '\033[1m', |
|
'underline': '\033[4m'} |
|
return ''.join(colors[x] for x in args) + f'{string}' + colors['end'] |
|
|
|
|
|
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()): |
|
""" Compute the average precision, given the recall and precision curves. |
|
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. |
|
# Arguments |
|
tp: True positives (nparray, nx1 or nx10). |
|
conf: Objectness value from 0-1 (nparray). |
|
pred_cls: Predicted object classes (nparray). |
|
target_cls: True object classes (nparray). |
|
plot: Plot precision-recall curve at mAP@0.5 |
|
save_dir: Plot save directory |
|
# Returns |
|
The average precision as computed in py-faster-rcnn. |
|
""" |
|
|
|
|
|
i = np.argsort(-conf) |
|
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] |
|
|
|
|
|
unique_classes = np.unique(target_cls) |
|
nc = unique_classes.shape[0] |
|
|
|
|
|
px, py = np.linspace(0, 1, 1000), [] |
|
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) |
|
for ci, c in enumerate(unique_classes): |
|
i = pred_cls == c |
|
n_l = (target_cls == c).sum() |
|
n_p = i.sum() |
|
|
|
if n_p == 0 or n_l == 0: |
|
continue |
|
else: |
|
|
|
fpc = (1 - tp[i]).cumsum(0) |
|
tpc = tp[i].cumsum(0) |
|
|
|
|
|
recall = tpc / (n_l + 1e-16) |
|
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) |
|
|
|
|
|
precision = tpc / (tpc + fpc) |
|
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) |
|
|
|
|
|
for j in range(tp.shape[1]): |
|
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) |
|
if plot and j == 0: |
|
py.append(np.interp(px, mrec, mpre)) |
|
|
|
|
|
f1 = 2 * p * r / (p + r + 1e-16) |
|
if plot: |
|
plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names) |
|
plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1') |
|
plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision') |
|
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall') |
|
|
|
i = f1.mean(0).argmax() |
|
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32') |
|
|
|
|
|
def compute_ap(recall, precision): |
|
""" Compute the average precision, given the recall and precision curves |
|
# Arguments |
|
recall: The recall curve (list) |
|
precision: The precision curve (list) |
|
# Returns |
|
Average precision, precision curve, recall curve |
|
""" |
|
|
|
|
|
mrec = np.concatenate(([0.0], recall, [1.0])) |
|
mpre = np.concatenate(([1.0], precision, [0.0])) |
|
|
|
|
|
mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) |
|
|
|
|
|
method = 'interp' |
|
if method == 'interp': |
|
x = np.linspace(0, 1, 101) |
|
ap = np.trapz(np.interp(x, mrec, mpre), x) |
|
else: |
|
i = np.where(mrec[1:] != mrec[:-1])[0] |
|
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) |
|
|
|
return ap, mpre, mrec |
|
|
|
|
|
class ConfusionMatrix: |
|
|
|
def __init__(self, nc, conf=0.25, iou_thres=0.45): |
|
self.matrix = np.zeros((nc + 1, nc + 1)) |
|
self.nc = nc |
|
self.conf = conf |
|
self.iou_thres = iou_thres |
|
|
|
def process_batch(self, detections, labels): |
|
""" |
|
Return intersection-over-union (Jaccard index) of boxes. |
|
Both sets of boxes are expected to be in (x1, y1, x2, y2) format. |
|
Arguments: |
|
detections (Array[N, 6]), x1, y1, x2, y2, conf, class |
|
labels (Array[M, 5]), class, x1, y1, x2, y2 |
|
Returns: |
|
None, updates confusion matrix accordingly |
|
""" |
|
detections = detections[detections[:, 4] > self.conf] |
|
gt_classes = labels[:, 0].int() |
|
detection_classes = detections[:, 5].int() |
|
iou = box_iou(labels[:, 1:], detections[:, :4]) |
|
|
|
x = torch.where(iou > self.iou_thres) |
|
if x[0].shape[0]: |
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() |
|
if x[0].shape[0] > 1: |
|
matches = matches[matches[:, 2].argsort()[::-1]] |
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]] |
|
matches = matches[matches[:, 2].argsort()[::-1]] |
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]] |
|
else: |
|
matches = np.zeros((0, 3)) |
|
|
|
n = matches.shape[0] > 0 |
|
m0, m1, _ = matches.transpose().astype(np.int16) |
|
for i, gc in enumerate(gt_classes): |
|
j = m0 == i |
|
if n and sum(j) == 1: |
|
self.matrix[detection_classes[m1[j]], gc] += 1 |
|
else: |
|
self.matrix[self.nc, gc] += 1 |
|
|
|
if n: |
|
for i, dc in enumerate(detection_classes): |
|
if not any(m1 == i): |
|
self.matrix[dc, self.nc] += 1 |
|
|
|
def matrix(self): |
|
return self.matrix |
|
|
|
def plot(self, normalize=True, save_dir='', names=()): |
|
try: |
|
import seaborn as sn |
|
|
|
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-6) if normalize else 1) |
|
array[array < 0.005] = np.nan |
|
|
|
fig = plt.figure(figsize=(12, 9), tight_layout=True) |
|
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) |
|
labels = (0 < len(names) < 99) and len(names) == self.nc |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter('ignore') |
|
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, |
|
xticklabels=names + ['background FP'] if labels else "auto", |
|
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) |
|
fig.axes[0].set_xlabel('True') |
|
fig.axes[0].set_ylabel('Predicted') |
|
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) |
|
plt.close() |
|
except Exception as e: |
|
print(f'WARNING: ConfusionMatrix plot failure: {e}') |
|
|
|
def print(self): |
|
for i in range(self.nc + 1): |
|
print(' '.join(map(str, self.matrix[i]))) |
|
|
|
|
|
def output_to_target(output): |
|
|
|
targets = [] |
|
for i, o in enumerate(output): |
|
for *box, conf, cls in o.cpu().numpy(): |
|
targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf]) |
|
return np.array(targets) |
|
|
|
|
|
def plot_val_study(file='', dir='', x=None): |
|
|
|
save_dir = Path(file).parent if file else Path(dir) |
|
plot2 = False |
|
if plot2: |
|
ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel() |
|
|
|
fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True) |
|
|
|
for f in sorted(save_dir.glob('study*.txt')): |
|
y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T |
|
x = np.arange(y.shape[1]) if x is None else np.array(x) |
|
if plot2: |
|
s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)'] |
|
for i in range(7): |
|
ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8) |
|
ax[i].set_title(s[i]) |
|
|
|
j = y[3].argmax() + 1 |
|
ax2.plot(y[5, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8, |
|
label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO')) |
|
|
|
ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5], |
|
'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet') |
|
|
|
ax2.grid(alpha=0.2) |
|
ax2.set_yticks(np.arange(20, 60, 5)) |
|
ax2.set_xlim(0, 57) |
|
ax2.set_ylim(25, 55) |
|
ax2.set_xlabel('GPU Speed (ms/img)') |
|
ax2.set_ylabel('COCO AP val') |
|
ax2.legend(loc='lower right') |
|
f = save_dir / 'study.png' |
|
print(f'Saving {f}...') |
|
plt.savefig(f, dpi=300) |
|
|
|
|
|
def check_yaml(file, suffix=('.yaml', '.yml')): |
|
|
|
return check_file(file, suffix) |
|
|
|
|
|
def check_file(file, suffix=''): |
|
|
|
check_suffix(file, suffix) |
|
file = str(file) |
|
if Path(file).is_file() or file == '': |
|
return file |
|
elif file.startswith(('http:/', 'https:/')): |
|
url = str(Path(file)).replace(':/', '://') |
|
file = Path(urllib.parse.unquote(file).split('?')[0]).name |
|
print(f'Downloading {url} to {file}...') |
|
torch.hub.download_url_to_file(url, file) |
|
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' |
|
return file |
|
else: |
|
files = [] |
|
for d in 'data', 'models', 'utils': |
|
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) |
|
assert len(files), f'File not found: {file}' |
|
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" |
|
return files[0] |
|
|
|
|
|
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''): |
|
|
|
if file and suffix: |
|
if isinstance(suffix, str): |
|
suffix = [suffix] |
|
for f in file if isinstance(file, (list, tuple)) else [file]: |
|
assert Path(f).suffix.lower() in suffix, f"{msg}{f} acceptable suffix is {suffix}" |
|
|
|
|
|
def img2label_paths(img_paths): |
|
|
|
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep |
|
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths] |
|
|
|
|
|
def exif_size(img): |
|
|
|
s = img.size |
|
try: |
|
rotation = dict(img._getexif().items())[orientation] |
|
if rotation == 6: |
|
s = (s[1], s[0]) |
|
elif rotation == 8: |
|
s = (s[1], s[0]) |
|
except: |
|
pass |
|
|
|
return s |
|
|
|
|
|
def verify_image_label(args): |
|
|
|
im_file, lb_file, prefix = args |
|
nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] |
|
try: |
|
|
|
im = Image.open(im_file) |
|
im.verify() |
|
shape = exif_size(im) |
|
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' |
|
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}' |
|
if im.format.lower() in ('jpg', 'jpeg'): |
|
with open(im_file, 'rb') as f: |
|
f.seek(-2, 2) |
|
if f.read() != b'\xff\xd9': |
|
Image.open(im_file).save(im_file, format='JPEG', subsampling=0, quality=100) |
|
msg = f'{prefix}WARNING: corrupt JPEG restored and saved {im_file}' |
|
|
|
|
|
if os.path.isfile(lb_file): |
|
nf = 1 |
|
with open(lb_file, 'r') as f: |
|
l = [x.split() for x in f.read().strip().splitlines() if len(x)] |
|
if any([len(x) > 8 for x in l]): |
|
classes = np.array([x[0] for x in l], dtype=np.float32) |
|
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] |
|
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) |
|
l = np.array(l, dtype=np.float32) |
|
if len(l): |
|
assert l.shape[1] == 5, 'labels require 5 columns each' |
|
assert (l >= 0).all(), 'negative labels' |
|
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' |
|
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' |
|
else: |
|
ne = 1 |
|
l = np.zeros((0, 5), dtype=np.float32) |
|
else: |
|
nm = 1 |
|
l = np.zeros((0, 5), dtype=np.float32) |
|
return im_file, l, shape, segments, nm, nf, ne, nc, msg |
|
except Exception as e: |
|
nc = 1 |
|
msg = f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}' |
|
return [None, None, None, None, nm, nf, ne, nc, msg] |
|
|
|
|
|
def segments2boxes(segments): |
|
|
|
boxes = [] |
|
for s in segments: |
|
x, y = s.T |
|
boxes.append([x.min(), y.min(), x.max(), y.max()]) |
|
return xyxy2xywh(np.array(boxes)) |
|
|
|
|
|
def get_hash(paths): |
|
|
|
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) |
|
h = hashlib.md5(str(size).encode()) |
|
h.update(''.join(paths).encode()) |
|
return h.hexdigest() |
|
|
|
|
|
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): |
|
""" Dataloader that reuses workers |
|
|
|
Uses same syntax as vanilla DataLoader |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) |
|
self.iterator = super().__iter__() |
|
|
|
def __len__(self): |
|
return len(self.batch_sampler.sampler) |
|
|
|
def __iter__(self): |
|
for i in range(len(self)): |
|
yield next(self.iterator) |
|
|
|
|
|
class _RepeatSampler(object): |
|
""" Sampler that repeats forever |
|
|
|
Args: |
|
sampler (Sampler) |
|
""" |
|
|
|
def __init__(self, sampler): |
|
self.sampler = sampler |
|
|
|
def __iter__(self): |
|
while True: |
|
yield from iter(self.sampler) |
|
|
|
|
|
def load_image(self, i): |
|
|
|
im = self.imgs[i] |
|
if im is None: |
|
npy = self.img_npy[i] |
|
if npy and npy.exists(): |
|
im = np.load(npy) |
|
else: |
|
path = self.img_files[i] |
|
im = cv2.imread(path) |
|
assert im is not None, 'Image Not Found ' + path |
|
h0, w0 = im.shape[:2] |
|
r = self.img_size / max(h0, w0) |
|
if r != 1: |
|
im = cv2.resize(im, (int(w0 * r), int(h0 * r)), |
|
interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR) |
|
return im, (h0, w0), im.shape[:2] |
|
else: |
|
return self.imgs[i], self.img_hw0[i], self.img_hw[i] |
|
|
|
|
|
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): |
|
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) |
|
y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw |
|
y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh |
|
y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw |
|
y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh |
|
return y |
|
|
|
|
|
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): |
|
|
|
if clip: |
|
clip_coords(x, (h - eps, w - eps)) |
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) |
|
y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w |
|
y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h |
|
y[:, 2] = (x[:, 2] - x[:, 0]) / w |
|
y[:, 3] = (x[:, 3] - x[:, 1]) / h |
|
return y |
|
|
|
|
|
def post_process(x): |
|
grid = np.load("./grid.npy", allow_pickle=True) |
|
anchor_grid = np.load("./anchor_grid.npy", allow_pickle=True) |
|
x = list(x) |
|
z = [] |
|
stride = [8, 16, 32] |
|
for i in range(3): |
|
bs, _, ny, nx = x[i].shape |
|
x[i] = ( |
|
torch.tensor(x[i]) |
|
.view(bs, 3, 85, ny, nx) |
|
.permute(0, 1, 3, 4, 2) |
|
.contiguous() |
|
) |
|
y = x[i].sigmoid() |
|
xy = (y[..., 0:2] * 2.0 - 0.5 + grid[i]) * stride[i] |
|
wh = (y[..., 2:4] * 2) ** 2 * anchor_grid[i] |
|
y = torch.cat((xy, wh, y[..., 4:]), -1) |
|
z.append(y.view(bs, -1, 85)) |
|
|
|
return (torch.cat(z, 1), x) |