|
|
|
|
|
from itertools import repeat |
|
from multiprocessing.pool import ThreadPool |
|
from pathlib import Path |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from tqdm import tqdm |
|
|
|
from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable |
|
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms |
|
from .base import BaseDataset |
|
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label |
|
|
|
|
|
class YOLODataset(BaseDataset): |
|
""" |
|
Dataset class for loading object detection and/or segmentation labels in YOLO format. |
|
|
|
Args: |
|
data (dict, optional): A dataset YAML dictionary. Defaults to None. |
|
use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False. |
|
use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False. |
|
|
|
Returns: |
|
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. |
|
""" |
|
cache_version = '1.0.2' |
|
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4] |
|
|
|
def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs): |
|
self.use_segments = use_segments |
|
self.use_keypoints = use_keypoints |
|
self.data = data |
|
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.' |
|
super().__init__(*args, **kwargs) |
|
|
|
def cache_labels(self, path=Path('./labels.cache')): |
|
"""Cache dataset labels, check images and read shapes. |
|
Args: |
|
path (Path): path where to save the cache file (default: Path('./labels.cache')). |
|
Returns: |
|
(dict): labels. |
|
""" |
|
x = {'labels': []} |
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] |
|
desc = f'{self.prefix}Scanning {path.parent / path.stem}...' |
|
total = len(self.im_files) |
|
nkpt, ndim = self.data.get('kpt_shape', (0, 0)) |
|
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)): |
|
raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " |
|
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'") |
|
with ThreadPool(NUM_THREADS) as pool: |
|
results = pool.imap(func=verify_image_label, |
|
iterable=zip(self.im_files, self.label_files, repeat(self.prefix), |
|
repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt), |
|
repeat(ndim))) |
|
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT) |
|
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: |
|
nm += nm_f |
|
nf += nf_f |
|
ne += ne_f |
|
nc += nc_f |
|
if im_file: |
|
x['labels'].append( |
|
dict( |
|
im_file=im_file, |
|
shape=shape, |
|
cls=lb[:, 0:1], |
|
bboxes=lb[:, 1:], |
|
segments=segments, |
|
keypoints=keypoint, |
|
normalized=True, |
|
bbox_format='xywh')) |
|
if msg: |
|
msgs.append(msg) |
|
pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt' |
|
pbar.close() |
|
|
|
if msgs: |
|
LOGGER.info('\n'.join(msgs)) |
|
if nf == 0: |
|
LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}') |
|
x['hash'] = get_hash(self.label_files + self.im_files) |
|
x['results'] = nf, nm, ne, nc, len(self.im_files) |
|
x['msgs'] = msgs |
|
x['version'] = self.cache_version |
|
if is_dir_writeable(path.parent): |
|
if path.exists(): |
|
path.unlink() |
|
np.save(str(path), x) |
|
path.with_suffix('.cache.npy').rename(path) |
|
LOGGER.info(f'{self.prefix}New cache created: {path}') |
|
else: |
|
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.') |
|
return x |
|
|
|
def get_labels(self): |
|
"""Returns dictionary of labels for YOLO training.""" |
|
self.label_files = img2label_paths(self.im_files) |
|
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') |
|
try: |
|
import gc |
|
gc.disable() |
|
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True |
|
gc.enable() |
|
assert cache['version'] == self.cache_version |
|
assert cache['hash'] == get_hash(self.label_files + self.im_files) |
|
except (FileNotFoundError, AssertionError, AttributeError): |
|
cache, exists = self.cache_labels(cache_path), False |
|
|
|
|
|
nf, nm, ne, nc, n = cache.pop('results') |
|
if exists and LOCAL_RANK in (-1, 0): |
|
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt' |
|
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) |
|
if cache['msgs']: |
|
LOGGER.info('\n'.join(cache['msgs'])) |
|
if nf == 0: |
|
raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}') |
|
|
|
|
|
[cache.pop(k) for k in ('hash', 'version', 'msgs')] |
|
labels = cache['labels'] |
|
self.im_files = [lb['im_file'] for lb in labels] |
|
|
|
|
|
lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels) |
|
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) |
|
if len_segments and len_boxes != len_segments: |
|
LOGGER.warning( |
|
f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, ' |
|
f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. ' |
|
'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.') |
|
for lb in labels: |
|
lb['segments'] = [] |
|
if len_cls == 0: |
|
raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}') |
|
return labels |
|
|
|
|
|
def build_transforms(self, hyp=None): |
|
"""Builds and appends transforms to the list.""" |
|
if self.augment: |
|
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 |
|
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 |
|
transforms = v8_transforms(self, self.imgsz, hyp) |
|
else: |
|
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) |
|
transforms.append( |
|
Format(bbox_format='xywh', |
|
normalize=True, |
|
return_mask=self.use_segments, |
|
return_keypoint=self.use_keypoints, |
|
batch_idx=True, |
|
mask_ratio=hyp.mask_ratio, |
|
mask_overlap=hyp.overlap_mask)) |
|
return transforms |
|
|
|
def close_mosaic(self, hyp): |
|
"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.""" |
|
hyp.mosaic = 0.0 |
|
hyp.copy_paste = 0.0 |
|
hyp.mixup = 0.0 |
|
self.transforms = self.build_transforms(hyp) |
|
|
|
def update_labels_info(self, label): |
|
"""custom your label format here.""" |
|
|
|
|
|
bboxes = label.pop('bboxes') |
|
segments = label.pop('segments') |
|
keypoints = label.pop('keypoints', None) |
|
bbox_format = label.pop('bbox_format') |
|
normalized = label.pop('normalized') |
|
label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) |
|
return label |
|
|
|
@staticmethod |
|
def collate_fn(batch): |
|
"""Collates data samples into batches.""" |
|
new_batch = {} |
|
keys = batch[0].keys() |
|
values = list(zip(*[list(b.values()) for b in batch])) |
|
for i, k in enumerate(keys): |
|
value = values[i] |
|
if k == 'img': |
|
value = torch.stack(value, 0) |
|
if k in ['masks', 'keypoints', 'bboxes', 'cls']: |
|
value = torch.cat(value, 0) |
|
new_batch[k] = value |
|
new_batch['batch_idx'] = list(new_batch['batch_idx']) |
|
for i in range(len(new_batch['batch_idx'])): |
|
new_batch['batch_idx'][i] += i |
|
new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0) |
|
return new_batch |
|
|
|
|
|
|
|
class ClassificationDataset(torchvision.datasets.ImageFolder): |
|
""" |
|
YOLO Classification Dataset. |
|
|
|
Args: |
|
root (str): Dataset path. |
|
|
|
Attributes: |
|
cache_ram (bool): True if images should be cached in RAM, False otherwise. |
|
cache_disk (bool): True if images should be cached on disk, False otherwise. |
|
samples (list): List of samples containing file, index, npy, and im. |
|
torch_transforms (callable): torchvision transforms applied to the dataset. |
|
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True. |
|
""" |
|
|
|
def __init__(self, root, args, augment=False, cache=False): |
|
""" |
|
Initialize YOLO object with root, image size, augmentations, and cache settings. |
|
|
|
Args: |
|
root (str): Dataset path. |
|
args (Namespace): Argument parser containing dataset related settings. |
|
augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False. |
|
cache (Union[bool, str], optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False. |
|
""" |
|
super().__init__(root=root) |
|
if augment and args.fraction < 1.0: |
|
self.samples = self.samples[:round(len(self.samples) * args.fraction)] |
|
self.cache_ram = cache is True or cache == 'ram' |
|
self.cache_disk = cache == 'disk' |
|
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] |
|
self.torch_transforms = classify_transforms(args.imgsz) |
|
self.album_transforms = classify_albumentations( |
|
augment=augment, |
|
size=args.imgsz, |
|
scale=(1.0 - args.scale, 1.0), |
|
hflip=args.fliplr, |
|
vflip=args.flipud, |
|
hsv_h=args.hsv_h, |
|
hsv_s=args.hsv_s, |
|
hsv_v=args.hsv_v, |
|
mean=(0.0, 0.0, 0.0), |
|
std=(1.0, 1.0, 1.0), |
|
auto_aug=False) if augment else None |
|
|
|
def __getitem__(self, i): |
|
"""Returns subset of data and targets corresponding to given indices.""" |
|
f, j, fn, im = self.samples[i] |
|
if self.cache_ram and im is None: |
|
im = self.samples[i][3] = cv2.imread(f) |
|
elif self.cache_disk: |
|
if not fn.exists(): |
|
np.save(fn.as_posix(), cv2.imread(f)) |
|
im = np.load(fn) |
|
else: |
|
im = cv2.imread(f) |
|
if self.album_transforms: |
|
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image'] |
|
else: |
|
sample = self.torch_transforms(im) |
|
return {'img': sample, 'cls': j} |
|
|
|
def __len__(self) -> int: |
|
return len(self.samples) |
|
|
|
|
|
|
|
class SemanticDataset(BaseDataset): |
|
|
|
def __init__(self): |
|
"""Initialize a SemanticDataset object.""" |
|
super().__init__() |
|
|