PKaushik commited on
Commit
dedceac
1 Parent(s): 9067b6a
Files changed (1) hide show
  1. yolov6/data/data_load.py +113 -0
yolov6/data/data_load.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # This code is based on
4
+ # https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py
5
+
6
+ import os
7
+ from torch.utils.data import dataloader, distributed
8
+
9
+ from .datasets import TrainValDataset
10
+ from yolov6.utils.events import LOGGER
11
+ from yolov6.utils.torch_utils import torch_distributed_zero_first
12
+
13
+
14
+ def create_dataloader(
15
+ path,
16
+ img_size,
17
+ batch_size,
18
+ stride,
19
+ hyp=None,
20
+ augment=False,
21
+ check_images=False,
22
+ check_labels=False,
23
+ pad=0.0,
24
+ rect=False,
25
+ rank=-1,
26
+ workers=8,
27
+ shuffle=False,
28
+ data_dict=None,
29
+ task="Train",
30
+ ):
31
+ """Create general dataloader.
32
+
33
+ Returns dataloader and dataset
34
+ """
35
+ if rect and shuffle:
36
+ LOGGER.warning(
37
+ "WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False"
38
+ )
39
+ shuffle = False
40
+ with torch_distributed_zero_first(rank):
41
+ dataset = TrainValDataset(
42
+ path,
43
+ img_size,
44
+ batch_size,
45
+ augment=augment,
46
+ hyp=hyp,
47
+ rect=rect,
48
+ check_images=check_images,
49
+ check_labels=check_labels,
50
+ stride=int(stride),
51
+ pad=pad,
52
+ rank=rank,
53
+ data_dict=data_dict,
54
+ task=task,
55
+ )
56
+
57
+ batch_size = min(batch_size, len(dataset))
58
+ workers = min(
59
+ [
60
+ os.cpu_count() // int(os.getenv("WORLD_SIZE", 1)),
61
+ batch_size if batch_size > 1 else 0,
62
+ workers,
63
+ ]
64
+ ) # number of workers
65
+ sampler = (
66
+ None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
67
+ )
68
+ return (
69
+ TrainValDataLoader(
70
+ dataset,
71
+ batch_size=batch_size,
72
+ shuffle=shuffle and sampler is None,
73
+ num_workers=workers,
74
+ sampler=sampler,
75
+ pin_memory=True,
76
+ collate_fn=TrainValDataset.collate_fn,
77
+ ),
78
+ dataset,
79
+ )
80
+
81
+
82
+ class TrainValDataLoader(dataloader.DataLoader):
83
+ """Dataloader that reuses workers
84
+
85
+ Uses same syntax as vanilla DataLoader
86
+ """
87
+
88
+ def __init__(self, *args, **kwargs):
89
+ super().__init__(*args, **kwargs)
90
+ object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
91
+ self.iterator = super().__iter__()
92
+
93
+ def __len__(self):
94
+ return len(self.batch_sampler.sampler)
95
+
96
+ def __iter__(self):
97
+ for i in range(len(self)):
98
+ yield next(self.iterator)
99
+
100
+
101
+ class _RepeatSampler:
102
+ """Sampler that repeats forever
103
+
104
+ Args:
105
+ sampler (Sampler)
106
+ """
107
+
108
+ def __init__(self, sampler):
109
+ self.sampler = sampler
110
+
111
+ def __iter__(self):
112
+ while True:
113
+ yield from iter(self.sampler)