Spaces:
Build error
Build error
File size: 2,777 Bytes
dedceac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# This code is based on
# https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py
import os
from torch.utils.data import dataloader, distributed
from .datasets import TrainValDataset
from yolov6.utils.events import LOGGER
from yolov6.utils.torch_utils import torch_distributed_zero_first
def create_dataloader(
path,
img_size,
batch_size,
stride,
hyp=None,
augment=False,
check_images=False,
check_labels=False,
pad=0.0,
rect=False,
rank=-1,
workers=8,
shuffle=False,
data_dict=None,
task="Train",
):
"""Create general dataloader.
Returns dataloader and dataset
"""
if rect and shuffle:
LOGGER.warning(
"WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False"
)
shuffle = False
with torch_distributed_zero_first(rank):
dataset = TrainValDataset(
path,
img_size,
batch_size,
augment=augment,
hyp=hyp,
rect=rect,
check_images=check_images,
check_labels=check_labels,
stride=int(stride),
pad=pad,
rank=rank,
data_dict=data_dict,
task=task,
)
batch_size = min(batch_size, len(dataset))
workers = min(
[
os.cpu_count() // int(os.getenv("WORLD_SIZE", 1)),
batch_size if batch_size > 1 else 0,
workers,
]
) # number of workers
sampler = (
None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
)
return (
TrainValDataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=workers,
sampler=sampler,
pin_memory=True,
collate_fn=TrainValDataset.collate_fn,
),
dataset,
)
class TrainValDataLoader(dataloader.DataLoader):
"""Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler:
"""Sampler that repeats forever
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
|