Spaces:
Sleeping
Sleeping
上传改进后的yolov8
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- ultralytics/__init__.py +13 -0
- ultralytics/__pycache__/__init__.cpython-39.pyc +0 -0
- ultralytics/cfg/__init__.py +441 -0
- ultralytics/cfg/__pycache__/__init__.cpython-39.pyc +0 -0
- ultralytics/cfg/default.yaml +114 -0
- ultralytics/cfg/models/v8/yolov8.yaml +46 -0
- ultralytics/cfg/models/v8/yolov8_ECA.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8_GAM.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8_ResBlock_CBAM.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8_SA.yaml +50 -0
- ultralytics/cfg/trackers/botsort.yaml +18 -0
- ultralytics/cfg/trackers/bytetrack.yaml +11 -0
- ultralytics/data/__init__.py +8 -0
- ultralytics/data/__pycache__/__init__.cpython-39.pyc +0 -0
- ultralytics/data/__pycache__/augment.cpython-39.pyc +0 -0
- ultralytics/data/__pycache__/base.cpython-39.pyc +0 -0
- ultralytics/data/__pycache__/build.cpython-39.pyc +0 -0
- ultralytics/data/__pycache__/dataset.cpython-39.pyc +0 -0
- ultralytics/data/__pycache__/loaders.cpython-39.pyc +0 -0
- ultralytics/data/__pycache__/utils.cpython-39.pyc +0 -0
- ultralytics/data/annotator.py +39 -0
- ultralytics/data/augment.py +906 -0
- ultralytics/data/base.py +287 -0
- ultralytics/data/build.py +170 -0
- ultralytics/data/converter.py +230 -0
- ultralytics/data/dataloaders/__init__.py +0 -0
- ultralytics/data/dataset.py +275 -0
- ultralytics/data/loaders.py +407 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +60 -0
- ultralytics/data/scripts/get_coco128.sh +17 -0
- ultralytics/data/scripts/get_imagenet.sh +51 -0
- ultralytics/data/utils.py +557 -0
- ultralytics/engine/__init__.py +0 -0
- ultralytics/engine/__pycache__/__init__.cpython-39.pyc +0 -0
- ultralytics/engine/__pycache__/exporter.cpython-39.pyc +0 -0
- ultralytics/engine/__pycache__/model.cpython-39.pyc +0 -0
- ultralytics/engine/__pycache__/predictor.cpython-39.pyc +0 -0
- ultralytics/engine/__pycache__/results.cpython-39.pyc +0 -0
- ultralytics/engine/__pycache__/trainer.cpython-39.pyc +0 -0
- ultralytics/engine/__pycache__/validator.cpython-39.pyc +0 -0
- ultralytics/engine/exporter.py +969 -0
- ultralytics/engine/model.py +465 -0
- ultralytics/engine/predictor.py +359 -0
- ultralytics/engine/results.py +604 -0
- ultralytics/engine/trainer.py +664 -0
- ultralytics/engine/validator.py +279 -0
- ultralytics/hub/__init__.py +121 -0
- ultralytics/hub/__pycache__/__init__.cpython-39.pyc +0 -0
- ultralytics/hub/__pycache__/auth.cpython-39.pyc +0 -0
ultralytics/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
__version__ = '8.0.147'
|
4 |
+
|
5 |
+
from ultralytics.hub import start
|
6 |
+
from ultralytics.models import RTDETR, SAM, YOLO
|
7 |
+
from ultralytics.models.fastsam import FastSAM
|
8 |
+
from ultralytics.models.nas import NAS
|
9 |
+
from ultralytics.utils import SETTINGS as settings
|
10 |
+
from ultralytics.utils.checks import check_yolo as checks
|
11 |
+
from ultralytics.utils.downloads import download
|
12 |
+
|
13 |
+
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'start', 'settings' # allow simpler import
|
ultralytics/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (681 Bytes). View file
|
|
ultralytics/cfg/__init__.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
import re
|
5 |
+
import shutil
|
6 |
+
import sys
|
7 |
+
from difflib import get_close_matches
|
8 |
+
from pathlib import Path
|
9 |
+
from types import SimpleNamespace
|
10 |
+
from typing import Dict, List, Union
|
11 |
+
|
12 |
+
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, SETTINGS, SETTINGS_YAML,
|
13 |
+
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load,
|
14 |
+
yaml_print)
|
15 |
+
|
16 |
+
# Define valid tasks and modes
|
17 |
+
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
18 |
+
TASKS = 'detect', 'segment', 'classify', 'pose'
|
19 |
+
TASK2DATA = {'detect': 'coco8.yaml', 'segment': 'coco8-seg.yaml', 'classify': 'imagenet100', 'pose': 'coco8-pose.yaml'}
|
20 |
+
TASK2MODEL = {
|
21 |
+
'detect': 'yolov8n.pt',
|
22 |
+
'segment': 'yolov8n-seg.pt',
|
23 |
+
'classify': 'yolov8n-cls.pt',
|
24 |
+
'pose': 'yolov8n-pose.pt'}
|
25 |
+
TASK2METRIC = {
|
26 |
+
'detect': 'metrics/mAP50-95(B)',
|
27 |
+
'segment': 'metrics/mAP50-95(M)',
|
28 |
+
'classify': 'metrics/accuracy_top1',
|
29 |
+
'pose': 'metrics/mAP50-95(P)'}
|
30 |
+
|
31 |
+
CLI_HELP_MSG = \
|
32 |
+
f"""
|
33 |
+
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
34 |
+
|
35 |
+
yolo TASK MODE ARGS
|
36 |
+
|
37 |
+
Where TASK (optional) is one of {TASKS}
|
38 |
+
MODE (required) is one of {MODES}
|
39 |
+
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
40 |
+
See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
|
41 |
+
|
42 |
+
1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
43 |
+
yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
44 |
+
|
45 |
+
2. Predict a YouTube video using a pretrained segmentation model at image size 320:
|
46 |
+
yolo predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
|
47 |
+
|
48 |
+
3. Val a pretrained detection model at batch-size 1 and image size 640:
|
49 |
+
yolo val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
50 |
+
|
51 |
+
4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
52 |
+
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
53 |
+
|
54 |
+
5. Run special commands:
|
55 |
+
yolo help
|
56 |
+
yolo checks
|
57 |
+
yolo version
|
58 |
+
yolo settings
|
59 |
+
yolo copy-cfg
|
60 |
+
yolo cfg
|
61 |
+
|
62 |
+
Docs: https://docs.ultralytics.com
|
63 |
+
Community: https://community.ultralytics.com
|
64 |
+
GitHub: https://github.com/ultralytics/ultralytics
|
65 |
+
"""
|
66 |
+
|
67 |
+
# Define keys for arg type checks
|
68 |
+
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
|
69 |
+
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
|
70 |
+
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
|
71 |
+
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
|
72 |
+
CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
73 |
+
'line_width', 'workspace', 'nbs', 'save_period')
|
74 |
+
CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
|
75 |
+
'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
|
76 |
+
'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
|
77 |
+
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile')
|
78 |
+
|
79 |
+
|
80 |
+
def cfg2dict(cfg):
|
81 |
+
"""
|
82 |
+
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
cfg (str | Path | SimpleNamespace): Configuration object to be converted to a dictionary.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
cfg (dict): Configuration object in dictionary format.
|
89 |
+
"""
|
90 |
+
if isinstance(cfg, (str, Path)):
|
91 |
+
cfg = yaml_load(cfg) # load dict
|
92 |
+
elif isinstance(cfg, SimpleNamespace):
|
93 |
+
cfg = vars(cfg) # convert to dict
|
94 |
+
return cfg
|
95 |
+
|
96 |
+
|
97 |
+
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
|
98 |
+
"""
|
99 |
+
Load and merge configuration data from a file or dictionary.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
cfg (str | Path | Dict | SimpleNamespace): Configuration data.
|
103 |
+
overrides (str | Dict | optional): Overrides in the form of a file name or a dictionary. Default is None.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
(SimpleNamespace): Training arguments namespace.
|
107 |
+
"""
|
108 |
+
cfg = cfg2dict(cfg)
|
109 |
+
|
110 |
+
# Merge overrides
|
111 |
+
if overrides:
|
112 |
+
overrides = cfg2dict(overrides)
|
113 |
+
check_dict_alignment(cfg, overrides)
|
114 |
+
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
115 |
+
|
116 |
+
# Special handling for numeric project/name
|
117 |
+
for k in 'project', 'name':
|
118 |
+
if k in cfg and isinstance(cfg[k], (int, float)):
|
119 |
+
cfg[k] = str(cfg[k])
|
120 |
+
if cfg.get('name') == 'model': # assign model to 'name' arg
|
121 |
+
cfg['name'] = cfg.get('model', '').split('.')[0]
|
122 |
+
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
|
123 |
+
|
124 |
+
# Type and Value checks
|
125 |
+
for k, v in cfg.items():
|
126 |
+
if v is not None: # None values may be from optional args
|
127 |
+
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
|
128 |
+
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
129 |
+
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
130 |
+
elif k in CFG_FRACTION_KEYS:
|
131 |
+
if not isinstance(v, (int, float)):
|
132 |
+
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
133 |
+
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
134 |
+
if not (0.0 <= v <= 1.0):
|
135 |
+
raise ValueError(f"'{k}={v}' is an invalid value. "
|
136 |
+
f"Valid '{k}' values are between 0.0 and 1.0.")
|
137 |
+
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
138 |
+
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
139 |
+
f"'{k}' must be an int (i.e. '{k}=8')")
|
140 |
+
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
141 |
+
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
142 |
+
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
|
143 |
+
|
144 |
+
# Return instance
|
145 |
+
return IterableSimpleNamespace(**cfg)
|
146 |
+
|
147 |
+
|
148 |
+
def _handle_deprecation(custom):
|
149 |
+
"""Hardcoded function to handle deprecated config keys"""
|
150 |
+
|
151 |
+
for key in custom.copy().keys():
|
152 |
+
if key == 'hide_labels':
|
153 |
+
deprecation_warn(key, 'show_labels')
|
154 |
+
custom['show_labels'] = custom.pop('hide_labels') == 'False'
|
155 |
+
if key == 'hide_conf':
|
156 |
+
deprecation_warn(key, 'show_conf')
|
157 |
+
custom['show_conf'] = custom.pop('hide_conf') == 'False'
|
158 |
+
if key == 'line_thickness':
|
159 |
+
deprecation_warn(key, 'line_width')
|
160 |
+
custom['line_width'] = custom.pop('line_thickness')
|
161 |
+
|
162 |
+
return custom
|
163 |
+
|
164 |
+
|
165 |
+
def check_dict_alignment(base: Dict, custom: Dict, e=None):
|
166 |
+
"""
|
167 |
+
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
|
168 |
+
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
custom (dict): a dictionary of custom configuration options
|
172 |
+
base (dict): a dictionary of base configuration options
|
173 |
+
"""
|
174 |
+
custom = _handle_deprecation(custom)
|
175 |
+
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
|
176 |
+
mismatched = [k for k in custom_keys if k not in base_keys]
|
177 |
+
if mismatched:
|
178 |
+
string = ''
|
179 |
+
for x in mismatched:
|
180 |
+
matches = get_close_matches(x, base_keys) # key list
|
181 |
+
matches = [f'{k}={base[k]}' if base.get(k) is not None else k for k in matches]
|
182 |
+
match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
|
183 |
+
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
184 |
+
raise SyntaxError(string + CLI_HELP_MSG) from e
|
185 |
+
|
186 |
+
|
187 |
+
def merge_equals_args(args: List[str]) -> List[str]:
|
188 |
+
"""
|
189 |
+
Merges arguments around isolated '=' args in a list of strings.
|
190 |
+
The function considers cases where the first argument ends with '=' or the second starts with '=',
|
191 |
+
as well as when the middle one is an equals sign.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
args (List[str]): A list of strings where each element is an argument.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
List[str]: A list of strings where the arguments around isolated '=' are merged.
|
198 |
+
"""
|
199 |
+
new_args = []
|
200 |
+
for i, arg in enumerate(args):
|
201 |
+
if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
|
202 |
+
new_args[-1] += f'={args[i + 1]}'
|
203 |
+
del args[i + 1]
|
204 |
+
elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
|
205 |
+
new_args.append(f'{arg}{args[i + 1]}')
|
206 |
+
del args[i + 1]
|
207 |
+
elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
|
208 |
+
new_args[-1] += arg
|
209 |
+
else:
|
210 |
+
new_args.append(arg)
|
211 |
+
return new_args
|
212 |
+
|
213 |
+
|
214 |
+
def handle_yolo_hub(args: List[str]) -> None:
|
215 |
+
"""
|
216 |
+
Handle Ultralytics HUB command-line interface (CLI) commands.
|
217 |
+
|
218 |
+
This function processes Ultralytics HUB CLI commands such as login and logout.
|
219 |
+
It should be called when executing a script with arguments related to HUB authentication.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
args (List[str]): A list of command line arguments
|
223 |
+
|
224 |
+
Example:
|
225 |
+
```python
|
226 |
+
python my_script.py hub login your_api_key
|
227 |
+
```
|
228 |
+
"""
|
229 |
+
from ultralytics import hub
|
230 |
+
|
231 |
+
if args[0] == 'login':
|
232 |
+
key = args[1] if len(args) > 1 else ''
|
233 |
+
# Log in to Ultralytics HUB using the provided API key
|
234 |
+
hub.login(key)
|
235 |
+
elif args[0] == 'logout':
|
236 |
+
# Log out from Ultralytics HUB
|
237 |
+
hub.logout()
|
238 |
+
|
239 |
+
|
240 |
+
def handle_yolo_settings(args: List[str]) -> None:
|
241 |
+
"""
|
242 |
+
Handle YOLO settings command-line interface (CLI) commands.
|
243 |
+
|
244 |
+
This function processes YOLO settings CLI commands such as reset.
|
245 |
+
It should be called when executing a script with arguments related to YOLO settings management.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
args (List[str]): A list of command line arguments for YOLO settings management.
|
249 |
+
|
250 |
+
Example:
|
251 |
+
```python
|
252 |
+
python my_script.py yolo settings reset
|
253 |
+
```
|
254 |
+
"""
|
255 |
+
if any(args):
|
256 |
+
if args[0] == 'reset':
|
257 |
+
SETTINGS_YAML.unlink() # delete the settings file
|
258 |
+
SETTINGS.reset() # create new settings
|
259 |
+
LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
|
260 |
+
else: # save a new setting
|
261 |
+
new = dict(parse_key_value_pair(a) for a in args)
|
262 |
+
check_dict_alignment(SETTINGS, new)
|
263 |
+
SETTINGS.update(new)
|
264 |
+
|
265 |
+
yaml_print(SETTINGS_YAML) # print the current settings
|
266 |
+
|
267 |
+
|
268 |
+
def parse_key_value_pair(pair):
|
269 |
+
"""Parse one 'key=value' pair and return key and value."""
|
270 |
+
re.sub(r' *= *', '=', pair) # remove spaces around equals sign
|
271 |
+
k, v = pair.split('=', 1) # split on first '=' sign
|
272 |
+
assert v, f"missing '{k}' value"
|
273 |
+
return k, smart_value(v)
|
274 |
+
|
275 |
+
|
276 |
+
def smart_value(v):
|
277 |
+
"""Convert a string to an underlying type such as int, float, bool, etc."""
|
278 |
+
if v.lower() == 'none':
|
279 |
+
return None
|
280 |
+
elif v.lower() == 'true':
|
281 |
+
return True
|
282 |
+
elif v.lower() == 'false':
|
283 |
+
return False
|
284 |
+
else:
|
285 |
+
with contextlib.suppress(Exception):
|
286 |
+
return eval(v)
|
287 |
+
return v
|
288 |
+
|
289 |
+
|
290 |
+
def entrypoint(debug=''):
|
291 |
+
"""
|
292 |
+
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
293 |
+
to the package.
|
294 |
+
|
295 |
+
This function allows for:
|
296 |
+
- passing mandatory YOLO args as a list of strings
|
297 |
+
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
|
298 |
+
- specifying the mode, either 'train', 'val', 'test', or 'predict'
|
299 |
+
- running special modes like 'checks'
|
300 |
+
- passing overrides to the package's configuration
|
301 |
+
|
302 |
+
It uses the package's default cfg and initializes it using the passed overrides.
|
303 |
+
Then it calls the CLI function with the composed cfg
|
304 |
+
"""
|
305 |
+
args = (debug.split(' ') if debug else sys.argv)[1:]
|
306 |
+
if not args: # no arguments passed
|
307 |
+
LOGGER.info(CLI_HELP_MSG)
|
308 |
+
return
|
309 |
+
|
310 |
+
special = {
|
311 |
+
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
312 |
+
'checks': checks.check_yolo,
|
313 |
+
'version': lambda: LOGGER.info(__version__),
|
314 |
+
'settings': lambda: handle_yolo_settings(args[1:]),
|
315 |
+
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
316 |
+
'hub': lambda: handle_yolo_hub(args[1:]),
|
317 |
+
'login': lambda: handle_yolo_hub(args),
|
318 |
+
'copy-cfg': copy_default_cfg}
|
319 |
+
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
320 |
+
|
321 |
+
# Define common mis-uses of special commands, i.e. -h, -help, --help
|
322 |
+
special.update({k[0]: v for k, v in special.items()}) # singular
|
323 |
+
special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular
|
324 |
+
special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
|
325 |
+
|
326 |
+
overrides = {} # basic overrides, i.e. imgsz=320
|
327 |
+
for a in merge_equals_args(args): # merge spaces around '=' sign
|
328 |
+
if a.startswith('--'):
|
329 |
+
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
330 |
+
a = a[2:]
|
331 |
+
if a.endswith(','):
|
332 |
+
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
|
333 |
+
a = a[:-1]
|
334 |
+
if '=' in a:
|
335 |
+
try:
|
336 |
+
k, v = parse_key_value_pair(a)
|
337 |
+
if k == 'cfg': # custom.yaml passed
|
338 |
+
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
|
339 |
+
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
|
340 |
+
else:
|
341 |
+
overrides[k] = v
|
342 |
+
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
343 |
+
check_dict_alignment(full_args_dict, {a: ''}, e)
|
344 |
+
|
345 |
+
elif a in TASKS:
|
346 |
+
overrides['task'] = a
|
347 |
+
elif a in MODES:
|
348 |
+
overrides['mode'] = a
|
349 |
+
elif a.lower() in special:
|
350 |
+
special[a.lower()]()
|
351 |
+
return
|
352 |
+
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
|
353 |
+
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
|
354 |
+
elif a in DEFAULT_CFG_DICT:
|
355 |
+
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
356 |
+
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
|
357 |
+
else:
|
358 |
+
check_dict_alignment(full_args_dict, {a: ''})
|
359 |
+
|
360 |
+
# Check keys
|
361 |
+
check_dict_alignment(full_args_dict, overrides)
|
362 |
+
|
363 |
+
# Mode
|
364 |
+
mode = overrides.get('mode')
|
365 |
+
if mode is None:
|
366 |
+
mode = DEFAULT_CFG.mode or 'predict'
|
367 |
+
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
368 |
+
elif mode not in MODES:
|
369 |
+
if mode not in ('checks', checks):
|
370 |
+
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
|
371 |
+
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
372 |
+
checks.check_yolo()
|
373 |
+
return
|
374 |
+
|
375 |
+
# Task
|
376 |
+
task = overrides.pop('task', None)
|
377 |
+
if task:
|
378 |
+
if task not in TASKS:
|
379 |
+
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
|
380 |
+
if 'model' not in overrides:
|
381 |
+
overrides['model'] = TASK2MODEL[task]
|
382 |
+
|
383 |
+
# Model
|
384 |
+
model = overrides.pop('model', DEFAULT_CFG.model)
|
385 |
+
if model is None:
|
386 |
+
model = 'yolov8n.pt'
|
387 |
+
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
388 |
+
overrides['model'] = model
|
389 |
+
if 'rtdetr' in model.lower(): # guess architecture
|
390 |
+
from ultralytics import RTDETR
|
391 |
+
model = RTDETR(model) # no task argument
|
392 |
+
elif 'fastsam' in model.lower():
|
393 |
+
from ultralytics import FastSAM
|
394 |
+
model = FastSAM(model)
|
395 |
+
elif 'sam' in model.lower():
|
396 |
+
from ultralytics import SAM
|
397 |
+
model = SAM(model)
|
398 |
+
else:
|
399 |
+
from ultralytics import YOLO
|
400 |
+
model = YOLO(model, task=task)
|
401 |
+
if isinstance(overrides.get('pretrained'), str):
|
402 |
+
model.load(overrides['pretrained'])
|
403 |
+
|
404 |
+
# Task Update
|
405 |
+
if task != model.task:
|
406 |
+
if task:
|
407 |
+
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
|
408 |
+
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
|
409 |
+
task = model.task
|
410 |
+
|
411 |
+
# Mode
|
412 |
+
if mode in ('predict', 'track') and 'source' not in overrides:
|
413 |
+
overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
|
414 |
+
else 'https://ultralytics.com/images/bus.jpg'
|
415 |
+
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
416 |
+
elif mode in ('train', 'val'):
|
417 |
+
if 'data' not in overrides:
|
418 |
+
overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
419 |
+
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
420 |
+
elif mode == 'export':
|
421 |
+
if 'format' not in overrides:
|
422 |
+
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
|
423 |
+
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
424 |
+
|
425 |
+
# Run command in python
|
426 |
+
# getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
|
427 |
+
getattr(model, mode)(**overrides) # default args from model
|
428 |
+
|
429 |
+
|
430 |
+
# Special modes --------------------------------------------------------------------------------------------------------
|
431 |
+
def copy_default_cfg():
|
432 |
+
"""Copy and create a new default configuration file with '_copy' appended to its name."""
|
433 |
+
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
434 |
+
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
435 |
+
LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
|
436 |
+
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
|
437 |
+
|
438 |
+
|
439 |
+
if __name__ == '__main__':
|
440 |
+
# Example Usage: entrypoint(debug='yolo predict model=yolov8n.pt')
|
441 |
+
entrypoint(debug='')
|
ultralytics/cfg/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (16.3 kB). View file
|
|
ultralytics/cfg/default.yaml
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# Default training settings and hyperparameters for medium-augmentation COCO training
|
3 |
+
|
4 |
+
task: detect # (str) YOLO task, i.e. detect, segment, classify, pose
|
5 |
+
mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark
|
6 |
+
|
7 |
+
# Train settings -------------------------------------------------------------------------------------------------------
|
8 |
+
model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
|
9 |
+
data: # (str, optional) path to data file, i.e. coco128.yaml
|
10 |
+
epochs: 100 # (int) number of epochs to train for
|
11 |
+
patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
|
12 |
+
batch: -1 # (int) number of images per batch (-1 for AutoBatch)
|
13 |
+
imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
|
14 |
+
save: True # (bool) save train checkpoints and predict results
|
15 |
+
save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1)
|
16 |
+
cache: False # (bool) True/ram, disk or False. Use cache for data loading
|
17 |
+
device: 0 # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
|
18 |
+
workers: 2 # (int) number of worker threads for data loading (per RANK if DDP)
|
19 |
+
project: # (str, optional) project name
|
20 |
+
name: # (str, optional) experiment name, results saved to 'project/name' directory
|
21 |
+
exist_ok: True # (bool) whether to overwrite existing experiment
|
22 |
+
pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str)
|
23 |
+
optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
|
24 |
+
verbose: True # (bool) whether to print verbose output
|
25 |
+
seed: 0 # (int) random seed for reproducibility
|
26 |
+
deterministic: True # (bool) whether to enable deterministic mode
|
27 |
+
single_cls: False # (bool) train multi-class data as single-class
|
28 |
+
rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
|
29 |
+
cos_lr: False # (bool) use cosine learning rate scheduler
|
30 |
+
close_mosaic: 10 # (int) disable mosaic augmentation for final epochs
|
31 |
+
resume: False # (bool) resume training from last checkpoint
|
32 |
+
amp: False # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
|
33 |
+
fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
|
34 |
+
profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
|
35 |
+
# Segmentation
|
36 |
+
overlap_mask: True # (bool) masks should overlap during training (segment train only)
|
37 |
+
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
|
38 |
+
# Classification
|
39 |
+
dropout: 0.0 # (float) use dropout regularization (classify train only)
|
40 |
+
|
41 |
+
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
42 |
+
val: True # (bool) validate/test during training
|
43 |
+
split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train'
|
44 |
+
save_json: True # (bool) save results to JSON file
|
45 |
+
save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions)
|
46 |
+
conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val)
|
47 |
+
iou: 0.7 # (float) intersection over union (IoU) threshold for NMS
|
48 |
+
max_det: 300 # (int) maximum number of detections per image
|
49 |
+
half: False # (bool) use half precision (FP16)
|
50 |
+
dnn: False # (bool) use OpenCV DNN for ONNX inference
|
51 |
+
plots: True # (bool) save plots during train/val
|
52 |
+
|
53 |
+
# Prediction settings --------------------------------------------------------------------------------------------------
|
54 |
+
source: # (str, optional) source directory for images or videos
|
55 |
+
show: False # (bool) show results if possible
|
56 |
+
save_txt: False # (bool) save results as .txt file
|
57 |
+
save_conf: False # (bool) save results with confidence scores
|
58 |
+
save_crop: False # (bool) save cropped images with results
|
59 |
+
show_labels: True # (bool) show object labels in plots
|
60 |
+
show_conf: True # (bool) show object confidence scores in plots
|
61 |
+
vid_stride: 1 # (int) video frame-rate stride
|
62 |
+
line_width: # (int, optional) line width of the bounding boxes, auto if missing
|
63 |
+
visualize: False # (bool) visualize model features
|
64 |
+
augment: False # (bool) apply image augmentation to prediction sources
|
65 |
+
agnostic_nms: False # (bool) class-agnostic NMS
|
66 |
+
classes: # (int | list[int], optional) filter results by class, i.e. class=0, or class=[0,2,3]
|
67 |
+
retina_masks: False # (bool) use high-resolution segmentation masks
|
68 |
+
boxes: True # (bool) Show boxes in segmentation predictions
|
69 |
+
|
70 |
+
# Export settings ------------------------------------------------------------------------------------------------------
|
71 |
+
format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
|
72 |
+
keras: False # (bool) use Kera=s
|
73 |
+
optimize: False # (bool) TorchScript: optimize for mobile
|
74 |
+
int8: False # (bool) CoreML/TF INT8 quantization
|
75 |
+
dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes
|
76 |
+
simplify: False # (bool) ONNX: simplify model
|
77 |
+
opset: # (int, optional) ONNX: opset version
|
78 |
+
workspace: 4 # (int) TensorRT: workspace size (GB)
|
79 |
+
nms: False # (bool) CoreML: add NMS
|
80 |
+
|
81 |
+
# Hyperparameters ------------------------------------------------------------------------------------------------------
|
82 |
+
lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
83 |
+
lrf: 0.01 # (float) final learning rate (lr0 * lrf)
|
84 |
+
momentum: 0.937 # (float) SGD momentum/Adam beta1
|
85 |
+
weight_decay: 0.0005 # (float) optimizer weight decay 5e-4
|
86 |
+
warmup_epochs: 3.0 # (float) warmup epochs (fractions ok)
|
87 |
+
warmup_momentum: 0.8 # (float) warmup initial momentum
|
88 |
+
warmup_bias_lr: 0.1 # (float) warmup initial bias lr
|
89 |
+
box: 7.5 # (float) box loss gain
|
90 |
+
cls: 0.5 # (float) cls loss gain (scale with pixels)
|
91 |
+
dfl: 1.5 # (float) dfl loss gain
|
92 |
+
pose: 12.0 # (float) pose loss gain
|
93 |
+
kobj: 1.0 # (float) keypoint obj loss gain
|
94 |
+
label_smoothing: 0.0 # (float) label smoothing (fraction)
|
95 |
+
nbs: 64 # (int) nominal batch size
|
96 |
+
hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction)
|
97 |
+
hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction)
|
98 |
+
hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction)
|
99 |
+
degrees: 0.0 # (float) image rotation (+/- deg)
|
100 |
+
translate: 0.1 # (float) image translation (+/- fraction)
|
101 |
+
scale: 0.5 # (float) image scale (+/- gain)
|
102 |
+
shear: 0.0 # (float) image shear (+/- deg)
|
103 |
+
perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001
|
104 |
+
flipud: 0.0 # (float) image flip up-down (probability)
|
105 |
+
fliplr: 0.5 # (float) image flip left-right (probability)
|
106 |
+
mosaic: 1.0 # (float) image mosaic (probability)
|
107 |
+
mixup: 0.0 # (float) image mixup (probability)
|
108 |
+
copy_paste: 0.0 # (float) segment copy-paste (probability)
|
109 |
+
|
110 |
+
# Custom config.yaml ---------------------------------------------------------------------------------------------------
|
111 |
+
cfg: # (str, optional) for overriding defaults.yaml
|
112 |
+
save_dir: ./runs/train1 # 自己设置路径
|
113 |
+
# Tracker settings ------------------------------------------------------------------------------------------------------
|
114 |
+
tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml]
|
ultralytics/cfg/models/v8/yolov8.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 1 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
|
9 |
+
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
|
10 |
+
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
|
11 |
+
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
|
12 |
+
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0n head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
|
34 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
35 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
36 |
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
37 |
+
|
38 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
39 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
40 |
+
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
41 |
+
|
42 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
43 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
44 |
+
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
45 |
+
|
46 |
+
- [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
|
ultralytics/cfg/models/v8/yolov8_ECA.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 9 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
|
9 |
+
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
|
10 |
+
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
|
11 |
+
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
|
12 |
+
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0n head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
- [-1, 1, ECAAttention, [512]]
|
34 |
+
|
35 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
36 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
37 |
+
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
|
38 |
+
- [-1, 1, ECAAttention, [256]]
|
39 |
+
|
40 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
41 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
42 |
+
- [-1, 3, C2f, [512]] # 20 (P4/16-medium)
|
43 |
+
- [-1, 1, ECAAttention, [512]]
|
44 |
+
|
45 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
46 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
47 |
+
- [-1, 3, C2f, [1024]] # 24 (P5/32-large)
|
48 |
+
- [-1, 1, ECAAttention, [1024]]
|
49 |
+
|
50 |
+
- [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
|
ultralytics/cfg/models/v8/yolov8_GAM.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 9 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
|
9 |
+
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
|
10 |
+
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
|
11 |
+
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
|
12 |
+
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0n head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
- [-1, 1, GAM_Attention, [512,512]]
|
34 |
+
|
35 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
36 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
37 |
+
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
|
38 |
+
- [-1, 1, GAM_Attention, [256,256]]
|
39 |
+
|
40 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
41 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
42 |
+
- [-1, 3, C2f, [512]] # 20 (P4/16-medium)
|
43 |
+
- [-1, 1, GAM_Attention, [512,512]]
|
44 |
+
|
45 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
46 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
47 |
+
- [-1, 3, C2f, [1024]] # 24 (P5/32-large)
|
48 |
+
- [-1, 1, GAM_Attention, [1024,1024]]
|
49 |
+
|
50 |
+
- [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
|
ultralytics/cfg/models/v8/yolov8_ResBlock_CBAM.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 9 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
|
9 |
+
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
|
10 |
+
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
|
11 |
+
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
|
12 |
+
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, GhostConv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0n head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
- [-1, 1, ResBlock_CBAM, [512]]
|
34 |
+
|
35 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
36 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
37 |
+
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
|
38 |
+
- [-1, 1, ResBlock_CBAM, [256]]
|
39 |
+
|
40 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
41 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
42 |
+
- [-1, 3, C2f, [512]] # 20 (P4/16-medium)
|
43 |
+
- [-1, 1, ResBlock_CBAM, [512]]
|
44 |
+
|
45 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
46 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
47 |
+
- [-1, 3, C2f, [1024]] # 24 (P5/32-large)
|
48 |
+
- [-1, 1, ResBlock_CBAM, [1024]]
|
49 |
+
|
50 |
+
- [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
|
ultralytics/cfg/models/v8/yolov8_SA.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 9 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
|
9 |
+
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
|
10 |
+
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
|
11 |
+
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
|
12 |
+
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0n head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
- [-1, 1, ShuffleAttention, [512]]
|
34 |
+
|
35 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
36 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
37 |
+
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
|
38 |
+
- [-1, 1, ShuffleAttention, [256]]
|
39 |
+
|
40 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
41 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
42 |
+
- [-1, 3, C2f, [512]] # 20 (P4/16-medium)
|
43 |
+
- [-1, 1, ShuffleAttention, [512]]
|
44 |
+
|
45 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
46 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
47 |
+
- [-1, 3, C2f, [1024]] # 24 (P5/32-large)
|
48 |
+
- [-1, 1, ShuffleAttention, [1024]]
|
49 |
+
|
50 |
+
- [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
|
ultralytics/cfg/trackers/botsort.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# Default YOLO tracker settings for BoT-SORT tracker https://github.com/NirAharon/BoT-SORT
|
3 |
+
|
4 |
+
tracker_type: botsort # tracker type, ['botsort', 'bytetrack']
|
5 |
+
track_high_thresh: 0.5 # threshold for the first association
|
6 |
+
track_low_thresh: 0.1 # threshold for the second association
|
7 |
+
new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
|
8 |
+
track_buffer: 30 # buffer to calculate the time when to remove tracks
|
9 |
+
match_thresh: 0.8 # threshold for matching tracks
|
10 |
+
# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
|
11 |
+
# mot20: False # for tracker evaluation(not used for now)
|
12 |
+
|
13 |
+
# BoT-SORT settings
|
14 |
+
cmc_method: sparseOptFlow # method of global motion compensation
|
15 |
+
# ReID model related thresh (not supported yet)
|
16 |
+
proximity_thresh: 0.5
|
17 |
+
appearance_thresh: 0.25
|
18 |
+
with_reid: False
|
ultralytics/cfg/trackers/bytetrack.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack
|
3 |
+
|
4 |
+
tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack']
|
5 |
+
track_high_thresh: 0.5 # threshold for the first association
|
6 |
+
track_low_thresh: 0.1 # threshold for the second association
|
7 |
+
new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
|
8 |
+
track_buffer: 30 # buffer to calculate the time when to remove tracks
|
9 |
+
match_thresh: 0.8 # threshold for matching tracks
|
10 |
+
# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
|
11 |
+
# mot20: False # for tracker evaluation(not used for now)
|
ultralytics/data/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
from .base import BaseDataset
|
4 |
+
from .build import build_dataloader, build_yolo_dataset, load_inference_source
|
5 |
+
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
6 |
+
|
7 |
+
__all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset',
|
8 |
+
'build_dataloader', 'load_inference_source')
|
ultralytics/data/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (473 Bytes). View file
|
|
ultralytics/data/__pycache__/augment.cpython-39.pyc
ADDED
Binary file (31.6 kB). View file
|
|
ultralytics/data/__pycache__/base.cpython-39.pyc
ADDED
Binary file (11.3 kB). View file
|
|
ultralytics/data/__pycache__/build.cpython-39.pyc
ADDED
Binary file (6.2 kB). View file
|
|
ultralytics/data/__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (11.3 kB). View file
|
|
ultralytics/data/__pycache__/loaders.cpython-39.pyc
ADDED
Binary file (15.7 kB). View file
|
|
ultralytics/data/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (24.1 kB). View file
|
|
ultralytics/data/annotator.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from ultralytics import SAM, YOLO
|
4 |
+
|
5 |
+
|
6 |
+
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
|
7 |
+
"""
|
8 |
+
Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
|
9 |
+
Args:
|
10 |
+
data (str): Path to a folder containing images to be annotated.
|
11 |
+
det_model (str, optional): Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'.
|
12 |
+
sam_model (str, optional): Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'.
|
13 |
+
device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available).
|
14 |
+
output_dir (str | None | optional): Directory to save the annotated results.
|
15 |
+
Defaults to a 'labels' folder in the same directory as 'data'.
|
16 |
+
"""
|
17 |
+
det_model = YOLO(det_model)
|
18 |
+
sam_model = SAM(sam_model)
|
19 |
+
|
20 |
+
if not output_dir:
|
21 |
+
output_dir = Path(str(data)).parent / 'labels'
|
22 |
+
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
23 |
+
|
24 |
+
det_results = det_model(data, stream=True, device=device)
|
25 |
+
|
26 |
+
for result in det_results:
|
27 |
+
boxes = result.boxes.xyxy # Boxes object for bbox outputs
|
28 |
+
class_ids = result.boxes.cls.int().tolist() # noqa
|
29 |
+
if len(class_ids):
|
30 |
+
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
|
31 |
+
segments = sam_results[0].masks.xyn # noqa
|
32 |
+
|
33 |
+
with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f:
|
34 |
+
for i in range(len(segments)):
|
35 |
+
s = segments[i]
|
36 |
+
if len(s) == 0:
|
37 |
+
continue
|
38 |
+
segment = map(str, segments[i].reshape(-1).tolist())
|
39 |
+
f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
|
ultralytics/data/augment.py
ADDED
@@ -0,0 +1,906 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
from copy import deepcopy
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as T
|
11 |
+
|
12 |
+
from ultralytics.utils import LOGGER, colorstr
|
13 |
+
from ultralytics.utils.checks import check_version
|
14 |
+
from ultralytics.utils.instance import Instances
|
15 |
+
from ultralytics.utils.metrics import bbox_ioa
|
16 |
+
from ultralytics.utils.ops import segment2box
|
17 |
+
|
18 |
+
from .utils import polygons2masks, polygons2masks_overlap
|
19 |
+
|
20 |
+
POSE_FLIPLR_INDEX = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
21 |
+
|
22 |
+
|
23 |
+
# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
|
24 |
+
class BaseTransform:
|
25 |
+
|
26 |
+
def __init__(self) -> None:
|
27 |
+
pass
|
28 |
+
|
29 |
+
def apply_image(self, labels):
|
30 |
+
"""Applies image transformation to labels."""
|
31 |
+
pass
|
32 |
+
|
33 |
+
def apply_instances(self, labels):
|
34 |
+
"""Applies transformations to input 'labels' and returns object instances."""
|
35 |
+
pass
|
36 |
+
|
37 |
+
def apply_semantic(self, labels):
|
38 |
+
"""Applies semantic segmentation to an image."""
|
39 |
+
pass
|
40 |
+
|
41 |
+
def __call__(self, labels):
|
42 |
+
"""Applies label transformations to an image, instances and semantic masks."""
|
43 |
+
self.apply_image(labels)
|
44 |
+
self.apply_instances(labels)
|
45 |
+
self.apply_semantic(labels)
|
46 |
+
|
47 |
+
|
48 |
+
class Compose:
|
49 |
+
|
50 |
+
def __init__(self, transforms):
|
51 |
+
"""Initializes the Compose object with a list of transforms."""
|
52 |
+
self.transforms = transforms
|
53 |
+
|
54 |
+
def __call__(self, data):
|
55 |
+
"""Applies a series of transformations to input data."""
|
56 |
+
for t in self.transforms:
|
57 |
+
data = t(data)
|
58 |
+
return data
|
59 |
+
|
60 |
+
def append(self, transform):
|
61 |
+
"""Appends a new transform to the existing list of transforms."""
|
62 |
+
self.transforms.append(transform)
|
63 |
+
|
64 |
+
def tolist(self):
|
65 |
+
"""Converts list of transforms to a standard Python list."""
|
66 |
+
return self.transforms
|
67 |
+
|
68 |
+
def __repr__(self):
|
69 |
+
"""Return string representation of object."""
|
70 |
+
format_string = f'{self.__class__.__name__}('
|
71 |
+
for t in self.transforms:
|
72 |
+
format_string += '\n'
|
73 |
+
format_string += f' {t}'
|
74 |
+
format_string += '\n)'
|
75 |
+
return format_string
|
76 |
+
|
77 |
+
|
78 |
+
class BaseMixTransform:
|
79 |
+
"""This implementation is from mmyolo."""
|
80 |
+
|
81 |
+
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
|
82 |
+
self.dataset = dataset
|
83 |
+
self.pre_transform = pre_transform
|
84 |
+
self.p = p
|
85 |
+
|
86 |
+
def __call__(self, labels):
|
87 |
+
"""Applies pre-processing transforms and mixup/mosaic transforms to labels data."""
|
88 |
+
if random.uniform(0, 1) > self.p:
|
89 |
+
return labels
|
90 |
+
|
91 |
+
# Get index of one or three other images
|
92 |
+
indexes = self.get_indexes()
|
93 |
+
if isinstance(indexes, int):
|
94 |
+
indexes = [indexes]
|
95 |
+
|
96 |
+
# Get images information will be used for Mosaic or MixUp
|
97 |
+
mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]
|
98 |
+
|
99 |
+
if self.pre_transform is not None:
|
100 |
+
for i, data in enumerate(mix_labels):
|
101 |
+
mix_labels[i] = self.pre_transform(data)
|
102 |
+
labels['mix_labels'] = mix_labels
|
103 |
+
|
104 |
+
# Mosaic or MixUp
|
105 |
+
labels = self._mix_transform(labels)
|
106 |
+
labels.pop('mix_labels', None)
|
107 |
+
return labels
|
108 |
+
|
109 |
+
def _mix_transform(self, labels):
|
110 |
+
"""Applies MixUp or Mosaic augmentation to the label dictionary."""
|
111 |
+
raise NotImplementedError
|
112 |
+
|
113 |
+
def get_indexes(self):
|
114 |
+
"""Gets a list of shuffled indexes for mosaic augmentation."""
|
115 |
+
raise NotImplementedError
|
116 |
+
|
117 |
+
|
118 |
+
class Mosaic(BaseMixTransform):
|
119 |
+
"""
|
120 |
+
Mosaic augmentation.
|
121 |
+
|
122 |
+
This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
|
123 |
+
The augmentation is applied to a dataset with a given probability.
|
124 |
+
|
125 |
+
Attributes:
|
126 |
+
dataset: The dataset on which the mosaic augmentation is applied.
|
127 |
+
imgsz (int, optional): Image size (height and width) after mosaic pipeline of a single image. Default to 640.
|
128 |
+
p (float, optional): Probability of applying the mosaic augmentation. Must be in the range 0-1. Default to 1.0.
|
129 |
+
n (int, optional): The grid size, either 4 (for 2x2) or 9 (for 3x3).
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
|
133 |
+
"""Initializes the object with a dataset, image size, probability, and border."""
|
134 |
+
assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
|
135 |
+
assert n in (4, 9), 'grid must be equal to 4 or 9.'
|
136 |
+
super().__init__(dataset=dataset, p=p)
|
137 |
+
self.dataset = dataset
|
138 |
+
self.imgsz = imgsz
|
139 |
+
self.border = (-imgsz // 2, -imgsz // 2) # width, height
|
140 |
+
self.n = n
|
141 |
+
|
142 |
+
def get_indexes(self, buffer=True):
|
143 |
+
"""Return a list of random indexes from the dataset."""
|
144 |
+
if buffer: # select images from buffer
|
145 |
+
return random.choices(list(self.dataset.buffer), k=self.n - 1)
|
146 |
+
else: # select any images
|
147 |
+
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
|
148 |
+
|
149 |
+
def _mix_transform(self, labels):
|
150 |
+
"""Apply mixup transformation to the input image and labels."""
|
151 |
+
assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
|
152 |
+
assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
|
153 |
+
return self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
|
154 |
+
|
155 |
+
def _mosaic4(self, labels):
|
156 |
+
"""Create a 2x2 image mosaic."""
|
157 |
+
mosaic_labels = []
|
158 |
+
s = self.imgsz
|
159 |
+
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
|
160 |
+
for i in range(4):
|
161 |
+
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
162 |
+
# Load image
|
163 |
+
img = labels_patch['img']
|
164 |
+
h, w = labels_patch.pop('resized_shape')
|
165 |
+
|
166 |
+
# Place img in img4
|
167 |
+
if i == 0: # top left
|
168 |
+
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
169 |
+
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
|
170 |
+
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
|
171 |
+
elif i == 1: # top right
|
172 |
+
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
173 |
+
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
174 |
+
elif i == 2: # bottom left
|
175 |
+
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
|
176 |
+
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
177 |
+
elif i == 3: # bottom right
|
178 |
+
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
|
179 |
+
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
180 |
+
|
181 |
+
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
|
182 |
+
padw = x1a - x1b
|
183 |
+
padh = y1a - y1b
|
184 |
+
|
185 |
+
labels_patch = self._update_labels(labels_patch, padw, padh)
|
186 |
+
mosaic_labels.append(labels_patch)
|
187 |
+
final_labels = self._cat_labels(mosaic_labels)
|
188 |
+
final_labels['img'] = img4
|
189 |
+
return final_labels
|
190 |
+
|
191 |
+
def _mosaic9(self, labels):
|
192 |
+
"""Create a 3x3 image mosaic."""
|
193 |
+
mosaic_labels = []
|
194 |
+
s = self.imgsz
|
195 |
+
hp, wp = -1, -1 # height, width previous
|
196 |
+
for i in range(9):
|
197 |
+
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
198 |
+
# Load image
|
199 |
+
img = labels_patch['img']
|
200 |
+
h, w = labels_patch.pop('resized_shape')
|
201 |
+
|
202 |
+
# Place img in img9
|
203 |
+
if i == 0: # center
|
204 |
+
img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
205 |
+
h0, w0 = h, w
|
206 |
+
c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
|
207 |
+
elif i == 1: # top
|
208 |
+
c = s, s - h, s + w, s
|
209 |
+
elif i == 2: # top right
|
210 |
+
c = s + wp, s - h, s + wp + w, s
|
211 |
+
elif i == 3: # right
|
212 |
+
c = s + w0, s, s + w0 + w, s + h
|
213 |
+
elif i == 4: # bottom right
|
214 |
+
c = s + w0, s + hp, s + w0 + w, s + hp + h
|
215 |
+
elif i == 5: # bottom
|
216 |
+
c = s + w0 - w, s + h0, s + w0, s + h0 + h
|
217 |
+
elif i == 6: # bottom left
|
218 |
+
c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
|
219 |
+
elif i == 7: # left
|
220 |
+
c = s - w, s + h0 - h, s, s + h0
|
221 |
+
elif i == 8: # top left
|
222 |
+
c = s - w, s + h0 - hp - h, s, s + h0 - hp
|
223 |
+
|
224 |
+
padw, padh = c[:2]
|
225 |
+
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
|
226 |
+
|
227 |
+
# Image
|
228 |
+
img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
|
229 |
+
hp, wp = h, w # height, width previous for next iteration
|
230 |
+
|
231 |
+
# Labels assuming imgsz*2 mosaic size
|
232 |
+
labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1])
|
233 |
+
mosaic_labels.append(labels_patch)
|
234 |
+
final_labels = self._cat_labels(mosaic_labels)
|
235 |
+
|
236 |
+
final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
|
237 |
+
return final_labels
|
238 |
+
|
239 |
+
@staticmethod
|
240 |
+
def _update_labels(labels, padw, padh):
|
241 |
+
"""Update labels."""
|
242 |
+
nh, nw = labels['img'].shape[:2]
|
243 |
+
labels['instances'].convert_bbox(format='xyxy')
|
244 |
+
labels['instances'].denormalize(nw, nh)
|
245 |
+
labels['instances'].add_padding(padw, padh)
|
246 |
+
return labels
|
247 |
+
|
248 |
+
def _cat_labels(self, mosaic_labels):
|
249 |
+
"""Return labels with mosaic border instances clipped."""
|
250 |
+
if len(mosaic_labels) == 0:
|
251 |
+
return {}
|
252 |
+
cls = []
|
253 |
+
instances = []
|
254 |
+
imgsz = self.imgsz * 2 # mosaic imgsz
|
255 |
+
for labels in mosaic_labels:
|
256 |
+
cls.append(labels['cls'])
|
257 |
+
instances.append(labels['instances'])
|
258 |
+
final_labels = {
|
259 |
+
'im_file': mosaic_labels[0]['im_file'],
|
260 |
+
'ori_shape': mosaic_labels[0]['ori_shape'],
|
261 |
+
'resized_shape': (imgsz, imgsz),
|
262 |
+
'cls': np.concatenate(cls, 0),
|
263 |
+
'instances': Instances.concatenate(instances, axis=0),
|
264 |
+
'mosaic_border': self.border} # final_labels
|
265 |
+
final_labels['instances'].clip(imgsz, imgsz)
|
266 |
+
good = final_labels['instances'].remove_zero_area_boxes()
|
267 |
+
final_labels['cls'] = final_labels['cls'][good]
|
268 |
+
return final_labels
|
269 |
+
|
270 |
+
|
271 |
+
class MixUp(BaseMixTransform):
|
272 |
+
|
273 |
+
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
|
274 |
+
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
|
275 |
+
|
276 |
+
def get_indexes(self):
|
277 |
+
"""Get a random index from the dataset."""
|
278 |
+
return random.randint(0, len(self.dataset) - 1)
|
279 |
+
|
280 |
+
def _mix_transform(self, labels):
|
281 |
+
"""Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf."""
|
282 |
+
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
283 |
+
labels2 = labels['mix_labels'][0]
|
284 |
+
labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
|
285 |
+
labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
|
286 |
+
labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
|
287 |
+
return labels
|
288 |
+
|
289 |
+
|
290 |
+
class RandomPerspective:
|
291 |
+
|
292 |
+
def __init__(self,
|
293 |
+
degrees=0.0,
|
294 |
+
translate=0.1,
|
295 |
+
scale=0.5,
|
296 |
+
shear=0.0,
|
297 |
+
perspective=0.0,
|
298 |
+
border=(0, 0),
|
299 |
+
pre_transform=None):
|
300 |
+
self.degrees = degrees
|
301 |
+
self.translate = translate
|
302 |
+
self.scale = scale
|
303 |
+
self.shear = shear
|
304 |
+
self.perspective = perspective
|
305 |
+
# Mosaic border
|
306 |
+
self.border = border
|
307 |
+
self.pre_transform = pre_transform
|
308 |
+
|
309 |
+
def affine_transform(self, img, border):
|
310 |
+
"""Center."""
|
311 |
+
C = np.eye(3, dtype=np.float32)
|
312 |
+
|
313 |
+
C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
|
314 |
+
C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
|
315 |
+
|
316 |
+
# Perspective
|
317 |
+
P = np.eye(3, dtype=np.float32)
|
318 |
+
P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y)
|
319 |
+
P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x)
|
320 |
+
|
321 |
+
# Rotation and Scale
|
322 |
+
R = np.eye(3, dtype=np.float32)
|
323 |
+
a = random.uniform(-self.degrees, self.degrees)
|
324 |
+
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
325 |
+
s = random.uniform(1 - self.scale, 1 + self.scale)
|
326 |
+
# s = 2 ** random.uniform(-scale, scale)
|
327 |
+
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
328 |
+
|
329 |
+
# Shear
|
330 |
+
S = np.eye(3, dtype=np.float32)
|
331 |
+
S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg)
|
332 |
+
S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg)
|
333 |
+
|
334 |
+
# Translation
|
335 |
+
T = np.eye(3, dtype=np.float32)
|
336 |
+
T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels)
|
337 |
+
T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels)
|
338 |
+
|
339 |
+
# Combined rotation matrix
|
340 |
+
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
|
341 |
+
# Affine image
|
342 |
+
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
|
343 |
+
if self.perspective:
|
344 |
+
img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
|
345 |
+
else: # affine
|
346 |
+
img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))
|
347 |
+
return img, M, s
|
348 |
+
|
349 |
+
def apply_bboxes(self, bboxes, M):
|
350 |
+
"""
|
351 |
+
Apply affine to bboxes only.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
bboxes (ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
|
355 |
+
M (ndarray): affine matrix.
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
new_bboxes (ndarray): bboxes after affine, [num_bboxes, 4].
|
359 |
+
"""
|
360 |
+
n = len(bboxes)
|
361 |
+
if n == 0:
|
362 |
+
return bboxes
|
363 |
+
|
364 |
+
xy = np.ones((n * 4, 3), dtype=bboxes.dtype)
|
365 |
+
xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
366 |
+
xy = xy @ M.T # transform
|
367 |
+
xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
|
368 |
+
|
369 |
+
# Create new boxes
|
370 |
+
x = xy[:, [0, 2, 4, 6]]
|
371 |
+
y = xy[:, [1, 3, 5, 7]]
|
372 |
+
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
|
373 |
+
|
374 |
+
def apply_segments(self, segments, M):
|
375 |
+
"""
|
376 |
+
Apply affine to segments and generate new bboxes from segments.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
segments (ndarray): list of segments, [num_samples, 500, 2].
|
380 |
+
M (ndarray): affine matrix.
|
381 |
+
|
382 |
+
Returns:
|
383 |
+
new_segments (ndarray): list of segments after affine, [num_samples, 500, 2].
|
384 |
+
new_bboxes (ndarray): bboxes after affine, [N, 4].
|
385 |
+
"""
|
386 |
+
n, num = segments.shape[:2]
|
387 |
+
if n == 0:
|
388 |
+
return [], segments
|
389 |
+
|
390 |
+
xy = np.ones((n * num, 3), dtype=segments.dtype)
|
391 |
+
segments = segments.reshape(-1, 2)
|
392 |
+
xy[:, :2] = segments
|
393 |
+
xy = xy @ M.T # transform
|
394 |
+
xy = xy[:, :2] / xy[:, 2:3]
|
395 |
+
segments = xy.reshape(n, -1, 2)
|
396 |
+
bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)
|
397 |
+
return bboxes, segments
|
398 |
+
|
399 |
+
def apply_keypoints(self, keypoints, M):
|
400 |
+
"""
|
401 |
+
Apply affine to keypoints.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
keypoints (ndarray): keypoints, [N, 17, 3].
|
405 |
+
M (ndarray): affine matrix.
|
406 |
+
|
407 |
+
Return:
|
408 |
+
new_keypoints (ndarray): keypoints after affine, [N, 17, 3].
|
409 |
+
"""
|
410 |
+
n, nkpt = keypoints.shape[:2]
|
411 |
+
if n == 0:
|
412 |
+
return keypoints
|
413 |
+
xy = np.ones((n * nkpt, 3), dtype=keypoints.dtype)
|
414 |
+
visible = keypoints[..., 2].reshape(n * nkpt, 1)
|
415 |
+
xy[:, :2] = keypoints[..., :2].reshape(n * nkpt, 2)
|
416 |
+
xy = xy @ M.T # transform
|
417 |
+
xy = xy[:, :2] / xy[:, 2:3] # perspective rescale or affine
|
418 |
+
out_mask = (xy[:, 0] < 0) | (xy[:, 1] < 0) | (xy[:, 0] > self.size[0]) | (xy[:, 1] > self.size[1])
|
419 |
+
visible[out_mask] = 0
|
420 |
+
return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3)
|
421 |
+
|
422 |
+
def __call__(self, labels):
|
423 |
+
"""
|
424 |
+
Affine images and targets.
|
425 |
+
|
426 |
+
Args:
|
427 |
+
labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
|
428 |
+
"""
|
429 |
+
if self.pre_transform and 'mosaic_border' not in labels:
|
430 |
+
labels = self.pre_transform(labels)
|
431 |
+
labels.pop('ratio_pad', None) # do not need ratio pad
|
432 |
+
|
433 |
+
img = labels['img']
|
434 |
+
cls = labels['cls']
|
435 |
+
instances = labels.pop('instances')
|
436 |
+
# Make sure the coord formats are right
|
437 |
+
instances.convert_bbox(format='xyxy')
|
438 |
+
instances.denormalize(*img.shape[:2][::-1])
|
439 |
+
|
440 |
+
border = labels.pop('mosaic_border', self.border)
|
441 |
+
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
|
442 |
+
# M is affine matrix
|
443 |
+
# scale for func:`box_candidates`
|
444 |
+
img, M, scale = self.affine_transform(img, border)
|
445 |
+
|
446 |
+
bboxes = self.apply_bboxes(instances.bboxes, M)
|
447 |
+
|
448 |
+
segments = instances.segments
|
449 |
+
keypoints = instances.keypoints
|
450 |
+
# Update bboxes if there are segments.
|
451 |
+
if len(segments):
|
452 |
+
bboxes, segments = self.apply_segments(segments, M)
|
453 |
+
|
454 |
+
if keypoints is not None:
|
455 |
+
keypoints = self.apply_keypoints(keypoints, M)
|
456 |
+
new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
|
457 |
+
# Clip
|
458 |
+
new_instances.clip(*self.size)
|
459 |
+
|
460 |
+
# Filter instances
|
461 |
+
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
|
462 |
+
# Make the bboxes have the same scale with new_bboxes
|
463 |
+
i = self.box_candidates(box1=instances.bboxes.T,
|
464 |
+
box2=new_instances.bboxes.T,
|
465 |
+
area_thr=0.01 if len(segments) else 0.10)
|
466 |
+
labels['instances'] = new_instances[i]
|
467 |
+
labels['cls'] = cls[i]
|
468 |
+
labels['img'] = img
|
469 |
+
labels['resized_shape'] = img.shape[:2]
|
470 |
+
return labels
|
471 |
+
|
472 |
+
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
473 |
+
# Compute box candidates: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
474 |
+
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
475 |
+
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
476 |
+
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
477 |
+
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
478 |
+
|
479 |
+
|
480 |
+
class RandomHSV:
|
481 |
+
|
482 |
+
def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None:
|
483 |
+
self.hgain = hgain
|
484 |
+
self.sgain = sgain
|
485 |
+
self.vgain = vgain
|
486 |
+
|
487 |
+
def __call__(self, labels):
|
488 |
+
"""Applies random horizontal or vertical flip to an image with a given probability."""
|
489 |
+
img = labels['img']
|
490 |
+
if self.hgain or self.sgain or self.vgain:
|
491 |
+
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
|
492 |
+
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
493 |
+
dtype = img.dtype # uint8
|
494 |
+
|
495 |
+
x = np.arange(0, 256, dtype=r.dtype)
|
496 |
+
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
497 |
+
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
498 |
+
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
499 |
+
|
500 |
+
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
501 |
+
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
|
502 |
+
return labels
|
503 |
+
|
504 |
+
|
505 |
+
class RandomFlip:
|
506 |
+
|
507 |
+
def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None:
|
508 |
+
assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
|
509 |
+
assert 0 <= p <= 1.0
|
510 |
+
|
511 |
+
self.p = p
|
512 |
+
self.direction = direction
|
513 |
+
self.flip_idx = flip_idx
|
514 |
+
|
515 |
+
def __call__(self, labels):
|
516 |
+
"""Resize image and padding for detection, instance segmentation, pose."""
|
517 |
+
img = labels['img']
|
518 |
+
instances = labels.pop('instances')
|
519 |
+
instances.convert_bbox(format='xywh')
|
520 |
+
h, w = img.shape[:2]
|
521 |
+
h = 1 if instances.normalized else h
|
522 |
+
w = 1 if instances.normalized else w
|
523 |
+
|
524 |
+
# Flip up-down
|
525 |
+
if self.direction == 'vertical' and random.random() < self.p:
|
526 |
+
img = np.flipud(img)
|
527 |
+
instances.flipud(h)
|
528 |
+
if self.direction == 'horizontal' and random.random() < self.p:
|
529 |
+
img = np.fliplr(img)
|
530 |
+
instances.fliplr(w)
|
531 |
+
# For keypoints
|
532 |
+
if self.flip_idx is not None and instances.keypoints is not None:
|
533 |
+
instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
|
534 |
+
labels['img'] = np.ascontiguousarray(img)
|
535 |
+
labels['instances'] = instances
|
536 |
+
return labels
|
537 |
+
|
538 |
+
|
539 |
+
class LetterBox:
|
540 |
+
"""Resize image and padding for detection, instance segmentation, pose."""
|
541 |
+
|
542 |
+
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):
|
543 |
+
"""Initialize LetterBox object with specific parameters."""
|
544 |
+
self.new_shape = new_shape
|
545 |
+
self.auto = auto
|
546 |
+
self.scaleFill = scaleFill
|
547 |
+
self.scaleup = scaleup
|
548 |
+
self.stride = stride
|
549 |
+
self.center = center # Put the image in the middle or top-left
|
550 |
+
|
551 |
+
def __call__(self, labels=None, image=None):
|
552 |
+
"""Return updated labels and image with added border."""
|
553 |
+
if labels is None:
|
554 |
+
labels = {}
|
555 |
+
img = labels.get('img') if image is None else image
|
556 |
+
shape = img.shape[:2] # current shape [height, width]
|
557 |
+
new_shape = labels.pop('rect_shape', self.new_shape)
|
558 |
+
if isinstance(new_shape, int):
|
559 |
+
new_shape = (new_shape, new_shape)
|
560 |
+
|
561 |
+
# Scale ratio (new / old)
|
562 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
563 |
+
if not self.scaleup: # only scale down, do not scale up (for better val mAP)
|
564 |
+
r = min(r, 1.0)
|
565 |
+
|
566 |
+
# Compute padding
|
567 |
+
ratio = r, r # width, height ratios
|
568 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
569 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
570 |
+
if self.auto: # minimum rectangle
|
571 |
+
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
|
572 |
+
elif self.scaleFill: # stretch
|
573 |
+
dw, dh = 0.0, 0.0
|
574 |
+
new_unpad = (new_shape[1], new_shape[0])
|
575 |
+
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
576 |
+
|
577 |
+
if self.center:
|
578 |
+
dw /= 2 # divide padding into 2 sides
|
579 |
+
dh /= 2
|
580 |
+
if labels.get('ratio_pad'):
|
581 |
+
labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh)) # for evaluation
|
582 |
+
|
583 |
+
if shape[::-1] != new_unpad: # resize
|
584 |
+
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
585 |
+
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
586 |
+
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
587 |
+
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
588 |
+
value=(114, 114, 114)) # add border
|
589 |
+
|
590 |
+
if len(labels):
|
591 |
+
labels = self._update_labels(labels, ratio, dw, dh)
|
592 |
+
labels['img'] = img
|
593 |
+
labels['resized_shape'] = new_shape
|
594 |
+
return labels
|
595 |
+
else:
|
596 |
+
return img
|
597 |
+
|
598 |
+
def _update_labels(self, labels, ratio, padw, padh):
|
599 |
+
"""Update labels."""
|
600 |
+
labels['instances'].convert_bbox(format='xyxy')
|
601 |
+
labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
|
602 |
+
labels['instances'].scale(*ratio)
|
603 |
+
labels['instances'].add_padding(padw, padh)
|
604 |
+
return labels
|
605 |
+
|
606 |
+
|
607 |
+
class CopyPaste:
|
608 |
+
|
609 |
+
def __init__(self, p=0.5) -> None:
|
610 |
+
self.p = p
|
611 |
+
|
612 |
+
def __call__(self, labels):
|
613 |
+
"""Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)."""
|
614 |
+
im = labels['img']
|
615 |
+
cls = labels['cls']
|
616 |
+
h, w = im.shape[:2]
|
617 |
+
instances = labels.pop('instances')
|
618 |
+
instances.convert_bbox(format='xyxy')
|
619 |
+
instances.denormalize(w, h)
|
620 |
+
if self.p and len(instances.segments):
|
621 |
+
n = len(instances)
|
622 |
+
_, w, _ = im.shape # height, width, channels
|
623 |
+
im_new = np.zeros(im.shape, np.uint8)
|
624 |
+
|
625 |
+
# Calculate ioa first then select indexes randomly
|
626 |
+
ins_flip = deepcopy(instances)
|
627 |
+
ins_flip.fliplr(w)
|
628 |
+
|
629 |
+
ioa = bbox_ioa(ins_flip.bboxes, instances.bboxes) # intersection over area, (N, M)
|
630 |
+
indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
|
631 |
+
n = len(indexes)
|
632 |
+
for j in random.sample(list(indexes), k=round(self.p * n)):
|
633 |
+
cls = np.concatenate((cls, cls[[j]]), axis=0)
|
634 |
+
instances = Instances.concatenate((instances, ins_flip[[j]]), axis=0)
|
635 |
+
cv2.drawContours(im_new, instances.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED)
|
636 |
+
|
637 |
+
result = cv2.flip(im, 1) # augment segments (flip left-right)
|
638 |
+
i = cv2.flip(im_new, 1).astype(bool)
|
639 |
+
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
|
640 |
+
|
641 |
+
labels['img'] = im
|
642 |
+
labels['cls'] = cls
|
643 |
+
labels['instances'] = instances
|
644 |
+
return labels
|
645 |
+
|
646 |
+
|
647 |
+
class Albumentations:
|
648 |
+
"""YOLOv8 Albumentations class (optional, only used if package is installed)"""
|
649 |
+
|
650 |
+
def __init__(self, p=1.0):
|
651 |
+
"""Initialize the transform object for YOLO bbox formatted params."""
|
652 |
+
self.p = p
|
653 |
+
self.transform = None
|
654 |
+
prefix = colorstr('albumentations: ')
|
655 |
+
try:
|
656 |
+
import albumentations as A
|
657 |
+
|
658 |
+
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
659 |
+
|
660 |
+
T = [
|
661 |
+
A.Blur(p=0.01),
|
662 |
+
A.MedianBlur(p=0.01),
|
663 |
+
A.ToGray(p=0.01),
|
664 |
+
A.CLAHE(p=0.01),
|
665 |
+
A.RandomBrightnessContrast(p=0.0),
|
666 |
+
A.RandomGamma(p=0.0),
|
667 |
+
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
|
668 |
+
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
669 |
+
|
670 |
+
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
671 |
+
except ImportError: # package not installed, skip
|
672 |
+
pass
|
673 |
+
except Exception as e:
|
674 |
+
LOGGER.info(f'{prefix}{e}')
|
675 |
+
|
676 |
+
def __call__(self, labels):
|
677 |
+
"""Generates object detections and returns a dictionary with detection results."""
|
678 |
+
im = labels['img']
|
679 |
+
cls = labels['cls']
|
680 |
+
if len(cls):
|
681 |
+
labels['instances'].convert_bbox('xywh')
|
682 |
+
labels['instances'].normalize(*im.shape[:2][::-1])
|
683 |
+
bboxes = labels['instances'].bboxes
|
684 |
+
# TODO: add supports of segments and keypoints
|
685 |
+
if self.transform and random.random() < self.p:
|
686 |
+
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
687 |
+
if len(new['class_labels']) > 0: # skip update if no bbox in new im
|
688 |
+
labels['img'] = new['image']
|
689 |
+
labels['cls'] = np.array(new['class_labels'])
|
690 |
+
bboxes = np.array(new['bboxes'], dtype=np.float32)
|
691 |
+
labels['instances'].update(bboxes=bboxes)
|
692 |
+
return labels
|
693 |
+
|
694 |
+
|
695 |
+
# TODO: technically this is not an augmentation, maybe we should put this to another files
|
696 |
+
class Format:
|
697 |
+
|
698 |
+
def __init__(self,
|
699 |
+
bbox_format='xywh',
|
700 |
+
normalize=True,
|
701 |
+
return_mask=False,
|
702 |
+
return_keypoint=False,
|
703 |
+
mask_ratio=4,
|
704 |
+
mask_overlap=True,
|
705 |
+
batch_idx=True):
|
706 |
+
self.bbox_format = bbox_format
|
707 |
+
self.normalize = normalize
|
708 |
+
self.return_mask = return_mask # set False when training detection only
|
709 |
+
self.return_keypoint = return_keypoint
|
710 |
+
self.mask_ratio = mask_ratio
|
711 |
+
self.mask_overlap = mask_overlap
|
712 |
+
self.batch_idx = batch_idx # keep the batch indexes
|
713 |
+
|
714 |
+
def __call__(self, labels):
|
715 |
+
"""Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
|
716 |
+
img = labels.pop('img')
|
717 |
+
h, w = img.shape[:2]
|
718 |
+
cls = labels.pop('cls')
|
719 |
+
instances = labels.pop('instances')
|
720 |
+
instances.convert_bbox(format=self.bbox_format)
|
721 |
+
instances.denormalize(w, h)
|
722 |
+
nl = len(instances)
|
723 |
+
|
724 |
+
if self.return_mask:
|
725 |
+
if nl:
|
726 |
+
masks, instances, cls = self._format_segments(instances, cls, w, h)
|
727 |
+
masks = torch.from_numpy(masks)
|
728 |
+
else:
|
729 |
+
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
|
730 |
+
img.shape[1] // self.mask_ratio)
|
731 |
+
labels['masks'] = masks
|
732 |
+
if self.normalize:
|
733 |
+
instances.normalize(w, h)
|
734 |
+
labels['img'] = self._format_img(img)
|
735 |
+
labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
736 |
+
labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
737 |
+
if self.return_keypoint:
|
738 |
+
labels['keypoints'] = torch.from_numpy(instances.keypoints)
|
739 |
+
# Then we can use collate_fn
|
740 |
+
if self.batch_idx:
|
741 |
+
labels['batch_idx'] = torch.zeros(nl)
|
742 |
+
return labels
|
743 |
+
|
744 |
+
def _format_img(self, img):
|
745 |
+
"""Format the image for YOLOv5 from Numpy array to PyTorch tensor."""
|
746 |
+
if len(img.shape) < 3:
|
747 |
+
img = np.expand_dims(img, -1)
|
748 |
+
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
|
749 |
+
img = torch.from_numpy(img)
|
750 |
+
return img
|
751 |
+
|
752 |
+
def _format_segments(self, instances, cls, w, h):
|
753 |
+
"""convert polygon points to bitmap."""
|
754 |
+
segments = instances.segments
|
755 |
+
if self.mask_overlap:
|
756 |
+
masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)
|
757 |
+
masks = masks[None] # (640, 640) -> (1, 640, 640)
|
758 |
+
instances = instances[sorted_idx]
|
759 |
+
cls = cls[sorted_idx]
|
760 |
+
else:
|
761 |
+
masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio)
|
762 |
+
|
763 |
+
return masks, instances, cls
|
764 |
+
|
765 |
+
|
766 |
+
def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
767 |
+
"""Convert images to a size suitable for YOLOv8 training."""
|
768 |
+
pre_transform = Compose([
|
769 |
+
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
|
770 |
+
CopyPaste(p=hyp.copy_paste),
|
771 |
+
RandomPerspective(
|
772 |
+
degrees=hyp.degrees,
|
773 |
+
translate=hyp.translate,
|
774 |
+
scale=hyp.scale,
|
775 |
+
shear=hyp.shear,
|
776 |
+
perspective=hyp.perspective,
|
777 |
+
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
|
778 |
+
)])
|
779 |
+
flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation
|
780 |
+
if dataset.use_keypoints:
|
781 |
+
kpt_shape = dataset.data.get('kpt_shape', None)
|
782 |
+
if len(flip_idx) == 0 and hyp.fliplr > 0.0:
|
783 |
+
hyp.fliplr = 0.0
|
784 |
+
LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
|
785 |
+
elif flip_idx and (len(flip_idx) != kpt_shape[0]):
|
786 |
+
raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
|
787 |
+
|
788 |
+
return Compose([
|
789 |
+
pre_transform,
|
790 |
+
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
791 |
+
Albumentations(p=1.0),
|
792 |
+
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
793 |
+
RandomFlip(direction='vertical', p=hyp.flipud),
|
794 |
+
RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms
|
795 |
+
|
796 |
+
|
797 |
+
# Classification augmentations -----------------------------------------------------------------------------------------
|
798 |
+
def classify_transforms(size=224, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): # IMAGENET_MEAN, IMAGENET_STD
|
799 |
+
# Transforms to apply if albumentations not installed
|
800 |
+
if not isinstance(size, int):
|
801 |
+
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
|
802 |
+
if any(mean) or any(std):
|
803 |
+
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(mean, std, inplace=True)])
|
804 |
+
else:
|
805 |
+
return T.Compose([CenterCrop(size), ToTensor()])
|
806 |
+
|
807 |
+
|
808 |
+
def hsv2colorjitter(h, s, v):
|
809 |
+
"""Map HSV (hue, saturation, value) jitter into ColorJitter values (brightness, contrast, saturation, hue)"""
|
810 |
+
return v, v, s, h
|
811 |
+
|
812 |
+
|
813 |
+
def classify_albumentations(
|
814 |
+
augment=True,
|
815 |
+
size=224,
|
816 |
+
scale=(0.08, 1.0),
|
817 |
+
hflip=0.5,
|
818 |
+
vflip=0.0,
|
819 |
+
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
|
820 |
+
hsv_s=0.7, # image HSV-Saturation augmentation (fraction)
|
821 |
+
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
822 |
+
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
|
823 |
+
std=(1.0, 1.0, 1.0), # IMAGENET_STD
|
824 |
+
auto_aug=False,
|
825 |
+
):
|
826 |
+
"""YOLOv8 classification Albumentations (optional, only used if package is installed)."""
|
827 |
+
prefix = colorstr('albumentations: ')
|
828 |
+
try:
|
829 |
+
import albumentations as A
|
830 |
+
from albumentations.pytorch import ToTensorV2
|
831 |
+
|
832 |
+
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
833 |
+
if augment: # Resize and crop
|
834 |
+
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
|
835 |
+
if auto_aug:
|
836 |
+
# TODO: implement AugMix, AutoAug & RandAug in albumentations
|
837 |
+
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
|
838 |
+
else:
|
839 |
+
if hflip > 0:
|
840 |
+
T += [A.HorizontalFlip(p=hflip)]
|
841 |
+
if vflip > 0:
|
842 |
+
T += [A.VerticalFlip(p=vflip)]
|
843 |
+
if any((hsv_h, hsv_s, hsv_v)):
|
844 |
+
T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
|
845 |
+
else: # Use fixed crop for eval set (reproducibility)
|
846 |
+
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
847 |
+
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
848 |
+
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
849 |
+
return A.Compose(T)
|
850 |
+
|
851 |
+
except ImportError: # package not installed, skip
|
852 |
+
pass
|
853 |
+
except Exception as e:
|
854 |
+
LOGGER.info(f'{prefix}{e}')
|
855 |
+
|
856 |
+
|
857 |
+
class ClassifyLetterBox:
|
858 |
+
"""YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])"""
|
859 |
+
|
860 |
+
def __init__(self, size=(640, 640), auto=False, stride=32):
|
861 |
+
"""Resizes image and crops it to center with max dimensions 'h' and 'w'."""
|
862 |
+
super().__init__()
|
863 |
+
self.h, self.w = (size, size) if isinstance(size, int) else size
|
864 |
+
self.auto = auto # pass max size integer, automatically solve for short side using stride
|
865 |
+
self.stride = stride # used with auto
|
866 |
+
|
867 |
+
def __call__(self, im): # im = np.array HWC
|
868 |
+
imh, imw = im.shape[:2]
|
869 |
+
r = min(self.h / imh, self.w / imw) # ratio of new/old
|
870 |
+
h, w = round(imh * r), round(imw * r) # resized image
|
871 |
+
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
|
872 |
+
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
|
873 |
+
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
|
874 |
+
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
875 |
+
return im_out
|
876 |
+
|
877 |
+
|
878 |
+
class CenterCrop:
|
879 |
+
"""YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])"""
|
880 |
+
|
881 |
+
def __init__(self, size=640):
|
882 |
+
"""Converts an image from numpy array to PyTorch tensor."""
|
883 |
+
super().__init__()
|
884 |
+
self.h, self.w = (size, size) if isinstance(size, int) else size
|
885 |
+
|
886 |
+
def __call__(self, im): # im = np.array HWC
|
887 |
+
imh, imw = im.shape[:2]
|
888 |
+
m = min(imh, imw) # min dimension
|
889 |
+
top, left = (imh - m) // 2, (imw - m) // 2
|
890 |
+
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
891 |
+
|
892 |
+
|
893 |
+
class ToTensor:
|
894 |
+
"""YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])."""
|
895 |
+
|
896 |
+
def __init__(self, half=False):
|
897 |
+
"""Initialize YOLOv8 ToTensor object with optional half-precision support."""
|
898 |
+
super().__init__()
|
899 |
+
self.half = half
|
900 |
+
|
901 |
+
def __call__(self, im): # im = np.array HWC in BGR order
|
902 |
+
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
|
903 |
+
im = torch.from_numpy(im) # to torch
|
904 |
+
im = im.half() if self.half else im.float() # uint8 to fp16/32
|
905 |
+
im /= 255.0 # 0-255 to 0.0-1.0
|
906 |
+
return im
|
ultralytics/data/base.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from copy import deepcopy
|
8 |
+
from multiprocessing.pool import ThreadPool
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
import psutil
|
15 |
+
from torch.utils.data import Dataset
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
|
19 |
+
|
20 |
+
from .utils import HELP_URL, IMG_FORMATS
|
21 |
+
|
22 |
+
|
23 |
+
class BaseDataset(Dataset):
|
24 |
+
"""
|
25 |
+
Base dataset class for loading and processing image data.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
img_path (str): Path to the folder containing images.
|
29 |
+
imgsz (int, optional): Image size. Defaults to 640.
|
30 |
+
cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
|
31 |
+
augment (bool, optional): If True, data augmentation is applied. Defaults to True.
|
32 |
+
hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
|
33 |
+
prefix (str, optional): Prefix to print in log messages. Defaults to ''.
|
34 |
+
rect (bool, optional): If True, rectangular training is used. Defaults to False.
|
35 |
+
batch_size (int, optional): Size of batches. Defaults to None.
|
36 |
+
stride (int, optional): Stride. Defaults to 32.
|
37 |
+
pad (float, optional): Padding. Defaults to 0.0.
|
38 |
+
single_cls (bool, optional): If True, single class training is used. Defaults to False.
|
39 |
+
classes (list): List of included classes. Default is None.
|
40 |
+
fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
|
41 |
+
|
42 |
+
Attributes:
|
43 |
+
im_files (list): List of image file paths.
|
44 |
+
labels (list): List of label data dictionaries.
|
45 |
+
ni (int): Number of images in the dataset.
|
46 |
+
ims (list): List of loaded images.
|
47 |
+
npy_files (list): List of numpy file paths.
|
48 |
+
transforms (callable): Image transformation function.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self,
|
52 |
+
img_path,
|
53 |
+
imgsz=640,
|
54 |
+
cache=False,
|
55 |
+
augment=True,
|
56 |
+
hyp=DEFAULT_CFG,
|
57 |
+
prefix='',
|
58 |
+
rect=False,
|
59 |
+
batch_size=16,
|
60 |
+
stride=32,
|
61 |
+
pad=0.5,
|
62 |
+
single_cls=False,
|
63 |
+
classes=None,
|
64 |
+
fraction=1.0):
|
65 |
+
super().__init__()
|
66 |
+
self.img_path = img_path
|
67 |
+
self.imgsz = imgsz
|
68 |
+
self.augment = augment
|
69 |
+
self.single_cls = single_cls
|
70 |
+
self.prefix = prefix
|
71 |
+
self.fraction = fraction
|
72 |
+
self.im_files = self.get_img_files(self.img_path)
|
73 |
+
self.labels = self.get_labels()
|
74 |
+
self.update_labels(include_class=classes) # single_cls and include_class
|
75 |
+
self.ni = len(self.labels) # number of images
|
76 |
+
self.rect = rect
|
77 |
+
self.batch_size = batch_size
|
78 |
+
self.stride = stride
|
79 |
+
self.pad = pad
|
80 |
+
if self.rect:
|
81 |
+
assert self.batch_size is not None
|
82 |
+
self.set_rectangle()
|
83 |
+
|
84 |
+
# Buffer thread for mosaic images
|
85 |
+
self.buffer = [] # buffer size = batch size
|
86 |
+
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
87 |
+
|
88 |
+
# Cache stuff
|
89 |
+
if cache == 'ram' and not self.check_cache_ram():
|
90 |
+
cache = False
|
91 |
+
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
92 |
+
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
93 |
+
if cache:
|
94 |
+
self.cache_images(cache)
|
95 |
+
|
96 |
+
# Transforms
|
97 |
+
self.transforms = self.build_transforms(hyp=hyp)
|
98 |
+
|
99 |
+
def get_img_files(self, img_path):
|
100 |
+
"""Read image files."""
|
101 |
+
try:
|
102 |
+
f = [] # image files
|
103 |
+
for p in img_path if isinstance(img_path, list) else [img_path]:
|
104 |
+
p = Path(p) # os-agnostic
|
105 |
+
if p.is_dir(): # dir
|
106 |
+
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
|
107 |
+
# F = list(p.rglob('*.*')) # pathlib
|
108 |
+
elif p.is_file(): # file
|
109 |
+
with open(p) as t:
|
110 |
+
t = t.read().strip().splitlines()
|
111 |
+
parent = str(p.parent) + os.sep
|
112 |
+
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
|
113 |
+
# F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
114 |
+
else:
|
115 |
+
raise FileNotFoundError(f'{self.prefix}{p} does not exist')
|
116 |
+
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
|
117 |
+
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
118 |
+
assert im_files, f'{self.prefix}No images found'
|
119 |
+
except Exception as e:
|
120 |
+
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
|
121 |
+
if self.fraction < 1:
|
122 |
+
im_files = im_files[:round(len(im_files) * self.fraction)]
|
123 |
+
return im_files
|
124 |
+
|
125 |
+
def update_labels(self, include_class: Optional[list]):
|
126 |
+
"""include_class, filter labels to include only these classes (optional)."""
|
127 |
+
include_class_array = np.array(include_class).reshape(1, -1)
|
128 |
+
for i in range(len(self.labels)):
|
129 |
+
if include_class is not None:
|
130 |
+
cls = self.labels[i]['cls']
|
131 |
+
bboxes = self.labels[i]['bboxes']
|
132 |
+
segments = self.labels[i]['segments']
|
133 |
+
keypoints = self.labels[i]['keypoints']
|
134 |
+
j = (cls == include_class_array).any(1)
|
135 |
+
self.labels[i]['cls'] = cls[j]
|
136 |
+
self.labels[i]['bboxes'] = bboxes[j]
|
137 |
+
if segments:
|
138 |
+
self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
|
139 |
+
if keypoints is not None:
|
140 |
+
self.labels[i]['keypoints'] = keypoints[j]
|
141 |
+
if self.single_cls:
|
142 |
+
self.labels[i]['cls'][:, 0] = 0
|
143 |
+
|
144 |
+
def load_image(self, i):
|
145 |
+
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
146 |
+
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
147 |
+
if im is None: # not cached in RAM
|
148 |
+
if fn.exists(): # load npy
|
149 |
+
im = np.load(fn)
|
150 |
+
else: # read image
|
151 |
+
im = cv2.imread(f) # BGR
|
152 |
+
if im is None:
|
153 |
+
raise FileNotFoundError(f'Image Not Found {f}')
|
154 |
+
h0, w0 = im.shape[:2] # orig hw
|
155 |
+
r = self.imgsz / max(h0, w0) # ratio
|
156 |
+
if r != 1: # if sizes are not equal
|
157 |
+
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
|
158 |
+
im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)),
|
159 |
+
interpolation=interp)
|
160 |
+
|
161 |
+
# Add to buffer if training with augmentations
|
162 |
+
if self.augment:
|
163 |
+
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
164 |
+
self.buffer.append(i)
|
165 |
+
if len(self.buffer) >= self.max_buffer_length:
|
166 |
+
j = self.buffer.pop(0)
|
167 |
+
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
|
168 |
+
|
169 |
+
return im, (h0, w0), im.shape[:2]
|
170 |
+
|
171 |
+
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
172 |
+
|
173 |
+
def cache_images(self, cache):
|
174 |
+
"""Cache images to memory or disk."""
|
175 |
+
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
176 |
+
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
177 |
+
with ThreadPool(NUM_THREADS) as pool:
|
178 |
+
results = pool.imap(fcn, range(self.ni))
|
179 |
+
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
180 |
+
for i, x in pbar:
|
181 |
+
if cache == 'disk':
|
182 |
+
b += self.npy_files[i].stat().st_size
|
183 |
+
else: # 'ram'
|
184 |
+
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
185 |
+
b += self.ims[i].nbytes
|
186 |
+
pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
|
187 |
+
pbar.close()
|
188 |
+
|
189 |
+
def cache_images_to_disk(self, i):
|
190 |
+
"""Saves an image as an *.npy file for faster loading."""
|
191 |
+
f = self.npy_files[i]
|
192 |
+
if not f.exists():
|
193 |
+
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
|
194 |
+
|
195 |
+
def check_cache_ram(self, safety_margin=0.5):
|
196 |
+
"""Check image caching requirements vs available memory."""
|
197 |
+
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
198 |
+
n = min(self.ni, 30) # extrapolate from 30 random images
|
199 |
+
for _ in range(n):
|
200 |
+
im = cv2.imread(random.choice(self.im_files)) # sample image
|
201 |
+
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
|
202 |
+
b += im.nbytes * ratio ** 2
|
203 |
+
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
|
204 |
+
mem = psutil.virtual_memory()
|
205 |
+
cache = mem_required < mem.available # to cache or not to cache, that is the question
|
206 |
+
if not cache:
|
207 |
+
LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
|
208 |
+
f'with {int(safety_margin * 100)}% safety margin but only '
|
209 |
+
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
|
210 |
+
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
|
211 |
+
return cache
|
212 |
+
|
213 |
+
def set_rectangle(self):
|
214 |
+
"""Sets the shape of bounding boxes for YOLO detections as rectangles."""
|
215 |
+
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
216 |
+
nb = bi[-1] + 1 # number of batches
|
217 |
+
|
218 |
+
s = np.array([x.pop('shape') for x in self.labels]) # hw
|
219 |
+
ar = s[:, 0] / s[:, 1] # aspect ratio
|
220 |
+
irect = ar.argsort()
|
221 |
+
self.im_files = [self.im_files[i] for i in irect]
|
222 |
+
self.labels = [self.labels[i] for i in irect]
|
223 |
+
ar = ar[irect]
|
224 |
+
|
225 |
+
# Set training image shapes
|
226 |
+
shapes = [[1, 1]] * nb
|
227 |
+
for i in range(nb):
|
228 |
+
ari = ar[bi == i]
|
229 |
+
mini, maxi = ari.min(), ari.max()
|
230 |
+
if maxi < 1:
|
231 |
+
shapes[i] = [maxi, 1]
|
232 |
+
elif mini > 1:
|
233 |
+
shapes[i] = [1, 1 / mini]
|
234 |
+
|
235 |
+
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
|
236 |
+
self.batch = bi # batch index of image
|
237 |
+
|
238 |
+
def __getitem__(self, index):
|
239 |
+
"""Returns transformed label information for given index."""
|
240 |
+
return self.transforms(self.get_image_and_label(index))
|
241 |
+
|
242 |
+
def get_image_and_label(self, index):
|
243 |
+
"""Get and return label information from the dataset."""
|
244 |
+
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
245 |
+
label.pop('shape', None) # shape is for rect, remove it
|
246 |
+
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
|
247 |
+
label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0],
|
248 |
+
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
|
249 |
+
if self.rect:
|
250 |
+
label['rect_shape'] = self.batch_shapes[self.batch[index]]
|
251 |
+
return self.update_labels_info(label)
|
252 |
+
|
253 |
+
def __len__(self):
|
254 |
+
"""Returns the length of the labels list for the dataset."""
|
255 |
+
return len(self.labels)
|
256 |
+
|
257 |
+
def update_labels_info(self, label):
|
258 |
+
"""custom your label format here."""
|
259 |
+
return label
|
260 |
+
|
261 |
+
def build_transforms(self, hyp=None):
|
262 |
+
"""Users can custom augmentations here
|
263 |
+
like:
|
264 |
+
if self.augment:
|
265 |
+
# Training transforms
|
266 |
+
return Compose([])
|
267 |
+
else:
|
268 |
+
# Val transforms
|
269 |
+
return Compose([])
|
270 |
+
"""
|
271 |
+
raise NotImplementedError
|
272 |
+
|
273 |
+
def get_labels(self):
|
274 |
+
"""Users can custom their own format here.
|
275 |
+
Make sure your output is a list with each element like below:
|
276 |
+
dict(
|
277 |
+
im_file=im_file,
|
278 |
+
shape=shape, # format: (height, width)
|
279 |
+
cls=cls,
|
280 |
+
bboxes=bboxes, # xywh
|
281 |
+
segments=segments, # xy
|
282 |
+
keypoints=keypoints, # xy
|
283 |
+
normalized=True, # or False
|
284 |
+
bbox_format="xyxy", # or xywh, ltwh
|
285 |
+
)
|
286 |
+
"""
|
287 |
+
raise NotImplementedError
|
ultralytics/data/build.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torch.utils.data import dataloader, distributed
|
11 |
+
|
12 |
+
from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor,
|
13 |
+
SourceTypes, autocast_list)
|
14 |
+
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
15 |
+
from ultralytics.utils import RANK, colorstr
|
16 |
+
from ultralytics.utils.checks import check_file
|
17 |
+
|
18 |
+
from .dataset import YOLODataset
|
19 |
+
from .utils import PIN_MEMORY
|
20 |
+
|
21 |
+
|
22 |
+
class InfiniteDataLoader(dataloader.DataLoader):
|
23 |
+
"""Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""
|
24 |
+
|
25 |
+
def __init__(self, *args, **kwargs):
|
26 |
+
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
|
27 |
+
super().__init__(*args, **kwargs)
|
28 |
+
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
29 |
+
self.iterator = super().__iter__()
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
"""Returns the length of the batch sampler's sampler."""
|
33 |
+
return len(self.batch_sampler.sampler)
|
34 |
+
|
35 |
+
def __iter__(self):
|
36 |
+
"""Creates a sampler that repeats indefinitely."""
|
37 |
+
for _ in range(len(self)):
|
38 |
+
yield next(self.iterator)
|
39 |
+
|
40 |
+
def reset(self):
|
41 |
+
"""Reset iterator.
|
42 |
+
This is useful when we want to modify settings of dataset while training.
|
43 |
+
"""
|
44 |
+
self.iterator = self._get_iterator()
|
45 |
+
|
46 |
+
|
47 |
+
class _RepeatSampler:
|
48 |
+
"""
|
49 |
+
Sampler that repeats forever.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
sampler (Dataset.sampler): The sampler to repeat.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, sampler):
|
56 |
+
"""Initializes an object that repeats a given sampler indefinitely."""
|
57 |
+
self.sampler = sampler
|
58 |
+
|
59 |
+
def __iter__(self):
|
60 |
+
"""Iterates over the 'sampler' and yields its contents."""
|
61 |
+
while True:
|
62 |
+
yield from iter(self.sampler)
|
63 |
+
|
64 |
+
|
65 |
+
def seed_worker(worker_id): # noqa
|
66 |
+
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
|
67 |
+
worker_seed = torch.initial_seed() % 2 ** 32
|
68 |
+
np.random.seed(worker_seed)
|
69 |
+
random.seed(worker_seed)
|
70 |
+
|
71 |
+
|
72 |
+
def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
|
73 |
+
"""Build YOLO Dataset"""
|
74 |
+
return YOLODataset(
|
75 |
+
img_path=img_path,
|
76 |
+
imgsz=cfg.imgsz,
|
77 |
+
batch_size=batch,
|
78 |
+
augment=mode == 'train', # augmentation
|
79 |
+
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
80 |
+
rect=cfg.rect or rect, # rectangular batches
|
81 |
+
cache=cfg.cache or None,
|
82 |
+
single_cls=cfg.single_cls or False,
|
83 |
+
stride=int(stride),
|
84 |
+
pad=0.0 if mode == 'train' else 0.5,
|
85 |
+
prefix=colorstr(f'{mode}: '),
|
86 |
+
use_segments=cfg.task == 'segment',
|
87 |
+
use_keypoints=cfg.task == 'pose',
|
88 |
+
classes=cfg.classes,
|
89 |
+
data=data,
|
90 |
+
fraction=cfg.fraction if mode == 'train' else 1.0)
|
91 |
+
|
92 |
+
|
93 |
+
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
94 |
+
"""Return an InfiniteDataLoader or DataLoader for training or validation set."""
|
95 |
+
batch = min(batch, len(dataset))
|
96 |
+
nd = torch.cuda.device_count() # number of CUDA devices
|
97 |
+
nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
|
98 |
+
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
99 |
+
generator = torch.Generator()
|
100 |
+
generator.manual_seed(6148914691236517205 + RANK)
|
101 |
+
return InfiniteDataLoader(dataset=dataset,
|
102 |
+
batch_size=batch,
|
103 |
+
shuffle=shuffle and sampler is None,
|
104 |
+
num_workers=nw,
|
105 |
+
sampler=sampler,
|
106 |
+
pin_memory=PIN_MEMORY,
|
107 |
+
collate_fn=getattr(dataset, 'collate_fn', None),
|
108 |
+
worker_init_fn=seed_worker,
|
109 |
+
generator=generator)
|
110 |
+
|
111 |
+
|
112 |
+
def check_source(source):
|
113 |
+
"""Check source type and return corresponding flag values."""
|
114 |
+
webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
|
115 |
+
if isinstance(source, (str, int, Path)): # int for local usb camera
|
116 |
+
source = str(source)
|
117 |
+
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
118 |
+
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
|
119 |
+
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
120 |
+
screenshot = source.lower() == 'screen'
|
121 |
+
if is_url and is_file:
|
122 |
+
source = check_file(source) # download
|
123 |
+
elif isinstance(source, tuple(LOADERS)):
|
124 |
+
in_memory = True
|
125 |
+
elif isinstance(source, (list, tuple)):
|
126 |
+
source = autocast_list(source) # convert all list elements to PIL or np arrays
|
127 |
+
from_img = True
|
128 |
+
elif isinstance(source, (Image.Image, np.ndarray)):
|
129 |
+
from_img = True
|
130 |
+
elif isinstance(source, torch.Tensor):
|
131 |
+
tensor = True
|
132 |
+
else:
|
133 |
+
raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict')
|
134 |
+
|
135 |
+
return source, webcam, screenshot, from_img, in_memory, tensor
|
136 |
+
|
137 |
+
|
138 |
+
def load_inference_source(source=None, imgsz=640, vid_stride=1):
|
139 |
+
"""
|
140 |
+
Loads an inference source for object detection and applies necessary transformations.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
|
144 |
+
imgsz (int, optional): The size of the image for inference. Default is 640.
|
145 |
+
vid_stride (int, optional): The frame interval for video sources. Default is 1.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
dataset (Dataset): A dataset object for the specified input source.
|
149 |
+
"""
|
150 |
+
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
|
151 |
+
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
|
152 |
+
|
153 |
+
# Dataloader
|
154 |
+
if tensor:
|
155 |
+
dataset = LoadTensor(source)
|
156 |
+
elif in_memory:
|
157 |
+
dataset = source
|
158 |
+
elif webcam:
|
159 |
+
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride)
|
160 |
+
elif screenshot:
|
161 |
+
dataset = LoadScreenshots(source, imgsz=imgsz)
|
162 |
+
elif from_img:
|
163 |
+
dataset = LoadPilAndNumpy(source, imgsz=imgsz)
|
164 |
+
else:
|
165 |
+
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
|
166 |
+
|
167 |
+
# Attach source types to the dataset
|
168 |
+
setattr(dataset, 'source_type', source_type)
|
169 |
+
|
170 |
+
return dataset
|
ultralytics/data/converter.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from collections import defaultdict
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from ultralytics.utils.checks import check_requirements
|
10 |
+
from ultralytics.utils.files import make_dirs
|
11 |
+
|
12 |
+
|
13 |
+
def coco91_to_coco80_class():
|
14 |
+
"""Converts 91-index COCO class IDs to 80-index COCO class IDs.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
(list): A list of 91 class IDs where the index represents the 80-index class ID and the value is the
|
18 |
+
corresponding 91-index class ID.
|
19 |
+
|
20 |
+
"""
|
21 |
+
return [
|
22 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None,
|
23 |
+
None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
|
24 |
+
51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
|
25 |
+
None, 73, 74, 75, 76, 77, 78, 79, None]
|
26 |
+
|
27 |
+
|
28 |
+
def convert_coco(labels_dir='../coco/annotations/', use_segments=False, use_keypoints=False, cls91to80=True):
|
29 |
+
"""Converts COCO dataset annotations to a format suitable for training YOLOv5 models.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
labels_dir (str, optional): Path to directory containing COCO dataset annotation files.
|
33 |
+
use_segments (bool, optional): Whether to include segmentation masks in the output.
|
34 |
+
use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
|
35 |
+
cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
|
36 |
+
|
37 |
+
Raises:
|
38 |
+
FileNotFoundError: If the labels_dir path does not exist.
|
39 |
+
|
40 |
+
Example Usage:
|
41 |
+
convert_coco(labels_dir='../coco/annotations/', use_segments=True, use_keypoints=True, cls91to80=True)
|
42 |
+
|
43 |
+
Output:
|
44 |
+
Generates output files in the specified output directory.
|
45 |
+
"""
|
46 |
+
|
47 |
+
save_dir = make_dirs('yolo_labels') # output directory
|
48 |
+
coco80 = coco91_to_coco80_class()
|
49 |
+
|
50 |
+
# Import json
|
51 |
+
for json_file in sorted(Path(labels_dir).resolve().glob('*.json')):
|
52 |
+
fn = Path(save_dir) / 'labels' / json_file.stem.replace('instances_', '') # folder name
|
53 |
+
fn.mkdir(parents=True, exist_ok=True)
|
54 |
+
with open(json_file) as f:
|
55 |
+
data = json.load(f)
|
56 |
+
|
57 |
+
# Create image dict
|
58 |
+
images = {f'{x["id"]:d}': x for x in data['images']}
|
59 |
+
# Create image-annotations dict
|
60 |
+
imgToAnns = defaultdict(list)
|
61 |
+
for ann in data['annotations']:
|
62 |
+
imgToAnns[ann['image_id']].append(ann)
|
63 |
+
|
64 |
+
# Write labels file
|
65 |
+
for img_id, anns in tqdm(imgToAnns.items(), desc=f'Annotations {json_file}'):
|
66 |
+
img = images[f'{img_id:d}']
|
67 |
+
h, w, f = img['height'], img['width'], img['file_name']
|
68 |
+
|
69 |
+
bboxes = []
|
70 |
+
segments = []
|
71 |
+
keypoints = []
|
72 |
+
for ann in anns:
|
73 |
+
if ann['iscrowd']:
|
74 |
+
continue
|
75 |
+
# The COCO box format is [top left x, top left y, width, height]
|
76 |
+
box = np.array(ann['bbox'], dtype=np.float64)
|
77 |
+
box[:2] += box[2:] / 2 # xy top-left corner to center
|
78 |
+
box[[0, 2]] /= w # normalize x
|
79 |
+
box[[1, 3]] /= h # normalize y
|
80 |
+
if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
|
81 |
+
continue
|
82 |
+
|
83 |
+
cls = coco80[ann['category_id'] - 1] if cls91to80 else ann['category_id'] - 1 # class
|
84 |
+
box = [cls] + box.tolist()
|
85 |
+
if box not in bboxes:
|
86 |
+
bboxes.append(box)
|
87 |
+
if use_segments and ann.get('segmentation') is not None:
|
88 |
+
if len(ann['segmentation']) == 0:
|
89 |
+
segments.append([])
|
90 |
+
continue
|
91 |
+
if isinstance(ann['segmentation'], dict):
|
92 |
+
ann['segmentation'] = rle2polygon(ann['segmentation'])
|
93 |
+
if len(ann['segmentation']) > 1:
|
94 |
+
s = merge_multi_segment(ann['segmentation'])
|
95 |
+
s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()
|
96 |
+
else:
|
97 |
+
s = [j for i in ann['segmentation'] for j in i] # all segments concatenated
|
98 |
+
s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist()
|
99 |
+
s = [cls] + s
|
100 |
+
if s not in segments:
|
101 |
+
segments.append(s)
|
102 |
+
if use_keypoints and ann.get('keypoints') is not None:
|
103 |
+
k = (np.array(ann['keypoints']).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist()
|
104 |
+
k = box + k
|
105 |
+
keypoints.append(k)
|
106 |
+
|
107 |
+
# Write
|
108 |
+
with open((fn / f).with_suffix('.txt'), 'a') as file:
|
109 |
+
for i in range(len(bboxes)):
|
110 |
+
if use_keypoints:
|
111 |
+
line = *(keypoints[i]), # cls, box, keypoints
|
112 |
+
else:
|
113 |
+
line = *(segments[i]
|
114 |
+
if use_segments and len(segments[i]) > 0 else bboxes[i]), # cls, box or segments
|
115 |
+
file.write(('%g ' * len(line)).rstrip() % line + '\n')
|
116 |
+
|
117 |
+
|
118 |
+
def rle2polygon(segmentation):
|
119 |
+
"""
|
120 |
+
Convert Run-Length Encoding (RLE) mask to polygon coordinates.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
segmentation (dict, list): RLE mask representation of the object segmentation.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
(list): A list of lists representing the polygon coordinates for each contour.
|
127 |
+
|
128 |
+
Note:
|
129 |
+
Requires the 'pycocotools' package to be installed.
|
130 |
+
"""
|
131 |
+
check_requirements('pycocotools')
|
132 |
+
from pycocotools import mask
|
133 |
+
|
134 |
+
m = mask.decode(segmentation)
|
135 |
+
m[m > 0] = 255
|
136 |
+
contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)
|
137 |
+
polygons = []
|
138 |
+
for contour in contours:
|
139 |
+
epsilon = 0.001 * cv2.arcLength(contour, True)
|
140 |
+
contour_approx = cv2.approxPolyDP(contour, epsilon, True)
|
141 |
+
polygon = contour_approx.flatten().tolist()
|
142 |
+
polygons.append(polygon)
|
143 |
+
return polygons
|
144 |
+
|
145 |
+
|
146 |
+
def min_index(arr1, arr2):
|
147 |
+
"""
|
148 |
+
Find a pair of indexes with the shortest distance between two arrays of 2D points.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
arr1 (np.array): A NumPy array of shape (N, 2) representing N 2D points.
|
152 |
+
arr2 (np.array): A NumPy array of shape (M, 2) representing M 2D points.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
(tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively.
|
156 |
+
"""
|
157 |
+
dis = ((arr1[:, None, :] - arr2[None, :, :]) ** 2).sum(-1)
|
158 |
+
return np.unravel_index(np.argmin(dis, axis=None), dis.shape)
|
159 |
+
|
160 |
+
|
161 |
+
def merge_multi_segment(segments):
|
162 |
+
"""
|
163 |
+
Merge multiple segments into one list by connecting the coordinates with the minimum distance between each segment.
|
164 |
+
This function connects these coordinates with a thin line to merge all segments into one.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
segments (List[List]): Original segmentations in COCO's JSON file.
|
168 |
+
Each element is a list of coordinates, like [segmentation1, segmentation2,...].
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
s (List[np.ndarray]): A list of connected segments represented as NumPy arrays.
|
172 |
+
"""
|
173 |
+
s = []
|
174 |
+
segments = [np.array(i).reshape(-1, 2) for i in segments]
|
175 |
+
idx_list = [[] for _ in range(len(segments))]
|
176 |
+
|
177 |
+
# record the indexes with min distance between each segment
|
178 |
+
for i in range(1, len(segments)):
|
179 |
+
idx1, idx2 = min_index(segments[i - 1], segments[i])
|
180 |
+
idx_list[i - 1].append(idx1)
|
181 |
+
idx_list[i].append(idx2)
|
182 |
+
|
183 |
+
# use two round to connect all the segments
|
184 |
+
for k in range(2):
|
185 |
+
# forward connection
|
186 |
+
if k == 0:
|
187 |
+
for i, idx in enumerate(idx_list):
|
188 |
+
# middle segments have two indexes
|
189 |
+
# reverse the index of middle segments
|
190 |
+
if len(idx) == 2 and idx[0] > idx[1]:
|
191 |
+
idx = idx[::-1]
|
192 |
+
segments[i] = segments[i][::-1, :]
|
193 |
+
|
194 |
+
segments[i] = np.roll(segments[i], -idx[0], axis=0)
|
195 |
+
segments[i] = np.concatenate([segments[i], segments[i][:1]])
|
196 |
+
# deal with the first segment and the last one
|
197 |
+
if i in [0, len(idx_list) - 1]:
|
198 |
+
s.append(segments[i])
|
199 |
+
else:
|
200 |
+
idx = [0, idx[1] - idx[0]]
|
201 |
+
s.append(segments[i][idx[0]:idx[1] + 1])
|
202 |
+
|
203 |
+
else:
|
204 |
+
for i in range(len(idx_list) - 1, -1, -1):
|
205 |
+
if i not in [0, len(idx_list) - 1]:
|
206 |
+
idx = idx_list[i]
|
207 |
+
nidx = abs(idx[1] - idx[0])
|
208 |
+
s.append(segments[i][nidx:])
|
209 |
+
return s
|
210 |
+
|
211 |
+
|
212 |
+
def delete_dsstore(path='../datasets'):
|
213 |
+
"""Delete Apple .DS_Store files in the specified directory and its subdirectories."""
|
214 |
+
from pathlib import Path
|
215 |
+
|
216 |
+
files = list(Path(path).rglob('.DS_store'))
|
217 |
+
print(files)
|
218 |
+
for f in files:
|
219 |
+
f.unlink()
|
220 |
+
|
221 |
+
|
222 |
+
if __name__ == '__main__':
|
223 |
+
source = 'COCO'
|
224 |
+
|
225 |
+
if source == 'COCO':
|
226 |
+
convert_coco(
|
227 |
+
'../datasets/coco/annotations', # directory with *.json
|
228 |
+
use_segments=False,
|
229 |
+
use_keypoints=True,
|
230 |
+
cls91to80=False)
|
ultralytics/data/dataloaders/__init__.py
ADDED
File without changes
|
ultralytics/data/dataset.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
from itertools import repeat
|
4 |
+
from multiprocessing.pool import ThreadPool
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torchvision
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
|
14 |
+
|
15 |
+
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
|
16 |
+
from .base import BaseDataset
|
17 |
+
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label
|
18 |
+
|
19 |
+
|
20 |
+
class YOLODataset(BaseDataset):
|
21 |
+
"""
|
22 |
+
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
data (dict, optional): A dataset YAML dictionary. Defaults to None.
|
26 |
+
use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
|
27 |
+
use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
31 |
+
"""
|
32 |
+
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
|
33 |
+
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
|
34 |
+
|
35 |
+
def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs):
|
36 |
+
self.use_segments = use_segments
|
37 |
+
self.use_keypoints = use_keypoints
|
38 |
+
self.data = data
|
39 |
+
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
|
40 |
+
super().__init__(*args, **kwargs)
|
41 |
+
|
42 |
+
def cache_labels(self, path=Path('./labels.cache')):
|
43 |
+
"""Cache dataset labels, check images and read shapes.
|
44 |
+
Args:
|
45 |
+
path (Path): path where to save the cache file (default: Path('./labels.cache')).
|
46 |
+
Returns:
|
47 |
+
(dict): labels.
|
48 |
+
"""
|
49 |
+
x = {'labels': []}
|
50 |
+
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
51 |
+
desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
|
52 |
+
total = len(self.im_files)
|
53 |
+
nkpt, ndim = self.data.get('kpt_shape', (0, 0))
|
54 |
+
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
55 |
+
raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
56 |
+
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")
|
57 |
+
with ThreadPool(NUM_THREADS) as pool:
|
58 |
+
results = pool.imap(func=verify_image_label,
|
59 |
+
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
|
60 |
+
repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
|
61 |
+
repeat(ndim)))
|
62 |
+
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
|
63 |
+
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
64 |
+
nm += nm_f
|
65 |
+
nf += nf_f
|
66 |
+
ne += ne_f
|
67 |
+
nc += nc_f
|
68 |
+
if im_file:
|
69 |
+
x['labels'].append(
|
70 |
+
dict(
|
71 |
+
im_file=im_file,
|
72 |
+
shape=shape,
|
73 |
+
cls=lb[:, 0:1], # n, 1
|
74 |
+
bboxes=lb[:, 1:], # n, 4
|
75 |
+
segments=segments,
|
76 |
+
keypoints=keypoint,
|
77 |
+
normalized=True,
|
78 |
+
bbox_format='xywh'))
|
79 |
+
if msg:
|
80 |
+
msgs.append(msg)
|
81 |
+
pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
82 |
+
pbar.close()
|
83 |
+
|
84 |
+
if msgs:
|
85 |
+
LOGGER.info('\n'.join(msgs))
|
86 |
+
if nf == 0:
|
87 |
+
LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
|
88 |
+
x['hash'] = get_hash(self.label_files + self.im_files)
|
89 |
+
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
90 |
+
x['msgs'] = msgs # warnings
|
91 |
+
x['version'] = self.cache_version # cache version
|
92 |
+
if is_dir_writeable(path.parent):
|
93 |
+
if path.exists():
|
94 |
+
path.unlink() # remove *.cache file if exists
|
95 |
+
np.save(str(path), x) # save cache for next time
|
96 |
+
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
97 |
+
LOGGER.info(f'{self.prefix}New cache created: {path}')
|
98 |
+
else:
|
99 |
+
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
|
100 |
+
return x
|
101 |
+
|
102 |
+
def get_labels(self):
|
103 |
+
"""Returns dictionary of labels for YOLO training."""
|
104 |
+
self.label_files = img2label_paths(self.im_files)
|
105 |
+
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
|
106 |
+
try:
|
107 |
+
import gc
|
108 |
+
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
109 |
+
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
110 |
+
gc.enable()
|
111 |
+
assert cache['version'] == self.cache_version # matches current version
|
112 |
+
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
113 |
+
except (FileNotFoundError, AssertionError, AttributeError):
|
114 |
+
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
115 |
+
|
116 |
+
# Display cache
|
117 |
+
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
118 |
+
if exists and LOCAL_RANK in (-1, 0):
|
119 |
+
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
120 |
+
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
121 |
+
if cache['msgs']:
|
122 |
+
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
123 |
+
if nf == 0: # number of labels found
|
124 |
+
raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}')
|
125 |
+
|
126 |
+
# Read cache
|
127 |
+
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
128 |
+
labels = cache['labels']
|
129 |
+
self.im_files = [lb['im_file'] for lb in labels] # update im_files
|
130 |
+
|
131 |
+
# Check if the dataset is all boxes or all segments
|
132 |
+
lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
|
133 |
+
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
|
134 |
+
if len_segments and len_boxes != len_segments:
|
135 |
+
LOGGER.warning(
|
136 |
+
f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
|
137 |
+
f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
|
138 |
+
'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
|
139 |
+
for lb in labels:
|
140 |
+
lb['segments'] = []
|
141 |
+
if len_cls == 0:
|
142 |
+
raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}')
|
143 |
+
return labels
|
144 |
+
|
145 |
+
# TODO: use hyp config to set all these augmentations
|
146 |
+
def build_transforms(self, hyp=None):
|
147 |
+
"""Builds and appends transforms to the list."""
|
148 |
+
if self.augment:
|
149 |
+
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
150 |
+
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
151 |
+
transforms = v8_transforms(self, self.imgsz, hyp)
|
152 |
+
else:
|
153 |
+
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
154 |
+
transforms.append(
|
155 |
+
Format(bbox_format='xywh',
|
156 |
+
normalize=True,
|
157 |
+
return_mask=self.use_segments,
|
158 |
+
return_keypoint=self.use_keypoints,
|
159 |
+
batch_idx=True,
|
160 |
+
mask_ratio=hyp.mask_ratio,
|
161 |
+
mask_overlap=hyp.overlap_mask))
|
162 |
+
return transforms
|
163 |
+
|
164 |
+
def close_mosaic(self, hyp):
|
165 |
+
"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
|
166 |
+
hyp.mosaic = 0.0 # set mosaic ratio=0.0
|
167 |
+
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
|
168 |
+
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
|
169 |
+
self.transforms = self.build_transforms(hyp)
|
170 |
+
|
171 |
+
def update_labels_info(self, label):
|
172 |
+
"""custom your label format here."""
|
173 |
+
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
174 |
+
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
175 |
+
bboxes = label.pop('bboxes')
|
176 |
+
segments = label.pop('segments')
|
177 |
+
keypoints = label.pop('keypoints', None)
|
178 |
+
bbox_format = label.pop('bbox_format')
|
179 |
+
normalized = label.pop('normalized')
|
180 |
+
label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
181 |
+
return label
|
182 |
+
|
183 |
+
@staticmethod
|
184 |
+
def collate_fn(batch):
|
185 |
+
"""Collates data samples into batches."""
|
186 |
+
new_batch = {}
|
187 |
+
keys = batch[0].keys()
|
188 |
+
values = list(zip(*[list(b.values()) for b in batch]))
|
189 |
+
for i, k in enumerate(keys):
|
190 |
+
value = values[i]
|
191 |
+
if k == 'img':
|
192 |
+
value = torch.stack(value, 0)
|
193 |
+
if k in ['masks', 'keypoints', 'bboxes', 'cls']:
|
194 |
+
value = torch.cat(value, 0)
|
195 |
+
new_batch[k] = value
|
196 |
+
new_batch['batch_idx'] = list(new_batch['batch_idx'])
|
197 |
+
for i in range(len(new_batch['batch_idx'])):
|
198 |
+
new_batch['batch_idx'][i] += i # add target image index for build_targets()
|
199 |
+
new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
|
200 |
+
return new_batch
|
201 |
+
|
202 |
+
|
203 |
+
# Classification dataloaders -------------------------------------------------------------------------------------------
|
204 |
+
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
205 |
+
"""
|
206 |
+
YOLO Classification Dataset.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
root (str): Dataset path.
|
210 |
+
|
211 |
+
Attributes:
|
212 |
+
cache_ram (bool): True if images should be cached in RAM, False otherwise.
|
213 |
+
cache_disk (bool): True if images should be cached on disk, False otherwise.
|
214 |
+
samples (list): List of samples containing file, index, npy, and im.
|
215 |
+
torch_transforms (callable): torchvision transforms applied to the dataset.
|
216 |
+
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self, root, args, augment=False, cache=False):
|
220 |
+
"""
|
221 |
+
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
root (str): Dataset path.
|
225 |
+
args (Namespace): Argument parser containing dataset related settings.
|
226 |
+
augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False.
|
227 |
+
cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
|
228 |
+
"""
|
229 |
+
super().__init__(root=root)
|
230 |
+
if augment and args.fraction < 1.0: # reduce training fraction
|
231 |
+
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
|
232 |
+
self.cache_ram = cache is True or cache == 'ram'
|
233 |
+
self.cache_disk = cache == 'disk'
|
234 |
+
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
235 |
+
self.torch_transforms = classify_transforms(args.imgsz)
|
236 |
+
self.album_transforms = classify_albumentations(
|
237 |
+
augment=augment,
|
238 |
+
size=args.imgsz,
|
239 |
+
scale=(1.0 - args.scale, 1.0), # (0.08, 1.0)
|
240 |
+
hflip=args.fliplr,
|
241 |
+
vflip=args.flipud,
|
242 |
+
hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction)
|
243 |
+
hsv_s=args.hsv_s, # HSV-Saturation augmentation (fraction)
|
244 |
+
hsv_v=args.hsv_v, # HSV-Value augmentation (fraction)
|
245 |
+
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
|
246 |
+
std=(1.0, 1.0, 1.0), # IMAGENET_STD
|
247 |
+
auto_aug=False) if augment else None
|
248 |
+
|
249 |
+
def __getitem__(self, i):
|
250 |
+
"""Returns subset of data and targets corresponding to given indices."""
|
251 |
+
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
252 |
+
if self.cache_ram and im is None:
|
253 |
+
im = self.samples[i][3] = cv2.imread(f)
|
254 |
+
elif self.cache_disk:
|
255 |
+
if not fn.exists(): # load npy
|
256 |
+
np.save(fn.as_posix(), cv2.imread(f))
|
257 |
+
im = np.load(fn)
|
258 |
+
else: # read image
|
259 |
+
im = cv2.imread(f) # BGR
|
260 |
+
if self.album_transforms:
|
261 |
+
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
|
262 |
+
else:
|
263 |
+
sample = self.torch_transforms(im)
|
264 |
+
return {'img': sample, 'cls': j}
|
265 |
+
|
266 |
+
def __len__(self) -> int:
|
267 |
+
return len(self.samples)
|
268 |
+
|
269 |
+
|
270 |
+
# TODO: support semantic segmentation
|
271 |
+
class SemanticDataset(BaseDataset):
|
272 |
+
|
273 |
+
def __init__(self):
|
274 |
+
"""Initialize a SemanticDataset object."""
|
275 |
+
super().__init__()
|
ultralytics/data/loaders.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from pathlib import Path
|
9 |
+
from threading import Thread
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
import requests
|
15 |
+
import torch
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
19 |
+
from ultralytics.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
|
20 |
+
from ultralytics.utils.checks import check_requirements
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class SourceTypes:
|
25 |
+
webcam: bool = False
|
26 |
+
screenshot: bool = False
|
27 |
+
from_img: bool = False
|
28 |
+
tensor: bool = False
|
29 |
+
|
30 |
+
|
31 |
+
class LoadStreams:
|
32 |
+
"""YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
|
33 |
+
|
34 |
+
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
|
35 |
+
"""Initialize instance variables and check for consistent input stream shapes."""
|
36 |
+
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
37 |
+
self.mode = 'stream'
|
38 |
+
self.imgsz = imgsz
|
39 |
+
self.vid_stride = vid_stride # video frame-rate stride
|
40 |
+
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
|
41 |
+
n = len(sources)
|
42 |
+
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
|
43 |
+
self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [None] * n
|
44 |
+
for i, s in enumerate(sources): # index, source
|
45 |
+
# Start thread to read frames from video stream
|
46 |
+
st = f'{i + 1}/{n}: {s}... '
|
47 |
+
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
|
48 |
+
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
49 |
+
s = get_best_youtube_url(s)
|
50 |
+
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
51 |
+
if s == 0 and (is_colab() or is_kaggle()):
|
52 |
+
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
53 |
+
"Try running 'source=0' in a local environment.")
|
54 |
+
cap = cv2.VideoCapture(s)
|
55 |
+
if not cap.isOpened():
|
56 |
+
raise ConnectionError(f'{st}Failed to open {s}')
|
57 |
+
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
58 |
+
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
59 |
+
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
60 |
+
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
|
61 |
+
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
|
62 |
+
|
63 |
+
success, im = cap.read() # guarantee first frame
|
64 |
+
if not success or im is None:
|
65 |
+
raise ConnectionError(f'{st}Failed to read images from {s}')
|
66 |
+
self.imgs[i].append(im)
|
67 |
+
self.shape[i] = im.shape
|
68 |
+
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
|
69 |
+
LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
|
70 |
+
self.threads[i].start()
|
71 |
+
LOGGER.info('') # newline
|
72 |
+
|
73 |
+
# Check for common shapes
|
74 |
+
self.bs = self.__len__()
|
75 |
+
|
76 |
+
def update(self, i, cap, stream):
|
77 |
+
"""Read stream `i` frames in daemon thread."""
|
78 |
+
n, f = 0, self.frames[i] # frame number, frame array
|
79 |
+
while cap.isOpened() and n < f:
|
80 |
+
# Only read a new frame if the buffer is empty
|
81 |
+
if not self.imgs[i]:
|
82 |
+
n += 1
|
83 |
+
cap.grab() # .read() = .grab() followed by .retrieve()
|
84 |
+
if n % self.vid_stride == 0:
|
85 |
+
success, im = cap.retrieve()
|
86 |
+
if success:
|
87 |
+
self.imgs[i].append(im) # add image to buffer
|
88 |
+
else:
|
89 |
+
LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
|
90 |
+
self.imgs[i].append(np.zeros(self.shape[i]))
|
91 |
+
cap.open(stream) # re-open stream if signal was lost
|
92 |
+
else:
|
93 |
+
time.sleep(0.01) # wait until the buffer is empty
|
94 |
+
|
95 |
+
def __iter__(self):
|
96 |
+
"""Iterates through YOLO image feed and re-opens unresponsive streams."""
|
97 |
+
self.count = -1
|
98 |
+
return self
|
99 |
+
|
100 |
+
def __next__(self):
|
101 |
+
"""Returns source paths, transformed and original images for processing."""
|
102 |
+
self.count += 1
|
103 |
+
|
104 |
+
# Wait until a frame is available in each buffer
|
105 |
+
while not all(self.imgs):
|
106 |
+
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
|
107 |
+
cv2.destroyAllWindows()
|
108 |
+
raise StopIteration
|
109 |
+
time.sleep(1 / min(self.fps))
|
110 |
+
|
111 |
+
# Get and remove the next frame from imgs buffer
|
112 |
+
return self.sources, [x.pop(0) for x in self.imgs], None, ''
|
113 |
+
|
114 |
+
def __len__(self):
|
115 |
+
"""Return the length of the sources object."""
|
116 |
+
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
|
117 |
+
|
118 |
+
|
119 |
+
class LoadScreenshots:
|
120 |
+
"""YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`."""
|
121 |
+
|
122 |
+
def __init__(self, source, imgsz=640):
|
123 |
+
"""source = [screen_number left top width height] (pixels)."""
|
124 |
+
check_requirements('mss')
|
125 |
+
import mss # noqa
|
126 |
+
|
127 |
+
source, *params = source.split()
|
128 |
+
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
|
129 |
+
if len(params) == 1:
|
130 |
+
self.screen = int(params[0])
|
131 |
+
elif len(params) == 4:
|
132 |
+
left, top, width, height = (int(x) for x in params)
|
133 |
+
elif len(params) == 5:
|
134 |
+
self.screen, left, top, width, height = (int(x) for x in params)
|
135 |
+
self.imgsz = imgsz
|
136 |
+
self.mode = 'stream'
|
137 |
+
self.frame = 0
|
138 |
+
self.sct = mss.mss()
|
139 |
+
self.bs = 1
|
140 |
+
|
141 |
+
# Parse monitor shape
|
142 |
+
monitor = self.sct.monitors[self.screen]
|
143 |
+
self.top = monitor['top'] if top is None else (monitor['top'] + top)
|
144 |
+
self.left = monitor['left'] if left is None else (monitor['left'] + left)
|
145 |
+
self.width = width or monitor['width']
|
146 |
+
self.height = height or monitor['height']
|
147 |
+
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
|
148 |
+
|
149 |
+
def __iter__(self):
|
150 |
+
"""Returns an iterator of the object."""
|
151 |
+
return self
|
152 |
+
|
153 |
+
def __next__(self):
|
154 |
+
"""mss screen capture: get raw pixels from the screen as np array."""
|
155 |
+
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
|
156 |
+
s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
|
157 |
+
|
158 |
+
self.frame += 1
|
159 |
+
return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string
|
160 |
+
|
161 |
+
|
162 |
+
class LoadImages:
|
163 |
+
"""YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`."""
|
164 |
+
|
165 |
+
def __init__(self, path, imgsz=640, vid_stride=1):
|
166 |
+
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
|
167 |
+
parent = None
|
168 |
+
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
169 |
+
parent = Path(path).parent
|
170 |
+
path = Path(path).read_text().rsplit()
|
171 |
+
files = []
|
172 |
+
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
173 |
+
a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
|
174 |
+
if '*' in a:
|
175 |
+
files.extend(sorted(glob.glob(a, recursive=True))) # glob
|
176 |
+
elif os.path.isdir(a):
|
177 |
+
files.extend(sorted(glob.glob(os.path.join(a, '*.*')))) # dir
|
178 |
+
elif os.path.isfile(a):
|
179 |
+
files.append(a) # files (absolute or relative to CWD)
|
180 |
+
elif parent and (parent / p).is_file():
|
181 |
+
files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
|
182 |
+
else:
|
183 |
+
raise FileNotFoundError(f'{p} does not exist')
|
184 |
+
|
185 |
+
images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
|
186 |
+
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
|
187 |
+
ni, nv = len(images), len(videos)
|
188 |
+
|
189 |
+
self.imgsz = imgsz
|
190 |
+
self.files = images + videos
|
191 |
+
self.nf = ni + nv # number of files
|
192 |
+
self.video_flag = [False] * ni + [True] * nv
|
193 |
+
self.mode = 'image'
|
194 |
+
self.vid_stride = vid_stride # video frame-rate stride
|
195 |
+
self.bs = 1
|
196 |
+
if any(videos):
|
197 |
+
self.orientation = None # rotation degrees
|
198 |
+
self._new_video(videos[0]) # new video
|
199 |
+
else:
|
200 |
+
self.cap = None
|
201 |
+
if self.nf == 0:
|
202 |
+
raise FileNotFoundError(f'No images or videos found in {p}. '
|
203 |
+
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
|
204 |
+
|
205 |
+
def __iter__(self):
|
206 |
+
"""Returns an iterator object for VideoStream or ImageFolder."""
|
207 |
+
self.count = 0
|
208 |
+
return self
|
209 |
+
|
210 |
+
def __next__(self):
|
211 |
+
"""Return next image, path and metadata from dataset."""
|
212 |
+
if self.count == self.nf:
|
213 |
+
raise StopIteration
|
214 |
+
path = self.files[self.count]
|
215 |
+
|
216 |
+
if self.video_flag[self.count]:
|
217 |
+
# Read video
|
218 |
+
self.mode = 'video'
|
219 |
+
for _ in range(self.vid_stride):
|
220 |
+
self.cap.grab()
|
221 |
+
success, im0 = self.cap.retrieve()
|
222 |
+
while not success:
|
223 |
+
self.count += 1
|
224 |
+
self.cap.release()
|
225 |
+
if self.count == self.nf: # last video
|
226 |
+
raise StopIteration
|
227 |
+
path = self.files[self.count]
|
228 |
+
self._new_video(path)
|
229 |
+
success, im0 = self.cap.read()
|
230 |
+
|
231 |
+
self.frame += 1
|
232 |
+
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
|
233 |
+
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
|
234 |
+
|
235 |
+
else:
|
236 |
+
# Read image
|
237 |
+
self.count += 1
|
238 |
+
im0 = cv2.imread(path) # BGR
|
239 |
+
if im0 is None:
|
240 |
+
raise FileNotFoundError(f'Image Not Found {path}')
|
241 |
+
s = f'image {self.count}/{self.nf} {path}: '
|
242 |
+
|
243 |
+
return [path], [im0], self.cap, s
|
244 |
+
|
245 |
+
def _new_video(self, path):
|
246 |
+
"""Create a new video capture object."""
|
247 |
+
self.frame = 0
|
248 |
+
self.cap = cv2.VideoCapture(path)
|
249 |
+
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
|
250 |
+
if hasattr(cv2, 'CAP_PROP_ORIENTATION_META'): # cv2<4.6.0 compatibility
|
251 |
+
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
|
252 |
+
# Disable auto-orientation due to known issues in https://github.com/ultralytics/yolov5/issues/8493
|
253 |
+
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)
|
254 |
+
|
255 |
+
def _cv2_rotate(self, im):
|
256 |
+
"""Rotate a cv2 video manually."""
|
257 |
+
if self.orientation == 0:
|
258 |
+
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
|
259 |
+
elif self.orientation == 180:
|
260 |
+
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
261 |
+
elif self.orientation == 90:
|
262 |
+
return cv2.rotate(im, cv2.ROTATE_180)
|
263 |
+
return im
|
264 |
+
|
265 |
+
def __len__(self):
|
266 |
+
"""Returns the number of files in the object."""
|
267 |
+
return self.nf # number of files
|
268 |
+
|
269 |
+
|
270 |
+
class LoadPilAndNumpy:
|
271 |
+
|
272 |
+
def __init__(self, im0, imgsz=640):
|
273 |
+
"""Initialize PIL and Numpy Dataloader."""
|
274 |
+
if not isinstance(im0, list):
|
275 |
+
im0 = [im0]
|
276 |
+
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
277 |
+
self.im0 = [self._single_check(im) for im in im0]
|
278 |
+
self.imgsz = imgsz
|
279 |
+
self.mode = 'image'
|
280 |
+
# Generate fake paths
|
281 |
+
self.bs = len(self.im0)
|
282 |
+
|
283 |
+
@staticmethod
|
284 |
+
def _single_check(im):
|
285 |
+
"""Validate and format an image to numpy array."""
|
286 |
+
assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
|
287 |
+
if isinstance(im, Image.Image):
|
288 |
+
if im.mode != 'RGB':
|
289 |
+
im = im.convert('RGB')
|
290 |
+
im = np.asarray(im)[:, :, ::-1]
|
291 |
+
im = np.ascontiguousarray(im) # contiguous
|
292 |
+
return im
|
293 |
+
|
294 |
+
def __len__(self):
|
295 |
+
"""Returns the length of the 'im0' attribute."""
|
296 |
+
return len(self.im0)
|
297 |
+
|
298 |
+
def __next__(self):
|
299 |
+
"""Returns batch paths, images, processed images, None, ''."""
|
300 |
+
if self.count == 1: # loop only once as it's batch inference
|
301 |
+
raise StopIteration
|
302 |
+
self.count += 1
|
303 |
+
return self.paths, self.im0, None, ''
|
304 |
+
|
305 |
+
def __iter__(self):
|
306 |
+
"""Enables iteration for class LoadPilAndNumpy."""
|
307 |
+
self.count = 0
|
308 |
+
return self
|
309 |
+
|
310 |
+
|
311 |
+
class LoadTensor:
|
312 |
+
|
313 |
+
def __init__(self, im0) -> None:
|
314 |
+
self.im0 = self._single_check(im0)
|
315 |
+
self.bs = self.im0.shape[0]
|
316 |
+
self.mode = 'image'
|
317 |
+
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
318 |
+
|
319 |
+
@staticmethod
|
320 |
+
def _single_check(im, stride=32):
|
321 |
+
"""Validate and format an image to torch.Tensor."""
|
322 |
+
s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
|
323 |
+
f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
|
324 |
+
if len(im.shape) != 4:
|
325 |
+
if len(im.shape) != 3:
|
326 |
+
raise ValueError(s)
|
327 |
+
LOGGER.warning(s)
|
328 |
+
im = im.unsqueeze(0)
|
329 |
+
if im.shape[2] % stride or im.shape[3] % stride:
|
330 |
+
raise ValueError(s)
|
331 |
+
if im.max() > 1.0:
|
332 |
+
LOGGER.warning(f'WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. '
|
333 |
+
f'Dividing input by 255.')
|
334 |
+
im = im.float() / 255.0
|
335 |
+
|
336 |
+
return im
|
337 |
+
|
338 |
+
def __iter__(self):
|
339 |
+
"""Returns an iterator object."""
|
340 |
+
self.count = 0
|
341 |
+
return self
|
342 |
+
|
343 |
+
def __next__(self):
|
344 |
+
"""Return next item in the iterator."""
|
345 |
+
if self.count == 1:
|
346 |
+
raise StopIteration
|
347 |
+
self.count += 1
|
348 |
+
return self.paths, self.im0, None, ''
|
349 |
+
|
350 |
+
def __len__(self):
|
351 |
+
"""Returns the batch size."""
|
352 |
+
return self.bs
|
353 |
+
|
354 |
+
|
355 |
+
def autocast_list(source):
|
356 |
+
"""
|
357 |
+
Merges a list of source of different types into a list of numpy arrays or PIL images
|
358 |
+
"""
|
359 |
+
files = []
|
360 |
+
for im in source:
|
361 |
+
if isinstance(im, (str, Path)): # filename or uri
|
362 |
+
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
|
363 |
+
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
364 |
+
files.append(im)
|
365 |
+
else:
|
366 |
+
raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
|
367 |
+
f'See https://docs.ultralytics.com/modes/predict for supported source types.')
|
368 |
+
|
369 |
+
return files
|
370 |
+
|
371 |
+
|
372 |
+
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
|
373 |
+
|
374 |
+
|
375 |
+
def get_best_youtube_url(url, use_pafy=True):
|
376 |
+
"""
|
377 |
+
Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
|
378 |
+
|
379 |
+
This function uses the pafy or yt_dlp library to extract the video info from YouTube. It then finds the highest
|
380 |
+
quality MP4 format that has video codec but no audio codec, and returns the URL of this video stream.
|
381 |
+
|
382 |
+
Args:
|
383 |
+
url (str): The URL of the YouTube video.
|
384 |
+
use_pafy (bool): Use the pafy package, default=True, otherwise use yt_dlp package.
|
385 |
+
|
386 |
+
Returns:
|
387 |
+
(str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
|
388 |
+
"""
|
389 |
+
if use_pafy:
|
390 |
+
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
391 |
+
import pafy # noqa
|
392 |
+
return pafy.new(url).getbest(preftype='mp4').url
|
393 |
+
else:
|
394 |
+
check_requirements('yt-dlp')
|
395 |
+
import yt_dlp
|
396 |
+
with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
|
397 |
+
info_dict = ydl.extract_info(url, download=False) # extract info
|
398 |
+
for f in info_dict.get('formats', None):
|
399 |
+
if f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4':
|
400 |
+
return f.get('url', None)
|
401 |
+
|
402 |
+
|
403 |
+
if __name__ == '__main__':
|
404 |
+
img = cv2.imread(str(ROOT / 'assets/bus.jpg'))
|
405 |
+
dataset = LoadPilAndNumpy(im0=img)
|
406 |
+
for d in dataset:
|
407 |
+
print(d[0])
|
ultralytics/data/scripts/download_weights.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
3 |
+
# Download latest models from https://github.com/ultralytics/assets/releases
|
4 |
+
# Example usage: bash ultralytics/data/scripts/download_weights.sh
|
5 |
+
# parent
|
6 |
+
# └── weights
|
7 |
+
# ├── yolov8n.pt ← downloads here
|
8 |
+
# ├── yolov8s.pt
|
9 |
+
# └── ...
|
10 |
+
|
11 |
+
python - <<EOF
|
12 |
+
from ultralytics.utils.downloads import attempt_download_asset
|
13 |
+
|
14 |
+
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose')]
|
15 |
+
for x in assets:
|
16 |
+
attempt_download_asset(f'weights/{x}')
|
17 |
+
|
18 |
+
EOF
|
ultralytics/data/scripts/get_coco.sh
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
3 |
+
# Download COCO 2017 dataset http://cocodataset.org
|
4 |
+
# Example usage: bash data/scripts/get_coco.sh
|
5 |
+
# parent
|
6 |
+
# ├── ultralytics
|
7 |
+
# └── datasets
|
8 |
+
# └── coco ← downloads here
|
9 |
+
|
10 |
+
# Arguments (optional) Usage: bash data/scripts/get_coco.sh --train --val --test --segments
|
11 |
+
if [ "$#" -gt 0 ]; then
|
12 |
+
for opt in "$@"; do
|
13 |
+
case "${opt}" in
|
14 |
+
--train) train=true ;;
|
15 |
+
--val) val=true ;;
|
16 |
+
--test) test=true ;;
|
17 |
+
--segments) segments=true ;;
|
18 |
+
--sama) sama=true ;;
|
19 |
+
esac
|
20 |
+
done
|
21 |
+
else
|
22 |
+
train=true
|
23 |
+
val=true
|
24 |
+
test=false
|
25 |
+
segments=false
|
26 |
+
sama=false
|
27 |
+
fi
|
28 |
+
|
29 |
+
# Download/unzip labels
|
30 |
+
d='../datasets' # unzip directory
|
31 |
+
url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
|
32 |
+
if [ "$segments" == "true" ]; then
|
33 |
+
f='coco2017labels-segments.zip' # 169 MB
|
34 |
+
elif [ "$sama" == "true" ]; then
|
35 |
+
f='coco2017labels-segments-sama.zip' # 199 MB https://www.sama.com/sama-coco-dataset/
|
36 |
+
else
|
37 |
+
f='coco2017labels.zip' # 46 MB
|
38 |
+
fi
|
39 |
+
echo 'Downloading' $url$f ' ...'
|
40 |
+
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
41 |
+
|
42 |
+
# Download/unzip images
|
43 |
+
d='../datasets/coco/images' # unzip directory
|
44 |
+
url=http://images.cocodataset.org/zips/
|
45 |
+
if [ "$train" == "true" ]; then
|
46 |
+
f='train2017.zip' # 19G, 118k images
|
47 |
+
echo 'Downloading' $url$f '...'
|
48 |
+
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
49 |
+
fi
|
50 |
+
if [ "$val" == "true" ]; then
|
51 |
+
f='val2017.zip' # 1G, 5k images
|
52 |
+
echo 'Downloading' $url$f '...'
|
53 |
+
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
54 |
+
fi
|
55 |
+
if [ "$test" == "true" ]; then
|
56 |
+
f='test2017.zip' # 7G, 41k images (optional)
|
57 |
+
echo 'Downloading' $url$f '...'
|
58 |
+
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
59 |
+
fi
|
60 |
+
wait # finish background tasks
|
ultralytics/data/scripts/get_coco128.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
3 |
+
# Download COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017)
|
4 |
+
# Example usage: bash data/scripts/get_coco128.sh
|
5 |
+
# parent
|
6 |
+
# ├── ultralytics
|
7 |
+
# └── datasets
|
8 |
+
# └── coco128 ← downloads here
|
9 |
+
|
10 |
+
# Download/unzip images and labels
|
11 |
+
d='../datasets' # unzip directory
|
12 |
+
url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
|
13 |
+
f='coco128.zip' # or 'coco128-segments.zip', 68 MB
|
14 |
+
echo 'Downloading' $url$f ' ...'
|
15 |
+
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
16 |
+
|
17 |
+
wait # finish background tasks
|
ultralytics/data/scripts/get_imagenet.sh
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
3 |
+
# Download ILSVRC2012 ImageNet dataset https://image-net.org
|
4 |
+
# Example usage: bash data/scripts/get_imagenet.sh
|
5 |
+
# parent
|
6 |
+
# ├── ultralytics
|
7 |
+
# └── datasets
|
8 |
+
# └── imagenet ← downloads here
|
9 |
+
|
10 |
+
# Arguments (optional) Usage: bash data/scripts/get_imagenet.sh --train --val
|
11 |
+
if [ "$#" -gt 0 ]; then
|
12 |
+
for opt in "$@"; do
|
13 |
+
case "${opt}" in
|
14 |
+
--train) train=true ;;
|
15 |
+
--val) val=true ;;
|
16 |
+
esac
|
17 |
+
done
|
18 |
+
else
|
19 |
+
train=true
|
20 |
+
val=true
|
21 |
+
fi
|
22 |
+
|
23 |
+
# Make dir
|
24 |
+
d='../datasets/imagenet' # unzip directory
|
25 |
+
mkdir -p $d && cd $d
|
26 |
+
|
27 |
+
# Download/unzip train
|
28 |
+
if [ "$train" == "true" ]; then
|
29 |
+
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar # download 138G, 1281167 images
|
30 |
+
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
|
31 |
+
tar -xf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
|
32 |
+
find . -name "*.tar" | while read NAME; do
|
33 |
+
mkdir -p "${NAME%.tar}"
|
34 |
+
tar -xf "${NAME}" -C "${NAME%.tar}"
|
35 |
+
rm -f "${NAME}"
|
36 |
+
done
|
37 |
+
cd ..
|
38 |
+
fi
|
39 |
+
|
40 |
+
# Download/unzip val
|
41 |
+
if [ "$val" == "true" ]; then
|
42 |
+
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar # download 6.3G, 50000 images
|
43 |
+
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xf ILSVRC2012_img_val.tar
|
44 |
+
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash # move into subdirs
|
45 |
+
fi
|
46 |
+
|
47 |
+
# Delete corrupted image (optional: PNG under JPEG name that may cause dataloaders to fail)
|
48 |
+
# rm train/n04266014/n04266014_10835.JPEG
|
49 |
+
|
50 |
+
# TFRecords (optional)
|
51 |
+
# wget https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_lsvrc_2015_synsets.txt
|
ultralytics/data/utils.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
import hashlib
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import subprocess
|
9 |
+
import time
|
10 |
+
import zipfile
|
11 |
+
from multiprocessing.pool import ThreadPool
|
12 |
+
from pathlib import Path
|
13 |
+
from tarfile import is_tarfile
|
14 |
+
|
15 |
+
import cv2
|
16 |
+
import numpy as np
|
17 |
+
from PIL import ExifTags, Image, ImageOps
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
from ultralytics.nn.autobackend import check_class_names
|
21 |
+
from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis,
|
22 |
+
yaml_load)
|
23 |
+
from ultralytics.utils.checks import check_file, check_font, is_ascii
|
24 |
+
from ultralytics.utils.downloads import download, safe_download, unzip_file
|
25 |
+
from ultralytics.utils.ops import segments2boxes
|
26 |
+
|
27 |
+
HELP_URL = 'See https://docs.ultralytics.com/yolov5/tutorials/train_custom_data'
|
28 |
+
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
|
29 |
+
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
|
30 |
+
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
31 |
+
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
|
32 |
+
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
|
33 |
+
|
34 |
+
# Get orientation exif tag
|
35 |
+
for orientation in ExifTags.TAGS.keys():
|
36 |
+
if ExifTags.TAGS[orientation] == 'Orientation':
|
37 |
+
break
|
38 |
+
|
39 |
+
|
40 |
+
def img2label_paths(img_paths):
|
41 |
+
"""Define label paths as a function of image paths."""
|
42 |
+
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
|
43 |
+
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
|
44 |
+
|
45 |
+
|
46 |
+
def get_hash(paths):
|
47 |
+
"""Returns a single hash value of a list of paths (files or dirs)."""
|
48 |
+
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
49 |
+
h = hashlib.sha256(str(size).encode()) # hash sizes
|
50 |
+
h.update(''.join(paths).encode()) # hash paths
|
51 |
+
return h.hexdigest() # return hash
|
52 |
+
|
53 |
+
|
54 |
+
def exif_size(img):
|
55 |
+
"""Returns exif-corrected PIL size."""
|
56 |
+
s = img.size # (width, height)
|
57 |
+
with contextlib.suppress(Exception):
|
58 |
+
rotation = dict(img._getexif().items())[orientation]
|
59 |
+
if rotation in [6, 8]: # rotation 270 or 90
|
60 |
+
s = (s[1], s[0])
|
61 |
+
return s
|
62 |
+
|
63 |
+
|
64 |
+
def verify_image_label(args):
|
65 |
+
"""Verify one image-label pair."""
|
66 |
+
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
|
67 |
+
# Number (missing, found, empty, corrupt), message, segments, keypoints
|
68 |
+
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
|
69 |
+
try:
|
70 |
+
# Verify images
|
71 |
+
im = Image.open(im_file)
|
72 |
+
im.verify() # PIL verify
|
73 |
+
shape = exif_size(im) # image size
|
74 |
+
shape = (shape[1], shape[0]) # hw
|
75 |
+
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
|
76 |
+
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
|
77 |
+
if im.format.lower() in ('jpg', 'jpeg'):
|
78 |
+
with open(im_file, 'rb') as f:
|
79 |
+
f.seek(-2, 2)
|
80 |
+
if f.read() != b'\xff\xd9': # corrupt JPEG
|
81 |
+
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
|
82 |
+
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
|
83 |
+
|
84 |
+
# Verify labels
|
85 |
+
if os.path.isfile(lb_file):
|
86 |
+
nf = 1 # label found
|
87 |
+
with open(lb_file) as f:
|
88 |
+
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
89 |
+
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
|
90 |
+
classes = np.array([x[0] for x in lb], dtype=np.float32)
|
91 |
+
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
|
92 |
+
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
93 |
+
lb = np.array(lb, dtype=np.float32)
|
94 |
+
nl = len(lb)
|
95 |
+
if nl:
|
96 |
+
if keypoint:
|
97 |
+
assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
|
98 |
+
assert (lb[:, 5::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
99 |
+
assert (lb[:, 6::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
100 |
+
else:
|
101 |
+
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
|
102 |
+
assert (lb[:, 1:] <= 1).all(), \
|
103 |
+
f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
|
104 |
+
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
|
105 |
+
# All labels
|
106 |
+
max_cls = int(lb[:, 0].max()) # max label count
|
107 |
+
assert max_cls <= num_cls, \
|
108 |
+
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
|
109 |
+
f'Possible class labels are 0-{num_cls - 1}'
|
110 |
+
_, i = np.unique(lb, axis=0, return_index=True)
|
111 |
+
if len(i) < nl: # duplicate row check
|
112 |
+
lb = lb[i] # remove duplicates
|
113 |
+
if segments:
|
114 |
+
segments = [segments[x] for x in i]
|
115 |
+
msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
|
116 |
+
else:
|
117 |
+
ne = 1 # label empty
|
118 |
+
lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros(
|
119 |
+
(0, 5), dtype=np.float32)
|
120 |
+
else:
|
121 |
+
nm = 1 # label missing
|
122 |
+
lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
123 |
+
if keypoint:
|
124 |
+
keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
|
125 |
+
if ndim == 2:
|
126 |
+
kpt_mask = np.ones(keypoints.shape[:2], dtype=np.float32)
|
127 |
+
kpt_mask = np.where(keypoints[..., 0] < 0, 0.0, kpt_mask)
|
128 |
+
kpt_mask = np.where(keypoints[..., 1] < 0, 0.0, kpt_mask)
|
129 |
+
keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
|
130 |
+
lb = lb[:, :5]
|
131 |
+
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
132 |
+
except Exception as e:
|
133 |
+
nc = 1
|
134 |
+
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
|
135 |
+
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
136 |
+
|
137 |
+
|
138 |
+
def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
|
139 |
+
"""
|
140 |
+
Args:
|
141 |
+
imgsz (tuple): The image size.
|
142 |
+
polygons (list[np.ndarray]): [N, M], N is the number of polygons, M is the number of points(Be divided by 2).
|
143 |
+
color (int): color
|
144 |
+
downsample_ratio (int): downsample ratio
|
145 |
+
"""
|
146 |
+
mask = np.zeros(imgsz, dtype=np.uint8)
|
147 |
+
polygons = np.asarray(polygons)
|
148 |
+
polygons = polygons.astype(np.int32)
|
149 |
+
shape = polygons.shape
|
150 |
+
polygons = polygons.reshape(shape[0], -1, 2)
|
151 |
+
cv2.fillPoly(mask, polygons, color=color)
|
152 |
+
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
|
153 |
+
# NOTE: fillPoly firstly then resize is trying the keep the same way
|
154 |
+
# of loss calculation when mask-ratio=1.
|
155 |
+
mask = cv2.resize(mask, (nw, nh))
|
156 |
+
return mask
|
157 |
+
|
158 |
+
|
159 |
+
def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
|
160 |
+
"""
|
161 |
+
Args:
|
162 |
+
imgsz (tuple): The image size.
|
163 |
+
polygons (list[np.ndarray]): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0)
|
164 |
+
color (int): color
|
165 |
+
downsample_ratio (int): downsample ratio
|
166 |
+
"""
|
167 |
+
masks = []
|
168 |
+
for si in range(len(polygons)):
|
169 |
+
mask = polygon2mask(imgsz, [polygons[si].reshape(-1)], color, downsample_ratio)
|
170 |
+
masks.append(mask)
|
171 |
+
return np.array(masks)
|
172 |
+
|
173 |
+
|
174 |
+
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
|
175 |
+
"""Return a (640, 640) overlap mask."""
|
176 |
+
masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
|
177 |
+
dtype=np.int32 if len(segments) > 255 else np.uint8)
|
178 |
+
areas = []
|
179 |
+
ms = []
|
180 |
+
for si in range(len(segments)):
|
181 |
+
mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
|
182 |
+
ms.append(mask)
|
183 |
+
areas.append(mask.sum())
|
184 |
+
areas = np.asarray(areas)
|
185 |
+
index = np.argsort(-areas)
|
186 |
+
ms = np.array(ms)[index]
|
187 |
+
for i in range(len(segments)):
|
188 |
+
mask = ms[i] * (i + 1)
|
189 |
+
masks = masks + mask
|
190 |
+
masks = np.clip(masks, a_min=0, a_max=i + 1)
|
191 |
+
return masks, index
|
192 |
+
|
193 |
+
|
194 |
+
def check_det_dataset(dataset, autodownload=True):
|
195 |
+
"""Download, check and/or unzip dataset if not found locally."""
|
196 |
+
data = check_file(dataset)
|
197 |
+
|
198 |
+
# Download (optional)
|
199 |
+
extract_dir = ''
|
200 |
+
if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)):
|
201 |
+
new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
|
202 |
+
data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
|
203 |
+
extract_dir, autodownload = data.parent, False
|
204 |
+
|
205 |
+
# Read yaml (optional)
|
206 |
+
if isinstance(data, (str, Path)):
|
207 |
+
data = yaml_load(data, append_filename=True) # dictionary
|
208 |
+
|
209 |
+
# Checks
|
210 |
+
for k in 'train', 'val':
|
211 |
+
if k not in data:
|
212 |
+
raise SyntaxError(
|
213 |
+
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
|
214 |
+
if 'names' not in data and 'nc' not in data:
|
215 |
+
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
|
216 |
+
if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
|
217 |
+
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
|
218 |
+
if 'names' not in data:
|
219 |
+
data['names'] = [f'class_{i}' for i in range(data['nc'])]
|
220 |
+
else:
|
221 |
+
data['nc'] = len(data['names'])
|
222 |
+
|
223 |
+
data['names'] = check_class_names(data['names'])
|
224 |
+
|
225 |
+
# Resolve paths
|
226 |
+
path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
|
227 |
+
|
228 |
+
if not path.is_absolute():
|
229 |
+
path = (DATASETS_DIR / path).resolve()
|
230 |
+
data['path'] = path # download scripts
|
231 |
+
for k in 'train', 'val', 'test':
|
232 |
+
if data.get(k): # prepend path
|
233 |
+
if isinstance(data[k], str):
|
234 |
+
x = (path / data[k]).resolve()
|
235 |
+
if not x.exists() and data[k].startswith('../'):
|
236 |
+
x = (path / data[k][3:]).resolve()
|
237 |
+
data[k] = str(x)
|
238 |
+
else:
|
239 |
+
data[k] = [str((path / x).resolve()) for x in data[k]]
|
240 |
+
|
241 |
+
# Parse yaml
|
242 |
+
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
|
243 |
+
if val:
|
244 |
+
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
245 |
+
if not all(x.exists() for x in val):
|
246 |
+
name = clean_url(dataset) # dataset name with URL auth stripped
|
247 |
+
m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
|
248 |
+
if s and autodownload:
|
249 |
+
LOGGER.warning(m)
|
250 |
+
else:
|
251 |
+
m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"
|
252 |
+
raise FileNotFoundError(m)
|
253 |
+
t = time.time()
|
254 |
+
if s.startswith('http') and s.endswith('.zip'): # URL
|
255 |
+
safe_download(url=s, dir=DATASETS_DIR, delete=True)
|
256 |
+
r = None # success
|
257 |
+
elif s.startswith('bash '): # bash script
|
258 |
+
LOGGER.info(f'Running {s} ...')
|
259 |
+
r = os.system(s)
|
260 |
+
else: # python script
|
261 |
+
r = exec(s, {'yaml': data}) # return None
|
262 |
+
dt = f'({round(time.time() - t, 1)}s)'
|
263 |
+
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
|
264 |
+
LOGGER.info(f'Dataset download {s}\n')
|
265 |
+
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
|
266 |
+
|
267 |
+
return data # dictionary
|
268 |
+
|
269 |
+
|
270 |
+
def check_cls_dataset(dataset: str, split=''):
|
271 |
+
"""
|
272 |
+
Checks a classification dataset such as Imagenet.
|
273 |
+
|
274 |
+
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
|
275 |
+
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
dataset (str): The name of the dataset.
|
279 |
+
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
(dict): A dictionary containing the following keys:
|
283 |
+
- 'train' (Path): The directory path containing the training set of the dataset.
|
284 |
+
- 'val' (Path): The directory path containing the validation set of the dataset.
|
285 |
+
- 'test' (Path): The directory path containing the test set of the dataset.
|
286 |
+
- 'nc' (int): The number of classes in the dataset.
|
287 |
+
- 'names' (dict): A dictionary of class names in the dataset.
|
288 |
+
|
289 |
+
Raises:
|
290 |
+
FileNotFoundError: If the specified dataset is not found and cannot be downloaded.
|
291 |
+
"""
|
292 |
+
|
293 |
+
dataset = Path(dataset)
|
294 |
+
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
|
295 |
+
if not data_dir.is_dir():
|
296 |
+
LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
297 |
+
t = time.time()
|
298 |
+
if str(dataset) == 'imagenet':
|
299 |
+
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
300 |
+
else:
|
301 |
+
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
302 |
+
download(url, dir=data_dir.parent)
|
303 |
+
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
304 |
+
LOGGER.info(s)
|
305 |
+
train_set = data_dir / 'train'
|
306 |
+
val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val
|
307 |
+
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
|
308 |
+
if split == 'val' and not val_set:
|
309 |
+
LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
310 |
+
elif split == 'test' and not test_set:
|
311 |
+
LOGGER.info("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
|
312 |
+
|
313 |
+
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
314 |
+
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
315 |
+
names = dict(enumerate(sorted(names)))
|
316 |
+
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
|
317 |
+
|
318 |
+
|
319 |
+
class HUBDatasetStats():
|
320 |
+
"""
|
321 |
+
A class for generating HUB dataset JSON and `-hub` dataset directory.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'.
|
325 |
+
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
|
326 |
+
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
|
327 |
+
|
328 |
+
Usage
|
329 |
+
from ultralytics.data.utils import HUBDatasetStats
|
330 |
+
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8.zip', task='detect') # detect dataset
|
331 |
+
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-seg.zip', task='segment') # segment dataset
|
332 |
+
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-pose.zip', task='pose') # pose dataset
|
333 |
+
stats.get_json(save=False)
|
334 |
+
stats.process_images()
|
335 |
+
"""
|
336 |
+
|
337 |
+
def __init__(self, path='coco128.yaml', task='detect', autodownload=False):
|
338 |
+
"""Initialize class."""
|
339 |
+
LOGGER.info(f'Starting HUB dataset checks for {path}....')
|
340 |
+
zipped, data_dir, yaml_path = self._unzip(Path(path))
|
341 |
+
try:
|
342 |
+
# data = yaml_load(check_yaml(yaml_path)) # data dict
|
343 |
+
data = check_det_dataset(yaml_path, autodownload) # data dict
|
344 |
+
if zipped:
|
345 |
+
data['path'] = data_dir
|
346 |
+
except Exception as e:
|
347 |
+
raise Exception('error/HUB/dataset_stats/yaml_load') from e
|
348 |
+
|
349 |
+
self.hub_dir = Path(str(data['path']) + '-hub')
|
350 |
+
self.im_dir = self.hub_dir / 'images'
|
351 |
+
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
352 |
+
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
|
353 |
+
self.data = data
|
354 |
+
self.task = task # detect, segment, pose, classify
|
355 |
+
|
356 |
+
@staticmethod
|
357 |
+
def _find_yaml(dir):
|
358 |
+
"""Return data.yaml file."""
|
359 |
+
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
|
360 |
+
assert files, f'No *.yaml file found in {dir}'
|
361 |
+
if len(files) > 1:
|
362 |
+
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
|
363 |
+
assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
|
364 |
+
assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
|
365 |
+
return files[0]
|
366 |
+
|
367 |
+
def _unzip(self, path):
|
368 |
+
"""Unzip data.zip."""
|
369 |
+
if not str(path).endswith('.zip'): # path is data.yaml
|
370 |
+
return False, None, path
|
371 |
+
unzip_dir = unzip_file(path, path=path.parent)
|
372 |
+
assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
|
373 |
+
f'path/to/abc.zip MUST unzip to path/to/abc/'
|
374 |
+
return True, str(unzip_dir), self._find_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
375 |
+
|
376 |
+
def _hub_ops(self, f):
|
377 |
+
"""Saves a compressed image for HUB previews."""
|
378 |
+
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
|
379 |
+
|
380 |
+
def get_json(self, save=False, verbose=False):
|
381 |
+
"""Return dataset JSON for Ultralytics HUB."""
|
382 |
+
from ultralytics.data import YOLODataset # ClassificationDataset
|
383 |
+
|
384 |
+
def _round(labels):
|
385 |
+
"""Update labels to integer class and 4 decimal place floats."""
|
386 |
+
if self.task == 'detect':
|
387 |
+
coordinates = labels['bboxes']
|
388 |
+
elif self.task == 'segment':
|
389 |
+
coordinates = [x.flatten() for x in labels['segments']]
|
390 |
+
elif self.task == 'pose':
|
391 |
+
n = labels['keypoints'].shape[0]
|
392 |
+
coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
|
393 |
+
else:
|
394 |
+
raise ValueError('Undefined dataset task.')
|
395 |
+
zipped = zip(labels['cls'], coordinates)
|
396 |
+
return [[int(c), *(round(float(x), 4) for x in points)] for c, points in zipped]
|
397 |
+
|
398 |
+
for split in 'train', 'val', 'test':
|
399 |
+
if self.data.get(split) is None:
|
400 |
+
self.stats[split] = None # i.e. no test set
|
401 |
+
continue
|
402 |
+
|
403 |
+
dataset = YOLODataset(img_path=self.data[split],
|
404 |
+
data=self.data,
|
405 |
+
use_segments=self.task == 'segment',
|
406 |
+
use_keypoints=self.task == 'pose')
|
407 |
+
x = np.array([
|
408 |
+
np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
|
409 |
+
for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
|
410 |
+
self.stats[split] = {
|
411 |
+
'instance_stats': {
|
412 |
+
'total': int(x.sum()),
|
413 |
+
'per_class': x.sum(0).tolist()},
|
414 |
+
'image_stats': {
|
415 |
+
'total': len(dataset),
|
416 |
+
'unlabelled': int(np.all(x == 0, 1).sum()),
|
417 |
+
'per_class': (x > 0).sum(0).tolist()},
|
418 |
+
'labels': [{
|
419 |
+
Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
|
420 |
+
|
421 |
+
# Save, print and return
|
422 |
+
if save:
|
423 |
+
stats_path = self.hub_dir / 'stats.json'
|
424 |
+
LOGGER.info(f'Saving {stats_path.resolve()}...')
|
425 |
+
with open(stats_path, 'w') as f:
|
426 |
+
json.dump(self.stats, f) # save stats.json
|
427 |
+
if verbose:
|
428 |
+
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
429 |
+
return self.stats
|
430 |
+
|
431 |
+
def process_images(self):
|
432 |
+
"""Compress images for Ultralytics HUB."""
|
433 |
+
from ultralytics.data import YOLODataset # ClassificationDataset
|
434 |
+
|
435 |
+
for split in 'train', 'val', 'test':
|
436 |
+
if self.data.get(split) is None:
|
437 |
+
continue
|
438 |
+
dataset = YOLODataset(img_path=self.data[split], data=self.data)
|
439 |
+
with ThreadPool(NUM_THREADS) as pool:
|
440 |
+
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
|
441 |
+
pass
|
442 |
+
LOGGER.info(f'Done. All images saved to {self.im_dir}')
|
443 |
+
return self.im_dir
|
444 |
+
|
445 |
+
|
446 |
+
def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
447 |
+
"""
|
448 |
+
Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the
|
449 |
+
Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will
|
450 |
+
not be resized.
|
451 |
+
|
452 |
+
Args:
|
453 |
+
f (str): The path to the input image file.
|
454 |
+
f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
|
455 |
+
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
|
456 |
+
quality (int, optional): The image compression quality as a percentage. Default is 50%.
|
457 |
+
|
458 |
+
Usage:
|
459 |
+
from pathlib import Path
|
460 |
+
from ultralytics.data.utils import compress_one_image
|
461 |
+
for f in Path('/Users/glennjocher/Downloads/dataset').rglob('*.jpg'):
|
462 |
+
compress_one_image(f)
|
463 |
+
"""
|
464 |
+
try: # use PIL
|
465 |
+
im = Image.open(f)
|
466 |
+
r = max_dim / max(im.height, im.width) # ratio
|
467 |
+
if r < 1.0: # image too large
|
468 |
+
im = im.resize((int(im.width * r), int(im.height * r)))
|
469 |
+
im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
|
470 |
+
except Exception as e: # use OpenCV
|
471 |
+
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
|
472 |
+
im = cv2.imread(f)
|
473 |
+
im_height, im_width = im.shape[:2]
|
474 |
+
r = max_dim / max(im_height, im_width) # ratio
|
475 |
+
if r < 1.0: # image too large
|
476 |
+
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
477 |
+
cv2.imwrite(str(f_new or f), im)
|
478 |
+
|
479 |
+
|
480 |
+
def delete_dsstore(path):
|
481 |
+
"""
|
482 |
+
Deletes all ".DS_store" files under a specified directory.
|
483 |
+
|
484 |
+
Args:
|
485 |
+
path (str, optional): The directory path where the ".DS_store" files should be deleted.
|
486 |
+
|
487 |
+
Usage:
|
488 |
+
from ultralytics.data.utils import delete_dsstore
|
489 |
+
delete_dsstore('/Users/glennjocher/Downloads/dataset')
|
490 |
+
|
491 |
+
Note:
|
492 |
+
".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
|
493 |
+
are hidden system files and can cause issues when transferring files between different operating systems.
|
494 |
+
"""
|
495 |
+
# Delete Apple .DS_store files
|
496 |
+
files = list(Path(path).rglob('.DS_store'))
|
497 |
+
LOGGER.info(f'Deleting *.DS_store files: {files}')
|
498 |
+
for f in files:
|
499 |
+
f.unlink()
|
500 |
+
|
501 |
+
|
502 |
+
def zip_directory(dir, use_zipfile_library=True):
|
503 |
+
"""
|
504 |
+
Zips a directory and saves the archive to the specified output path.
|
505 |
+
|
506 |
+
Args:
|
507 |
+
dir (str): The path to the directory to be zipped.
|
508 |
+
use_zipfile_library (bool): Whether to use zipfile library or shutil for zipping.
|
509 |
+
|
510 |
+
Usage:
|
511 |
+
from ultralytics.data.utils import zip_directory
|
512 |
+
zip_directory('/Users/glennjocher/Downloads/playground')
|
513 |
+
|
514 |
+
zip -r coco8-pose.zip coco8-pose
|
515 |
+
"""
|
516 |
+
delete_dsstore(dir)
|
517 |
+
if use_zipfile_library:
|
518 |
+
dir = Path(dir)
|
519 |
+
with zipfile.ZipFile(dir.with_suffix('.zip'), 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
520 |
+
for file_path in dir.glob('**/*'):
|
521 |
+
if file_path.is_file():
|
522 |
+
zip_file.write(file_path, file_path.relative_to(dir))
|
523 |
+
else:
|
524 |
+
import shutil
|
525 |
+
shutil.make_archive(dir, 'zip', dir)
|
526 |
+
|
527 |
+
|
528 |
+
def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
|
529 |
+
"""
|
530 |
+
Autosplit a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
|
531 |
+
|
532 |
+
Args:
|
533 |
+
path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco128/images'.
|
534 |
+
weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
|
535 |
+
annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
|
536 |
+
|
537 |
+
Usage:
|
538 |
+
from utils.dataloaders import autosplit
|
539 |
+
autosplit()
|
540 |
+
"""
|
541 |
+
|
542 |
+
path = Path(path) # images dir
|
543 |
+
files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
|
544 |
+
n = len(files) # number of files
|
545 |
+
random.seed(0) # for reproducibility
|
546 |
+
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
547 |
+
|
548 |
+
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
|
549 |
+
for x in txt:
|
550 |
+
if (path.parent / x).exists():
|
551 |
+
(path.parent / x).unlink() # remove existing
|
552 |
+
|
553 |
+
LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
|
554 |
+
for i, img in tqdm(zip(indices, files), total=n):
|
555 |
+
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
556 |
+
with open(path.parent / txt[i], 'a') as f:
|
557 |
+
f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
|
ultralytics/engine/__init__.py
ADDED
File without changes
|
ultralytics/engine/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (175 Bytes). View file
|
|
ultralytics/engine/__pycache__/exporter.cpython-39.pyc
ADDED
Binary file (34.4 kB). View file
|
|
ultralytics/engine/__pycache__/model.cpython-39.pyc
ADDED
Binary file (17.8 kB). View file
|
|
ultralytics/engine/__pycache__/predictor.cpython-39.pyc
ADDED
Binary file (14.1 kB). View file
|
|
ultralytics/engine/__pycache__/results.cpython-39.pyc
ADDED
Binary file (24.8 kB). View file
|
|
ultralytics/engine/__pycache__/trainer.cpython-39.pyc
ADDED
Binary file (24 kB). View file
|
|
ultralytics/engine/__pycache__/validator.cpython-39.pyc
ADDED
Binary file (11.1 kB). View file
|
|
ultralytics/engine/exporter.py
ADDED
@@ -0,0 +1,969 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Export a YOLOv8 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
|
4 |
+
|
5 |
+
Format | `format=argument` | Model
|
6 |
+
--- | --- | ---
|
7 |
+
PyTorch | - | yolov8n.pt
|
8 |
+
TorchScript | `torchscript` | yolov8n.torchscript
|
9 |
+
ONNX | `onnx` | yolov8n.onnx
|
10 |
+
OpenVINO | `openvino` | yolov8n_openvino_model/
|
11 |
+
TensorRT | `engine` | yolov8n.engine
|
12 |
+
CoreML | `coreml` | yolov8n.mlmodel
|
13 |
+
TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
|
14 |
+
TensorFlow GraphDef | `pb` | yolov8n.pb
|
15 |
+
TensorFlow Lite | `tflite` | yolov8n.tflite
|
16 |
+
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
|
17 |
+
TensorFlow.js | `tfjs` | yolov8n_web_model/
|
18 |
+
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
19 |
+
ncnn | `ncnn` | yolov8n_ncnn_model/
|
20 |
+
|
21 |
+
Requirements:
|
22 |
+
$ pip install "ultralytics[export]"
|
23 |
+
|
24 |
+
Python:
|
25 |
+
from ultralytics import YOLO
|
26 |
+
model = YOLO('yolov8n.pt')
|
27 |
+
results = model.export(format='onnx')
|
28 |
+
|
29 |
+
CLI:
|
30 |
+
$ yolo mode=export model=yolov8n.pt format=onnx
|
31 |
+
|
32 |
+
Inference:
|
33 |
+
$ yolo predict model=yolov8n.pt # PyTorch
|
34 |
+
yolov8n.torchscript # TorchScript
|
35 |
+
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
36 |
+
yolov8n_openvino_model # OpenVINO
|
37 |
+
yolov8n.engine # TensorRT
|
38 |
+
yolov8n.mlmodel # CoreML (macOS-only)
|
39 |
+
yolov8n_saved_model # TensorFlow SavedModel
|
40 |
+
yolov8n.pb # TensorFlow GraphDef
|
41 |
+
yolov8n.tflite # TensorFlow Lite
|
42 |
+
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
43 |
+
yolov8n_paddle_model # PaddlePaddle
|
44 |
+
|
45 |
+
TensorFlow.js:
|
46 |
+
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
47 |
+
$ npm install
|
48 |
+
$ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
|
49 |
+
$ npm start
|
50 |
+
"""
|
51 |
+
import json
|
52 |
+
import os
|
53 |
+
import shutil
|
54 |
+
import subprocess
|
55 |
+
import time
|
56 |
+
import warnings
|
57 |
+
from copy import deepcopy
|
58 |
+
from datetime import datetime
|
59 |
+
from pathlib import Path
|
60 |
+
|
61 |
+
import torch
|
62 |
+
|
63 |
+
from ultralytics.cfg import get_cfg
|
64 |
+
from ultralytics.nn.autobackend import check_class_names
|
65 |
+
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
66 |
+
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
67 |
+
from ultralytics.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, WINDOWS, __version__, callbacks,
|
68 |
+
colorstr, get_default_args, yaml_save)
|
69 |
+
from ultralytics.utils.checks import check_imgsz, check_requirements, check_version
|
70 |
+
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
|
71 |
+
from ultralytics.utils.files import file_size, spaces_in_path
|
72 |
+
from ultralytics.utils.ops import Profile
|
73 |
+
from ultralytics.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
|
74 |
+
|
75 |
+
|
76 |
+
def export_formats():
|
77 |
+
"""YOLOv8 export formats."""
|
78 |
+
import pandas
|
79 |
+
x = [
|
80 |
+
['PyTorch', '-', '.pt', True, True],
|
81 |
+
['TorchScript', 'torchscript', '.torchscript', True, True],
|
82 |
+
['ONNX', 'onnx', '.onnx', True, True],
|
83 |
+
['OpenVINO', 'openvino', '_openvino_model', True, False],
|
84 |
+
['TensorRT', 'engine', '.engine', False, True],
|
85 |
+
['CoreML', 'coreml', '.mlmodel', True, False],
|
86 |
+
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
|
87 |
+
['TensorFlow GraphDef', 'pb', '.pb', True, True],
|
88 |
+
['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
89 |
+
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
|
90 |
+
['TensorFlow.js', 'tfjs', '_web_model', True, False],
|
91 |
+
['PaddlePaddle', 'paddle', '_paddle_model', True, True],
|
92 |
+
['ncnn', 'ncnn', '_ncnn_model', True, True], ]
|
93 |
+
return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
94 |
+
|
95 |
+
|
96 |
+
def gd_outputs(gd):
|
97 |
+
"""TensorFlow GraphDef model output node names."""
|
98 |
+
name_list, input_list = [], []
|
99 |
+
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
100 |
+
name_list.append(node.name)
|
101 |
+
input_list.extend(node.input)
|
102 |
+
return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
|
103 |
+
|
104 |
+
|
105 |
+
def try_export(inner_func):
|
106 |
+
"""YOLOv8 export decorator, i..e @try_export."""
|
107 |
+
inner_args = get_default_args(inner_func)
|
108 |
+
|
109 |
+
def outer_func(*args, **kwargs):
|
110 |
+
"""Export a model."""
|
111 |
+
prefix = inner_args['prefix']
|
112 |
+
try:
|
113 |
+
with Profile() as dt:
|
114 |
+
f, model = inner_func(*args, **kwargs)
|
115 |
+
LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
|
116 |
+
return f, model
|
117 |
+
except Exception as e:
|
118 |
+
LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
|
119 |
+
raise e
|
120 |
+
|
121 |
+
return outer_func
|
122 |
+
|
123 |
+
|
124 |
+
class Exporter:
|
125 |
+
"""
|
126 |
+
A class for exporting a model.
|
127 |
+
|
128 |
+
Attributes:
|
129 |
+
args (SimpleNamespace): Configuration for the exporter.
|
130 |
+
save_dir (Path): Directory to save results.
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
134 |
+
"""
|
135 |
+
Initializes the Exporter class.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
139 |
+
overrides (dict, optional): Configuration overrides. Defaults to None.
|
140 |
+
_callbacks (list, optional): List of callback functions. Defaults to None.
|
141 |
+
"""
|
142 |
+
self.args = get_cfg(cfg, overrides)
|
143 |
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
144 |
+
callbacks.add_integration_callbacks(self)
|
145 |
+
|
146 |
+
@smart_inference_mode()
|
147 |
+
def __call__(self, model=None):
|
148 |
+
"""Returns list of exported files/dirs after running callbacks."""
|
149 |
+
self.run_callbacks('on_export_start')
|
150 |
+
t = time.time()
|
151 |
+
format = self.args.format.lower() # to lowercase
|
152 |
+
if format in ('tensorrt', 'trt'): # engine aliases
|
153 |
+
format = 'engine'
|
154 |
+
fmts = tuple(export_formats()['Argument'][1:]) # available export formats
|
155 |
+
flags = [x == format for x in fmts]
|
156 |
+
if sum(flags) != 1:
|
157 |
+
raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}")
|
158 |
+
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
|
159 |
+
|
160 |
+
# Load PyTorch model
|
161 |
+
self.device = select_device('cpu' if self.args.device is None else self.args.device)
|
162 |
+
|
163 |
+
# Checks
|
164 |
+
model.names = check_class_names(model.names)
|
165 |
+
if self.args.half and onnx and self.device.type == 'cpu':
|
166 |
+
LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0')
|
167 |
+
self.args.half = False
|
168 |
+
assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.'
|
169 |
+
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
170 |
+
if self.args.optimize:
|
171 |
+
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
|
172 |
+
assert self.device.type == 'cpu', "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
|
173 |
+
if edgetpu and not LINUX:
|
174 |
+
raise SystemError('Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/')
|
175 |
+
|
176 |
+
# Input
|
177 |
+
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
|
178 |
+
file = Path(
|
179 |
+
getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', ''))
|
180 |
+
if file.suffix in ('.yaml', '.yml'):
|
181 |
+
file = Path(file.name)
|
182 |
+
|
183 |
+
# Update model
|
184 |
+
model = deepcopy(model).to(self.device)
|
185 |
+
for p in model.parameters():
|
186 |
+
p.requires_grad = False
|
187 |
+
model.eval()
|
188 |
+
model.float()
|
189 |
+
model = model.fuse()
|
190 |
+
for k, m in model.named_modules():
|
191 |
+
if isinstance(m, (Detect, RTDETRDecoder)): # Segment and Pose use Detect base class
|
192 |
+
m.dynamic = self.args.dynamic
|
193 |
+
m.export = True
|
194 |
+
m.format = self.args.format
|
195 |
+
elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
|
196 |
+
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
|
197 |
+
m.forward = m.forward_split
|
198 |
+
|
199 |
+
y = None
|
200 |
+
for _ in range(2):
|
201 |
+
y = model(im) # dry runs
|
202 |
+
if self.args.half and (engine or onnx) and self.device.type != 'cpu':
|
203 |
+
im, model = im.half(), model.half() # to FP16
|
204 |
+
|
205 |
+
# Filter warnings
|
206 |
+
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
207 |
+
warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant missing ONNX warning
|
208 |
+
warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
|
209 |
+
|
210 |
+
# Assign
|
211 |
+
self.im = im
|
212 |
+
self.model = model
|
213 |
+
self.file = file
|
214 |
+
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else \
|
215 |
+
tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
|
216 |
+
self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
|
217 |
+
trained_on = f'trained on {Path(self.args.data).name}' if self.args.data else '(untrained)'
|
218 |
+
description = f'Ultralytics {self.pretty_name} model {trained_on}'
|
219 |
+
self.metadata = {
|
220 |
+
'description': description,
|
221 |
+
'author': 'Ultralytics',
|
222 |
+
'license': 'AGPL-3.0 https://ultralytics.com/license',
|
223 |
+
'date': datetime.now().isoformat(),
|
224 |
+
'version': __version__,
|
225 |
+
'stride': int(max(model.stride)),
|
226 |
+
'task': model.task,
|
227 |
+
'batch': self.args.batch,
|
228 |
+
'imgsz': self.imgsz,
|
229 |
+
'names': model.names} # model metadata
|
230 |
+
if model.task == 'pose':
|
231 |
+
self.metadata['kpt_shape'] = model.model[-1].kpt_shape
|
232 |
+
|
233 |
+
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
|
234 |
+
f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
|
235 |
+
|
236 |
+
# Exports
|
237 |
+
f = [''] * len(fmts) # exported filenames
|
238 |
+
if jit or ncnn: # TorchScript
|
239 |
+
f[0], _ = self.export_torchscript()
|
240 |
+
if engine: # TensorRT required before ONNX
|
241 |
+
f[1], _ = self.export_engine()
|
242 |
+
if onnx or xml: # OpenVINO requires ONNX
|
243 |
+
f[2], _ = self.export_onnx()
|
244 |
+
if xml: # OpenVINO
|
245 |
+
f[3], _ = self.export_openvino()
|
246 |
+
if coreml: # CoreML
|
247 |
+
f[4], _ = self.export_coreml()
|
248 |
+
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
249 |
+
self.args.int8 |= edgetpu
|
250 |
+
f[5], s_model = self.export_saved_model()
|
251 |
+
if pb or tfjs: # pb prerequisite to tfjs
|
252 |
+
f[6], _ = self.export_pb(s_model)
|
253 |
+
if tflite:
|
254 |
+
f[7], _ = self.export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms)
|
255 |
+
if edgetpu:
|
256 |
+
f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite')
|
257 |
+
if tfjs:
|
258 |
+
f[9], _ = self.export_tfjs()
|
259 |
+
if paddle: # PaddlePaddle
|
260 |
+
f[10], _ = self.export_paddle()
|
261 |
+
if ncnn: # ncnn
|
262 |
+
f[11], _ = self.export_ncnn()
|
263 |
+
|
264 |
+
# Finish
|
265 |
+
f = [str(x) for x in f if x] # filter out '' and None
|
266 |
+
if any(f):
|
267 |
+
f = str(Path(f[-1]))
|
268 |
+
square = self.imgsz[0] == self.imgsz[1]
|
269 |
+
s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
|
270 |
+
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
271 |
+
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
|
272 |
+
data = f'data={self.args.data}' if model.task == 'segment' and format == 'pb' else ''
|
273 |
+
LOGGER.info(
|
274 |
+
f'\nExport complete ({time.time() - t:.1f}s)'
|
275 |
+
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
276 |
+
f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {data}'
|
277 |
+
f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={self.args.data} {s}'
|
278 |
+
f'\nVisualize: https://netron.app')
|
279 |
+
|
280 |
+
self.run_callbacks('on_export_end')
|
281 |
+
return f # return list of exported files/dirs
|
282 |
+
|
283 |
+
@try_export
|
284 |
+
def export_torchscript(self, prefix=colorstr('TorchScript:')):
|
285 |
+
"""YOLOv8 TorchScript model export."""
|
286 |
+
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
|
287 |
+
f = self.file.with_suffix('.torchscript')
|
288 |
+
|
289 |
+
ts = torch.jit.trace(self.model, self.im, strict=False)
|
290 |
+
extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
|
291 |
+
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
292 |
+
LOGGER.info(f'{prefix} optimizing for mobile...')
|
293 |
+
from torch.utils.mobile_optimizer import optimize_for_mobile
|
294 |
+
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
|
295 |
+
else:
|
296 |
+
ts.save(str(f), _extra_files=extra_files)
|
297 |
+
return f, None
|
298 |
+
|
299 |
+
@try_export
|
300 |
+
def export_onnx(self, prefix=colorstr('ONNX:')):
|
301 |
+
"""YOLOv8 ONNX export."""
|
302 |
+
requirements = ['onnx>=1.12.0']
|
303 |
+
if self.args.simplify:
|
304 |
+
requirements += ['onnxsim>=0.4.17', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
|
305 |
+
check_requirements(requirements)
|
306 |
+
import onnx # noqa
|
307 |
+
|
308 |
+
opset_version = self.args.opset or get_latest_opset()
|
309 |
+
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...')
|
310 |
+
f = str(self.file.with_suffix('.onnx'))
|
311 |
+
|
312 |
+
output_names = ['output0', 'output1'] if isinstance(self.model, SegmentationModel) else ['output0']
|
313 |
+
dynamic = self.args.dynamic
|
314 |
+
if dynamic:
|
315 |
+
dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
|
316 |
+
if isinstance(self.model, SegmentationModel):
|
317 |
+
dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 116, 8400)
|
318 |
+
dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
|
319 |
+
elif isinstance(self.model, DetectionModel):
|
320 |
+
dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 84, 8400)
|
321 |
+
|
322 |
+
torch.onnx.export(
|
323 |
+
self.model.cpu() if dynamic else self.model, # --dynamic only compatible with cpu
|
324 |
+
self.im.cpu() if dynamic else self.im,
|
325 |
+
f,
|
326 |
+
verbose=False,
|
327 |
+
opset_version=opset_version,
|
328 |
+
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
329 |
+
input_names=['images'],
|
330 |
+
output_names=output_names,
|
331 |
+
dynamic_axes=dynamic or None)
|
332 |
+
|
333 |
+
# Checks
|
334 |
+
model_onnx = onnx.load(f) # load onnx model
|
335 |
+
# onnx.checker.check_model(model_onnx) # check onnx model
|
336 |
+
|
337 |
+
# Simplify
|
338 |
+
if self.args.simplify:
|
339 |
+
try:
|
340 |
+
import onnxsim
|
341 |
+
|
342 |
+
LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
|
343 |
+
# subprocess.run(f'onnxsim "{f}" "{f}"', shell=True)
|
344 |
+
model_onnx, check = onnxsim.simplify(model_onnx)
|
345 |
+
assert check, 'Simplified ONNX model could not be validated'
|
346 |
+
except Exception as e:
|
347 |
+
LOGGER.info(f'{prefix} simplifier failure: {e}')
|
348 |
+
|
349 |
+
# Metadata
|
350 |
+
for k, v in self.metadata.items():
|
351 |
+
meta = model_onnx.metadata_props.add()
|
352 |
+
meta.key, meta.value = k, str(v)
|
353 |
+
|
354 |
+
onnx.save(model_onnx, f)
|
355 |
+
return f, model_onnx
|
356 |
+
|
357 |
+
@try_export
|
358 |
+
def export_openvino(self, prefix=colorstr('OpenVINO:')):
|
359 |
+
"""YOLOv8 OpenVINO export."""
|
360 |
+
check_requirements('openvino-dev>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
361 |
+
import openvino.runtime as ov # noqa
|
362 |
+
from openvino.tools import mo # noqa
|
363 |
+
|
364 |
+
LOGGER.info(f'\n{prefix} starting export with openvino {ov.__version__}...')
|
365 |
+
f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
|
366 |
+
f_onnx = self.file.with_suffix('.onnx')
|
367 |
+
f_ov = str(Path(f) / self.file.with_suffix('.xml').name)
|
368 |
+
|
369 |
+
ov_model = mo.convert_model(f_onnx,
|
370 |
+
model_name=self.pretty_name,
|
371 |
+
framework='onnx',
|
372 |
+
compress_to_fp16=self.args.half) # export
|
373 |
+
|
374 |
+
# Set RT info
|
375 |
+
ov_model.set_rt_info('YOLOv8', ['model_info', 'model_type'])
|
376 |
+
ov_model.set_rt_info(True, ['model_info', 'reverse_input_channels'])
|
377 |
+
ov_model.set_rt_info(114, ['model_info', 'pad_value'])
|
378 |
+
ov_model.set_rt_info([255.0], ['model_info', 'scale_values'])
|
379 |
+
ov_model.set_rt_info(self.args.iou, ['model_info', 'iou_threshold'])
|
380 |
+
ov_model.set_rt_info([v.replace(' ', '_') for k, v in sorted(self.model.names.items())],
|
381 |
+
['model_info', 'labels'])
|
382 |
+
if self.model.task != 'classify':
|
383 |
+
ov_model.set_rt_info('fit_to_window_letterbox', ['model_info', 'resize_type'])
|
384 |
+
|
385 |
+
ov.serialize(ov_model, f_ov) # save
|
386 |
+
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
387 |
+
return f, None
|
388 |
+
|
389 |
+
@try_export
|
390 |
+
def export_paddle(self, prefix=colorstr('PaddlePaddle:')):
|
391 |
+
"""YOLOv8 Paddle export."""
|
392 |
+
check_requirements(('paddlepaddle', 'x2paddle'))
|
393 |
+
import x2paddle # noqa
|
394 |
+
from x2paddle.convert import pytorch2paddle # noqa
|
395 |
+
|
396 |
+
LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
|
397 |
+
f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}')
|
398 |
+
|
399 |
+
pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export
|
400 |
+
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
401 |
+
return f, None
|
402 |
+
|
403 |
+
@try_export
|
404 |
+
def export_ncnn(self, prefix=colorstr('ncnn:')):
|
405 |
+
"""
|
406 |
+
YOLOv8 ncnn export using PNNX https://github.com/pnnx/pnnx.
|
407 |
+
"""
|
408 |
+
check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn
|
409 |
+
import ncnn # noqa
|
410 |
+
|
411 |
+
LOGGER.info(f'\n{prefix} starting export with ncnn {ncnn.__version__}...')
|
412 |
+
f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}'))
|
413 |
+
f_ts = self.file.with_suffix('.torchscript')
|
414 |
+
|
415 |
+
pnnx_filename = 'pnnx.exe' if WINDOWS else 'pnnx'
|
416 |
+
if Path(pnnx_filename).is_file():
|
417 |
+
pnnx = pnnx_filename
|
418 |
+
elif (ROOT / pnnx_filename).is_file():
|
419 |
+
pnnx = ROOT / pnnx_filename
|
420 |
+
else:
|
421 |
+
LOGGER.warning(
|
422 |
+
f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
|
423 |
+
'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
|
424 |
+
f'or in {ROOT}. See PNNX repo for full installation instructions.')
|
425 |
+
_, assets = get_github_assets(repo='pnnx/pnnx', retry=True)
|
426 |
+
asset = [x for x in assets if ('macos' if MACOS else 'ubuntu' if LINUX else 'windows') in x][0]
|
427 |
+
attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
|
428 |
+
unzip_dir = Path(asset).with_suffix('')
|
429 |
+
pnnx = ROOT / pnnx_filename # new location
|
430 |
+
(unzip_dir / pnnx_filename).rename(pnnx) # move binary to ROOT
|
431 |
+
shutil.rmtree(unzip_dir) # delete unzip dir
|
432 |
+
Path(asset).unlink() # delete zip
|
433 |
+
pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
|
434 |
+
|
435 |
+
use_ncnn = True
|
436 |
+
ncnn_args = [
|
437 |
+
f'ncnnparam={f / "model.ncnn.param"}',
|
438 |
+
f'ncnnbin={f / "model.ncnn.bin"}',
|
439 |
+
f'ncnnpy={f / "model_ncnn.py"}', ] if use_ncnn else []
|
440 |
+
|
441 |
+
use_pnnx = False
|
442 |
+
pnnx_args = [
|
443 |
+
f'pnnxparam={f / "model.pnnx.param"}',
|
444 |
+
f'pnnxbin={f / "model.pnnx.bin"}',
|
445 |
+
f'pnnxpy={f / "model_pnnx.py"}',
|
446 |
+
f'pnnxonnx={f / "model.pnnx.onnx"}', ] if use_pnnx else []
|
447 |
+
|
448 |
+
cmd = [
|
449 |
+
str(pnnx),
|
450 |
+
str(f_ts),
|
451 |
+
*ncnn_args,
|
452 |
+
*pnnx_args,
|
453 |
+
f'fp16={int(self.args.half)}',
|
454 |
+
f'device={self.device.type}',
|
455 |
+
f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ]
|
456 |
+
f.mkdir(exist_ok=True) # make ncnn_model directory
|
457 |
+
LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
|
458 |
+
subprocess.run(cmd, check=True)
|
459 |
+
for f_debug in 'debug.bin', 'debug.param', 'debug2.bin', 'debug2.param': # remove debug files
|
460 |
+
Path(f_debug).unlink(missing_ok=True)
|
461 |
+
|
462 |
+
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
463 |
+
return str(f), None
|
464 |
+
|
465 |
+
@try_export
|
466 |
+
def export_coreml(self, prefix=colorstr('CoreML:')):
|
467 |
+
"""YOLOv8 CoreML export."""
|
468 |
+
check_requirements('coremltools>=6.0,<=6.2')
|
469 |
+
import coremltools as ct # noqa
|
470 |
+
|
471 |
+
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
|
472 |
+
f = self.file.with_suffix('.mlmodel')
|
473 |
+
|
474 |
+
bias = [0.0, 0.0, 0.0]
|
475 |
+
scale = 1 / 255
|
476 |
+
classifier_config = None
|
477 |
+
if self.model.task == 'classify':
|
478 |
+
classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
|
479 |
+
model = self.model
|
480 |
+
elif self.model.task == 'detect':
|
481 |
+
model = iOSDetectModel(self.model, self.im) if self.args.nms else self.model
|
482 |
+
else:
|
483 |
+
# TODO CoreML Segment and Pose model pipelining
|
484 |
+
model = self.model
|
485 |
+
|
486 |
+
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
|
487 |
+
ct_model = ct.convert(ts,
|
488 |
+
inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)],
|
489 |
+
classifier_config=classifier_config)
|
490 |
+
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
|
491 |
+
if bits < 32:
|
492 |
+
if 'kmeans' in mode:
|
493 |
+
check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
|
494 |
+
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
495 |
+
if self.args.nms and self.model.task == 'detect':
|
496 |
+
ct_model = self._pipeline_coreml(ct_model)
|
497 |
+
|
498 |
+
m = self.metadata # metadata dict
|
499 |
+
ct_model.short_description = m.pop('description')
|
500 |
+
ct_model.author = m.pop('author')
|
501 |
+
ct_model.license = m.pop('license')
|
502 |
+
ct_model.version = m.pop('version')
|
503 |
+
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
|
504 |
+
ct_model.save(str(f))
|
505 |
+
return f, ct_model
|
506 |
+
|
507 |
+
@try_export
|
508 |
+
def export_engine(self, prefix=colorstr('TensorRT:')):
|
509 |
+
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
|
510 |
+
assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
511 |
+
try:
|
512 |
+
import tensorrt as trt # noqa
|
513 |
+
except ImportError:
|
514 |
+
if LINUX:
|
515 |
+
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
|
516 |
+
import tensorrt as trt # noqa
|
517 |
+
|
518 |
+
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
519 |
+
self.args.simplify = True
|
520 |
+
f_onnx, _ = self.export_onnx()
|
521 |
+
|
522 |
+
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
523 |
+
assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}'
|
524 |
+
f = self.file.with_suffix('.engine') # TensorRT engine file
|
525 |
+
logger = trt.Logger(trt.Logger.INFO)
|
526 |
+
if self.args.verbose:
|
527 |
+
logger.min_severity = trt.Logger.Severity.VERBOSE
|
528 |
+
|
529 |
+
builder = trt.Builder(logger)
|
530 |
+
config = builder.create_builder_config()
|
531 |
+
config.max_workspace_size = self.args.workspace * 1 << 30
|
532 |
+
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
|
533 |
+
|
534 |
+
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
535 |
+
network = builder.create_network(flag)
|
536 |
+
parser = trt.OnnxParser(network, logger)
|
537 |
+
if not parser.parse_from_file(f_onnx):
|
538 |
+
raise RuntimeError(f'failed to load ONNX file: {f_onnx}')
|
539 |
+
|
540 |
+
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
541 |
+
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
542 |
+
for inp in inputs:
|
543 |
+
LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
|
544 |
+
for out in outputs:
|
545 |
+
LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
|
546 |
+
|
547 |
+
if self.args.dynamic:
|
548 |
+
shape = self.im.shape
|
549 |
+
if shape[0] <= 1:
|
550 |
+
LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
|
551 |
+
profile = builder.create_optimization_profile()
|
552 |
+
for inp in inputs:
|
553 |
+
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
|
554 |
+
config.add_optimization_profile(profile)
|
555 |
+
|
556 |
+
LOGGER.info(
|
557 |
+
f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}')
|
558 |
+
if builder.platform_has_fast_fp16 and self.args.half:
|
559 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
560 |
+
|
561 |
+
# Write file
|
562 |
+
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
563 |
+
# Metadata
|
564 |
+
meta = json.dumps(self.metadata)
|
565 |
+
t.write(len(meta).to_bytes(4, byteorder='little', signed=True))
|
566 |
+
t.write(meta.encode())
|
567 |
+
# Model
|
568 |
+
t.write(engine.serialize())
|
569 |
+
|
570 |
+
return f, None
|
571 |
+
|
572 |
+
@try_export
|
573 |
+
def export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')):
|
574 |
+
"""YOLOv8 TensorFlow SavedModel export."""
|
575 |
+
try:
|
576 |
+
import tensorflow as tf # noqa
|
577 |
+
except ImportError:
|
578 |
+
cuda = torch.cuda.is_available()
|
579 |
+
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
|
580 |
+
import tensorflow as tf # noqa
|
581 |
+
check_requirements(('onnx', 'onnx2tf>=1.9.1', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.17', 'onnx_graphsurgeon>=0.3.26',
|
582 |
+
'tflite_support', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
|
583 |
+
cmds='--extra-index-url https://pypi.ngc.nvidia.com')
|
584 |
+
|
585 |
+
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
586 |
+
f = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
|
587 |
+
if f.is_dir():
|
588 |
+
import shutil
|
589 |
+
shutil.rmtree(f) # delete output folder
|
590 |
+
|
591 |
+
# Export to ONNX
|
592 |
+
self.args.simplify = True
|
593 |
+
f_onnx, _ = self.export_onnx()
|
594 |
+
|
595 |
+
# Export to TF
|
596 |
+
tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file
|
597 |
+
if self.args.int8:
|
598 |
+
if self.args.data:
|
599 |
+
import numpy as np
|
600 |
+
|
601 |
+
from ultralytics.data.dataset import YOLODataset
|
602 |
+
from ultralytics.data.utils import check_det_dataset
|
603 |
+
|
604 |
+
# Generate calibration data for integer quantization
|
605 |
+
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
606 |
+
dataset = YOLODataset(check_det_dataset(self.args.data)['val'], imgsz=self.imgsz[0], augment=False)
|
607 |
+
images = []
|
608 |
+
n_images = 100 # maximum number of images
|
609 |
+
for n, batch in enumerate(dataset):
|
610 |
+
if n >= n_images:
|
611 |
+
break
|
612 |
+
im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC,
|
613 |
+
images.append(im)
|
614 |
+
f.mkdir()
|
615 |
+
images = torch.cat(images, 0).float()
|
616 |
+
# mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
|
617 |
+
# std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
|
618 |
+
np.save(str(tmp_file), images.numpy()) # BHWC
|
619 |
+
int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
|
620 |
+
else:
|
621 |
+
int8 = '-oiqt -qt per-tensor'
|
622 |
+
else:
|
623 |
+
int8 = ''
|
624 |
+
|
625 |
+
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'.strip()
|
626 |
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
627 |
+
subprocess.run(cmd, shell=True)
|
628 |
+
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
629 |
+
|
630 |
+
# Remove/rename TFLite models
|
631 |
+
if self.args.int8:
|
632 |
+
tmp_file.unlink(missing_ok=True)
|
633 |
+
for file in f.rglob('*_dynamic_range_quant.tflite'):
|
634 |
+
file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
|
635 |
+
for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
|
636 |
+
file.unlink() # delete extra fp16 activation TFLite files
|
637 |
+
|
638 |
+
# Add TFLite metadata
|
639 |
+
for file in f.rglob('*.tflite'):
|
640 |
+
f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file)
|
641 |
+
|
642 |
+
# Load saved_model
|
643 |
+
keras_model = tf.saved_model.load(f, tags=None, options=None)
|
644 |
+
|
645 |
+
return str(f), keras_model
|
646 |
+
|
647 |
+
@try_export
|
648 |
+
def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
|
649 |
+
"""YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
|
650 |
+
import tensorflow as tf # noqa
|
651 |
+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
652 |
+
|
653 |
+
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
654 |
+
f = self.file.with_suffix('.pb')
|
655 |
+
|
656 |
+
m = tf.function(lambda x: keras_model(x)) # full model
|
657 |
+
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
658 |
+
frozen_func = convert_variables_to_constants_v2(m)
|
659 |
+
frozen_func.graph.as_graph_def()
|
660 |
+
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
|
661 |
+
return f, None
|
662 |
+
|
663 |
+
@try_export
|
664 |
+
def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
|
665 |
+
"""YOLOv8 TensorFlow Lite export."""
|
666 |
+
import tensorflow as tf # noqa
|
667 |
+
|
668 |
+
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
669 |
+
saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
|
670 |
+
if self.args.int8:
|
671 |
+
f = saved_model / f'{self.file.stem}_int8.tflite' # fp32 in/out
|
672 |
+
elif self.args.half:
|
673 |
+
f = saved_model / f'{self.file.stem}_float16.tflite' # fp32 in/out
|
674 |
+
else:
|
675 |
+
f = saved_model / f'{self.file.stem}_float32.tflite'
|
676 |
+
return str(f), None
|
677 |
+
|
678 |
+
@try_export
|
679 |
+
def export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
|
680 |
+
"""YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
|
681 |
+
LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185')
|
682 |
+
|
683 |
+
cmd = 'edgetpu_compiler --version'
|
684 |
+
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
|
685 |
+
assert LINUX, f'export only supported on Linux. See {help_url}'
|
686 |
+
if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
|
687 |
+
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
|
688 |
+
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
|
689 |
+
for c in (
|
690 |
+
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
|
691 |
+
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
|
692 |
+
'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
|
693 |
+
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
|
694 |
+
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
|
695 |
+
|
696 |
+
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
|
697 |
+
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
|
698 |
+
|
699 |
+
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir "{Path(f).parent}" "{tflite_model}"'
|
700 |
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
701 |
+
subprocess.run(cmd, shell=True)
|
702 |
+
self._add_tflite_metadata(f)
|
703 |
+
return f, None
|
704 |
+
|
705 |
+
@try_export
|
706 |
+
def export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
|
707 |
+
"""YOLOv8 TensorFlow.js export."""
|
708 |
+
check_requirements('tensorflowjs')
|
709 |
+
import tensorflow as tf
|
710 |
+
import tensorflowjs as tfjs # noqa
|
711 |
+
|
712 |
+
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
|
713 |
+
f = str(self.file).replace(self.file.suffix, '_web_model') # js dir
|
714 |
+
f_pb = str(self.file.with_suffix('.pb')) # *.pb path
|
715 |
+
|
716 |
+
gd = tf.Graph().as_graph_def() # TF GraphDef
|
717 |
+
with open(f_pb, 'rb') as file:
|
718 |
+
gd.ParseFromString(file.read())
|
719 |
+
outputs = ','.join(gd_outputs(gd))
|
720 |
+
LOGGER.info(f'\n{prefix} output node names: {outputs}')
|
721 |
+
|
722 |
+
with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
|
723 |
+
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model --output_node_names={outputs} "{fpb_}" "{f_}"'
|
724 |
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
725 |
+
subprocess.run(cmd, shell=True)
|
726 |
+
|
727 |
+
if ' ' in str(f):
|
728 |
+
LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.")
|
729 |
+
|
730 |
+
# f_json = Path(f) / 'model.json' # *.json path
|
731 |
+
# with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
|
732 |
+
# subst = re.sub(
|
733 |
+
# r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
|
734 |
+
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
735 |
+
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
736 |
+
# r'"Identity.?.?": {"name": "Identity.?.?"}}}',
|
737 |
+
# r'{"outputs": {"Identity": {"name": "Identity"}, '
|
738 |
+
# r'"Identity_1": {"name": "Identity_1"}, '
|
739 |
+
# r'"Identity_2": {"name": "Identity_2"}, '
|
740 |
+
# r'"Identity_3": {"name": "Identity_3"}}}',
|
741 |
+
# f_json.read_text(),
|
742 |
+
# )
|
743 |
+
# j.write(subst)
|
744 |
+
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
745 |
+
return f, None
|
746 |
+
|
747 |
+
def _add_tflite_metadata(self, file):
|
748 |
+
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
|
749 |
+
from tflite_support import flatbuffers # noqa
|
750 |
+
from tflite_support import metadata as _metadata # noqa
|
751 |
+
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
752 |
+
|
753 |
+
# Create model info
|
754 |
+
model_meta = _metadata_fb.ModelMetadataT()
|
755 |
+
model_meta.name = self.metadata['description']
|
756 |
+
model_meta.version = self.metadata['version']
|
757 |
+
model_meta.author = self.metadata['author']
|
758 |
+
model_meta.license = self.metadata['license']
|
759 |
+
|
760 |
+
# Label file
|
761 |
+
tmp_file = Path(file).parent / 'temp_meta.txt'
|
762 |
+
with open(tmp_file, 'w') as f:
|
763 |
+
f.write(str(self.metadata))
|
764 |
+
|
765 |
+
label_file = _metadata_fb.AssociatedFileT()
|
766 |
+
label_file.name = tmp_file.name
|
767 |
+
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
|
768 |
+
|
769 |
+
# Create input info
|
770 |
+
input_meta = _metadata_fb.TensorMetadataT()
|
771 |
+
input_meta.name = 'image'
|
772 |
+
input_meta.description = 'Input image to be detected.'
|
773 |
+
input_meta.content = _metadata_fb.ContentT()
|
774 |
+
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
|
775 |
+
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
776 |
+
input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
|
777 |
+
|
778 |
+
# Create output info
|
779 |
+
output1 = _metadata_fb.TensorMetadataT()
|
780 |
+
output1.name = 'output'
|
781 |
+
output1.description = 'Coordinates of detected objects, class labels, and confidence score'
|
782 |
+
output1.associatedFiles = [label_file]
|
783 |
+
if self.model.task == 'segment':
|
784 |
+
output2 = _metadata_fb.TensorMetadataT()
|
785 |
+
output2.name = 'output'
|
786 |
+
output2.description = 'Mask protos'
|
787 |
+
output2.associatedFiles = [label_file]
|
788 |
+
|
789 |
+
# Create subgraph info
|
790 |
+
subgraph = _metadata_fb.SubGraphMetadataT()
|
791 |
+
subgraph.inputTensorMetadata = [input_meta]
|
792 |
+
subgraph.outputTensorMetadata = [output1, output2] if self.model.task == 'segment' else [output1]
|
793 |
+
model_meta.subgraphMetadata = [subgraph]
|
794 |
+
|
795 |
+
b = flatbuffers.Builder(0)
|
796 |
+
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
797 |
+
metadata_buf = b.Output()
|
798 |
+
|
799 |
+
populator = _metadata.MetadataPopulator.with_model_file(str(file))
|
800 |
+
populator.load_metadata_buffer(metadata_buf)
|
801 |
+
populator.load_associated_files([str(tmp_file)])
|
802 |
+
populator.populate()
|
803 |
+
tmp_file.unlink()
|
804 |
+
|
805 |
+
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
|
806 |
+
"""YOLOv8 CoreML pipeline."""
|
807 |
+
import coremltools as ct # noqa
|
808 |
+
|
809 |
+
LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
|
810 |
+
batch_size, ch, h, w = list(self.im.shape) # BCHW
|
811 |
+
|
812 |
+
# Output shapes
|
813 |
+
spec = model.get_spec()
|
814 |
+
out0, out1 = iter(spec.description.output)
|
815 |
+
if MACOS:
|
816 |
+
from PIL import Image
|
817 |
+
img = Image.new('RGB', (w, h)) # img(192 width, 320 height)
|
818 |
+
# img = torch.zeros((*opt.img_size, 3)).numpy() # img size(320,192,3) iDetection
|
819 |
+
out = model.predict({'image': img})
|
820 |
+
out0_shape = out[out0.name].shape
|
821 |
+
out1_shape = out[out1.name].shape
|
822 |
+
else: # linux and windows can not run model.predict(), get sizes from pytorch output y
|
823 |
+
out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
|
824 |
+
out1_shape = self.output_shape[2], 4 # (3780, 4)
|
825 |
+
|
826 |
+
# Checks
|
827 |
+
names = self.metadata['names']
|
828 |
+
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
|
829 |
+
na, nc = out0_shape
|
830 |
+
# na, nc = out0.type.multiArrayType.shape # number anchors, classes
|
831 |
+
assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check
|
832 |
+
|
833 |
+
# Define output shapes (missing)
|
834 |
+
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
|
835 |
+
out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
|
836 |
+
# spec.neuralNetwork.preprocessing[0].featureName = '0'
|
837 |
+
|
838 |
+
# Flexible input shapes
|
839 |
+
# from coremltools.models.neural_network import flexible_shape_utils
|
840 |
+
# s = [] # shapes
|
841 |
+
# s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
|
842 |
+
# s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
|
843 |
+
# flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
|
844 |
+
# r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
|
845 |
+
# r.add_height_range((192, 640))
|
846 |
+
# r.add_width_range((192, 640))
|
847 |
+
# flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
|
848 |
+
|
849 |
+
# Print
|
850 |
+
# print(spec.description)
|
851 |
+
|
852 |
+
# Model from spec
|
853 |
+
model = ct.models.MLModel(spec)
|
854 |
+
|
855 |
+
# 3. Create NMS protobuf
|
856 |
+
nms_spec = ct.proto.Model_pb2.Model()
|
857 |
+
nms_spec.specificationVersion = 5
|
858 |
+
for i in range(2):
|
859 |
+
decoder_output = model._spec.description.output[i].SerializeToString()
|
860 |
+
nms_spec.description.input.add()
|
861 |
+
nms_spec.description.input[i].ParseFromString(decoder_output)
|
862 |
+
nms_spec.description.output.add()
|
863 |
+
nms_spec.description.output[i].ParseFromString(decoder_output)
|
864 |
+
|
865 |
+
nms_spec.description.output[0].name = 'confidence'
|
866 |
+
nms_spec.description.output[1].name = 'coordinates'
|
867 |
+
|
868 |
+
output_sizes = [nc, 4]
|
869 |
+
for i in range(2):
|
870 |
+
ma_type = nms_spec.description.output[i].type.multiArrayType
|
871 |
+
ma_type.shapeRange.sizeRanges.add()
|
872 |
+
ma_type.shapeRange.sizeRanges[0].lowerBound = 0
|
873 |
+
ma_type.shapeRange.sizeRanges[0].upperBound = -1
|
874 |
+
ma_type.shapeRange.sizeRanges.add()
|
875 |
+
ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
|
876 |
+
ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
|
877 |
+
del ma_type.shape[:]
|
878 |
+
|
879 |
+
nms = nms_spec.nonMaximumSuppression
|
880 |
+
nms.confidenceInputFeatureName = out0.name # 1x507x80
|
881 |
+
nms.coordinatesInputFeatureName = out1.name # 1x507x4
|
882 |
+
nms.confidenceOutputFeatureName = 'confidence'
|
883 |
+
nms.coordinatesOutputFeatureName = 'coordinates'
|
884 |
+
nms.iouThresholdInputFeatureName = 'iouThreshold'
|
885 |
+
nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
|
886 |
+
nms.iouThreshold = 0.45
|
887 |
+
nms.confidenceThreshold = 0.25
|
888 |
+
nms.pickTop.perClass = True
|
889 |
+
nms.stringClassLabels.vector.extend(names.values())
|
890 |
+
nms_model = ct.models.MLModel(nms_spec)
|
891 |
+
|
892 |
+
# 4. Pipeline models together
|
893 |
+
pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)),
|
894 |
+
('iouThreshold', ct.models.datatypes.Double()),
|
895 |
+
('confidenceThreshold', ct.models.datatypes.Double())],
|
896 |
+
output_features=['confidence', 'coordinates'])
|
897 |
+
pipeline.add_model(model)
|
898 |
+
pipeline.add_model(nms_model)
|
899 |
+
|
900 |
+
# Correct datatypes
|
901 |
+
pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
|
902 |
+
pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
|
903 |
+
pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
|
904 |
+
|
905 |
+
# Update metadata
|
906 |
+
pipeline.spec.specificationVersion = 5
|
907 |
+
pipeline.spec.description.metadata.userDefined.update({
|
908 |
+
'IoU threshold': str(nms.iouThreshold),
|
909 |
+
'Confidence threshold': str(nms.confidenceThreshold)})
|
910 |
+
|
911 |
+
# Save the model
|
912 |
+
model = ct.models.MLModel(pipeline.spec)
|
913 |
+
model.input_description['image'] = 'Input image'
|
914 |
+
model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})'
|
915 |
+
model.input_description['confidenceThreshold'] = \
|
916 |
+
f'(optional) Confidence threshold override (default: {nms.confidenceThreshold})'
|
917 |
+
model.output_description['confidence'] = 'Boxes × Class confidence (see user-defined metadata "classes")'
|
918 |
+
model.output_description['coordinates'] = 'Boxes × [x, y, width, height] (relative to image size)'
|
919 |
+
LOGGER.info(f'{prefix} pipeline success')
|
920 |
+
return model
|
921 |
+
|
922 |
+
def add_callback(self, event: str, callback):
|
923 |
+
"""
|
924 |
+
Appends the given callback.
|
925 |
+
"""
|
926 |
+
self.callbacks[event].append(callback)
|
927 |
+
|
928 |
+
def run_callbacks(self, event: str):
|
929 |
+
"""Execute all callbacks for a given event."""
|
930 |
+
for callback in self.callbacks.get(event, []):
|
931 |
+
callback(self)
|
932 |
+
|
933 |
+
|
934 |
+
class iOSDetectModel(torch.nn.Module):
|
935 |
+
"""Wrap an Ultralytics YOLO model for iOS export."""
|
936 |
+
|
937 |
+
def __init__(self, model, im):
|
938 |
+
"""Initialize the iOSDetectModel class with a YOLO model and example image."""
|
939 |
+
super().__init__()
|
940 |
+
b, c, h, w = im.shape # batch, channel, height, width
|
941 |
+
self.model = model
|
942 |
+
self.nc = len(model.names) # number of classes
|
943 |
+
if w == h:
|
944 |
+
self.normalize = 1.0 / w # scalar
|
945 |
+
else:
|
946 |
+
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
|
947 |
+
|
948 |
+
def forward(self, x):
|
949 |
+
"""Normalize predictions of object detection model with input size-dependent factors."""
|
950 |
+
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
|
951 |
+
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
952 |
+
|
953 |
+
|
954 |
+
def export(cfg=DEFAULT_CFG):
|
955 |
+
"""Export a YOLOv model to a specific format."""
|
956 |
+
cfg.model = cfg.model or 'yolov8n.yaml'
|
957 |
+
cfg.format = cfg.format or 'torchscript'
|
958 |
+
|
959 |
+
from ultralytics import YOLO
|
960 |
+
model = YOLO(cfg.model)
|
961 |
+
model.export(**vars(cfg))
|
962 |
+
|
963 |
+
|
964 |
+
if __name__ == '__main__':
|
965 |
+
"""
|
966 |
+
CLI:
|
967 |
+
yolo mode=export model=yolov8n.yaml format=onnx
|
968 |
+
"""
|
969 |
+
export()
|
ultralytics/engine/model.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
import sys
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
from ultralytics.cfg import get_cfg
|
9 |
+
from ultralytics.engine.exporter import Exporter
|
10 |
+
from ultralytics.hub.utils import HUB_WEB_ROOT
|
11 |
+
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
12 |
+
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
13 |
+
is_git_dir, yaml_load)
|
14 |
+
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
15 |
+
from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
|
16 |
+
from ultralytics.utils.torch_utils import smart_inference_mode
|
17 |
+
|
18 |
+
|
19 |
+
class Model:
|
20 |
+
"""
|
21 |
+
A base model class to unify apis for all the models.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
model (str, Path): Path to the model file to load or create.
|
25 |
+
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
26 |
+
|
27 |
+
Attributes:
|
28 |
+
predictor (Any): The predictor object.
|
29 |
+
model (Any): The model object.
|
30 |
+
trainer (Any): The trainer object.
|
31 |
+
task (str): The type of model task.
|
32 |
+
ckpt (Any): The checkpoint object if the model loaded from *.pt file.
|
33 |
+
cfg (str): The model configuration if loaded from *.yaml file.
|
34 |
+
ckpt_path (str): The checkpoint file path.
|
35 |
+
overrides (dict): Overrides for the trainer object.
|
36 |
+
metrics (Any): The data for metrics.
|
37 |
+
|
38 |
+
Methods:
|
39 |
+
__call__(source=None, stream=False, **kwargs):
|
40 |
+
Alias for the predict method.
|
41 |
+
_new(cfg:str, verbose:bool=True) -> None:
|
42 |
+
Initializes a new model and infers the task type from the model definitions.
|
43 |
+
_load(weights:str, task:str='') -> None:
|
44 |
+
Initializes a new model and infers the task type from the model head.
|
45 |
+
_check_is_pytorch_model() -> None:
|
46 |
+
Raises TypeError if the model is not a PyTorch model.
|
47 |
+
reset() -> None:
|
48 |
+
Resets the model modules.
|
49 |
+
info(verbose:bool=False) -> None:
|
50 |
+
Logs the model info.
|
51 |
+
fuse() -> None:
|
52 |
+
Fuses the model for faster inference.
|
53 |
+
predict(source=None, stream=False, **kwargs) -> List[ultralytics.engine.results.Results]:
|
54 |
+
Performs prediction using the YOLO model.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
list(ultralytics.engine.results.Results): The prediction results.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
|
61 |
+
"""
|
62 |
+
Initializes the YOLO model.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
|
66 |
+
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
67 |
+
"""
|
68 |
+
self.callbacks = callbacks.get_default_callbacks()
|
69 |
+
self.predictor = None # reuse predictor
|
70 |
+
self.model = None # model object
|
71 |
+
self.trainer = None # trainer object
|
72 |
+
self.ckpt = None # if loaded from *.pt
|
73 |
+
self.cfg = None # if loaded from *.yaml
|
74 |
+
self.ckpt_path = None
|
75 |
+
self.overrides = {} # overrides for trainer object
|
76 |
+
self.metrics = None # validation/training metrics
|
77 |
+
self.session = None # HUB session
|
78 |
+
self.task = task # task type
|
79 |
+
model = str(model).strip() # strip spaces
|
80 |
+
|
81 |
+
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
82 |
+
if self.is_hub_model(model):
|
83 |
+
from ultralytics.hub.session import HUBTrainingSession
|
84 |
+
self.session = HUBTrainingSession(model)
|
85 |
+
model = self.session.model_file
|
86 |
+
|
87 |
+
# Load or create new YOLO model
|
88 |
+
suffix = Path(model).suffix
|
89 |
+
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
90 |
+
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
91 |
+
if suffix in ('.yaml', '.yml'):
|
92 |
+
self._new(model, task)
|
93 |
+
else:
|
94 |
+
self._load(model, task)
|
95 |
+
|
96 |
+
def __call__(self, source=None, stream=False, **kwargs):
|
97 |
+
"""Calls the 'predict' function with given arguments to perform object detection."""
|
98 |
+
return self.predict(source, stream, **kwargs)
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def is_hub_model(model):
|
102 |
+
"""Check if the provided model is a HUB model."""
|
103 |
+
return any((
|
104 |
+
model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
105 |
+
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
106 |
+
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
|
107 |
+
|
108 |
+
def _new(self, cfg: str, task=None, model=None, verbose=True):
|
109 |
+
"""
|
110 |
+
Initializes a new model and infers the task type from the model definitions.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
cfg (str): model configuration file
|
114 |
+
task (str | None): model task
|
115 |
+
model (BaseModel): Customized model.
|
116 |
+
verbose (bool): display model info on load
|
117 |
+
"""
|
118 |
+
cfg_dict = yaml_model_load(cfg)
|
119 |
+
self.cfg = cfg
|
120 |
+
self.task = task or guess_model_task(cfg_dict)
|
121 |
+
model = model or self.smart_load('model')
|
122 |
+
self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
|
123 |
+
self.overrides['model'] = self.cfg
|
124 |
+
|
125 |
+
# Below added to allow export from yamls
|
126 |
+
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
|
127 |
+
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
128 |
+
self.model.task = self.task
|
129 |
+
|
130 |
+
def _load(self, weights: str, task=None):
|
131 |
+
"""
|
132 |
+
Initializes a new model and infers the task type from the model head.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
weights (str): model checkpoint to be loaded
|
136 |
+
task (str | None): model task
|
137 |
+
"""
|
138 |
+
suffix = Path(weights).suffix
|
139 |
+
if suffix == '.pt':
|
140 |
+
self.model, self.ckpt = attempt_load_one_weight(weights)
|
141 |
+
self.task = self.model.args['task']
|
142 |
+
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
143 |
+
self.ckpt_path = self.model.pt_path
|
144 |
+
else:
|
145 |
+
weights = check_file(weights)
|
146 |
+
self.model, self.ckpt = weights, None
|
147 |
+
self.task = task or guess_model_task(weights)
|
148 |
+
self.ckpt_path = weights
|
149 |
+
self.overrides['model'] = weights
|
150 |
+
self.overrides['task'] = self.task
|
151 |
+
|
152 |
+
def _check_is_pytorch_model(self):
|
153 |
+
"""
|
154 |
+
Raises TypeError is model is not a PyTorch model
|
155 |
+
"""
|
156 |
+
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
|
157 |
+
pt_module = isinstance(self.model, nn.Module)
|
158 |
+
if not (pt_module or pt_str):
|
159 |
+
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
|
160 |
+
f'PyTorch models can be used to train, val, predict and export, i.e. '
|
161 |
+
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
162 |
+
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
163 |
+
|
164 |
+
@smart_inference_mode()
|
165 |
+
def reset_weights(self):
|
166 |
+
"""
|
167 |
+
Resets the model modules parameters to randomly initialized values, losing all training information.
|
168 |
+
"""
|
169 |
+
self._check_is_pytorch_model()
|
170 |
+
for m in self.model.modules():
|
171 |
+
if hasattr(m, 'reset_parameters'):
|
172 |
+
m.reset_parameters()
|
173 |
+
for p in self.model.parameters():
|
174 |
+
p.requires_grad = True
|
175 |
+
return self
|
176 |
+
|
177 |
+
@smart_inference_mode()
|
178 |
+
def load(self, weights='yolov8n.pt'):
|
179 |
+
"""
|
180 |
+
Transfers parameters with matching names and shapes from 'weights' to model.
|
181 |
+
"""
|
182 |
+
self._check_is_pytorch_model()
|
183 |
+
if isinstance(weights, (str, Path)):
|
184 |
+
weights, self.ckpt = attempt_load_one_weight(weights)
|
185 |
+
self.model.load(weights)
|
186 |
+
return self
|
187 |
+
|
188 |
+
def info(self, detailed=False, verbose=True):
|
189 |
+
"""
|
190 |
+
Logs model info.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
detailed (bool): Show detailed information about model.
|
194 |
+
verbose (bool): Controls verbosity.
|
195 |
+
"""
|
196 |
+
self._check_is_pytorch_model()
|
197 |
+
return self.model.info(detailed=detailed, verbose=verbose)
|
198 |
+
|
199 |
+
def fuse(self):
|
200 |
+
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
|
201 |
+
self._check_is_pytorch_model()
|
202 |
+
self.model.fuse()
|
203 |
+
|
204 |
+
@smart_inference_mode()
|
205 |
+
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
206 |
+
"""
|
207 |
+
Perform prediction using the YOLO model.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
211 |
+
Accepts all source types accepted by the YOLO model.
|
212 |
+
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
213 |
+
predictor (BasePredictor): Customized predictor.
|
214 |
+
**kwargs : Additional keyword arguments passed to the predictor.
|
215 |
+
Check the 'configuration' section in the documentation for all available options.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
(List[ultralytics.engine.results.Results]): The prediction results.
|
219 |
+
"""
|
220 |
+
if source is None:
|
221 |
+
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
222 |
+
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
223 |
+
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
|
224 |
+
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
|
225 |
+
# Check prompts for SAM/FastSAM
|
226 |
+
prompts = kwargs.pop('prompts', None)
|
227 |
+
overrides = self.overrides.copy()
|
228 |
+
overrides['conf'] = 0.25
|
229 |
+
overrides.update(kwargs) # prefer kwargs
|
230 |
+
overrides['mode'] = kwargs.get('mode', 'predict')
|
231 |
+
assert overrides['mode'] in ['track', 'predict']
|
232 |
+
if not is_cli:
|
233 |
+
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
234 |
+
if not self.predictor:
|
235 |
+
self.task = overrides.get('task') or self.task
|
236 |
+
predictor = predictor or self.smart_load('predictor')
|
237 |
+
self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks)
|
238 |
+
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
239 |
+
else: # only update args if predictor is already setup
|
240 |
+
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
241 |
+
if 'project' in overrides or 'name' in overrides:
|
242 |
+
self.predictor.save_dir = self.predictor.get_save_dir()
|
243 |
+
# Set prompts for SAM/FastSAM
|
244 |
+
if len and hasattr(self.predictor, 'set_prompts'):
|
245 |
+
self.predictor.set_prompts(prompts)
|
246 |
+
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
247 |
+
|
248 |
+
def track(self, source=None, stream=False, persist=False, **kwargs):
|
249 |
+
"""
|
250 |
+
Perform object tracking on the input source using the registered trackers.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
source (str, optional): The input source for object tracking. Can be a file path or a video stream.
|
254 |
+
stream (bool, optional): Whether the input source is a video stream. Defaults to False.
|
255 |
+
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
256 |
+
**kwargs (optional): Additional keyword arguments for the tracking process.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
(List[ultralytics.engine.results.Results]): The tracking results.
|
260 |
+
|
261 |
+
"""
|
262 |
+
if not hasattr(self.predictor, 'trackers'):
|
263 |
+
from ultralytics.trackers import register_tracker
|
264 |
+
register_tracker(self, persist)
|
265 |
+
# ByteTrack-based method needs low confidence predictions as input
|
266 |
+
conf = kwargs.get('conf') or 0.1
|
267 |
+
kwargs['conf'] = conf
|
268 |
+
kwargs['mode'] = 'track'
|
269 |
+
return self.predict(source=source, stream=stream, **kwargs)
|
270 |
+
|
271 |
+
@smart_inference_mode()
|
272 |
+
def val(self, data=None, validator=None, **kwargs):
|
273 |
+
"""
|
274 |
+
Validate a model on a given dataset.
|
275 |
+
|
276 |
+
Args:
|
277 |
+
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
278 |
+
validator (BaseValidator): Customized validator.
|
279 |
+
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
280 |
+
"""
|
281 |
+
overrides = self.overrides.copy()
|
282 |
+
overrides['rect'] = True # rect batches as default
|
283 |
+
overrides.update(kwargs)
|
284 |
+
overrides['mode'] = 'val'
|
285 |
+
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
286 |
+
args.data = data or args.data
|
287 |
+
if 'task' in overrides:
|
288 |
+
self.task = args.task
|
289 |
+
else:
|
290 |
+
args.task = self.task
|
291 |
+
validator = validator or self.smart_load('validator')
|
292 |
+
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
|
293 |
+
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
294 |
+
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
295 |
+
|
296 |
+
validator = validator(args=args, _callbacks=self.callbacks)
|
297 |
+
validator(model=self.model)
|
298 |
+
self.metrics = validator.metrics
|
299 |
+
|
300 |
+
return validator.metrics
|
301 |
+
|
302 |
+
@smart_inference_mode()
|
303 |
+
def benchmark(self, **kwargs):
|
304 |
+
"""
|
305 |
+
Benchmark a model on all export formats.
|
306 |
+
|
307 |
+
Args:
|
308 |
+
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
309 |
+
"""
|
310 |
+
self._check_is_pytorch_model()
|
311 |
+
from ultralytics.utils.benchmarks import benchmark
|
312 |
+
overrides = self.model.args.copy()
|
313 |
+
overrides.update(kwargs)
|
314 |
+
overrides['mode'] = 'benchmark'
|
315 |
+
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
|
316 |
+
return benchmark(
|
317 |
+
model=self,
|
318 |
+
data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
|
319 |
+
imgsz=overrides['imgsz'],
|
320 |
+
half=overrides['half'],
|
321 |
+
int8=overrides['int8'],
|
322 |
+
device=overrides['device'],
|
323 |
+
verbose=overrides['verbose'])
|
324 |
+
|
325 |
+
def export(self, **kwargs):
|
326 |
+
"""
|
327 |
+
Export model.
|
328 |
+
|
329 |
+
Args:
|
330 |
+
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
331 |
+
"""
|
332 |
+
self._check_is_pytorch_model()
|
333 |
+
overrides = self.overrides.copy()
|
334 |
+
overrides.update(kwargs)
|
335 |
+
overrides['mode'] = 'export'
|
336 |
+
if overrides.get('imgsz') is None:
|
337 |
+
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
338 |
+
if 'batch' not in kwargs:
|
339 |
+
overrides['batch'] = 1 # default to 1 if not modified
|
340 |
+
if 'data' not in kwargs:
|
341 |
+
overrides['data'] = None # default to None if not modified (avoid int8 calibration with coco.yaml)
|
342 |
+
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
343 |
+
args.task = self.task
|
344 |
+
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
345 |
+
|
346 |
+
def train(self, trainer=None, **kwargs):
|
347 |
+
"""
|
348 |
+
Trains the model on a given dataset.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
trainer (BaseTrainer, optional): Customized trainer.
|
352 |
+
**kwargs (Any): Any number of arguments representing the training configuration.
|
353 |
+
"""
|
354 |
+
self._check_is_pytorch_model()
|
355 |
+
if self.session: # Ultralytics HUB session
|
356 |
+
if any(kwargs):
|
357 |
+
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
358 |
+
kwargs = self.session.train_args
|
359 |
+
check_pip_update_available()
|
360 |
+
overrides = self.overrides.copy()
|
361 |
+
if kwargs.get('cfg'):
|
362 |
+
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
363 |
+
overrides = yaml_load(check_yaml(kwargs['cfg']))
|
364 |
+
overrides.update(kwargs)
|
365 |
+
overrides['mode'] = 'train'
|
366 |
+
if not overrides.get('data'):
|
367 |
+
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
368 |
+
if overrides.get('resume'):
|
369 |
+
overrides['resume'] = self.ckpt_path
|
370 |
+
self.task = overrides.get('task') or self.task
|
371 |
+
trainer = trainer or self.smart_load('trainer')
|
372 |
+
self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
|
373 |
+
if not overrides.get('resume'): # manually set model only if not resuming
|
374 |
+
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
375 |
+
self.model = self.trainer.model
|
376 |
+
self.trainer.hub_session = self.session # attach optional HUB session
|
377 |
+
self.trainer.train()
|
378 |
+
# Update model and cfg after training
|
379 |
+
if RANK in (-1, 0):
|
380 |
+
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
381 |
+
self.overrides = self.model.args
|
382 |
+
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
|
383 |
+
|
384 |
+
def to(self, device):
|
385 |
+
"""
|
386 |
+
Sends the model to the given device.
|
387 |
+
|
388 |
+
Args:
|
389 |
+
device (str): device
|
390 |
+
"""
|
391 |
+
self._check_is_pytorch_model()
|
392 |
+
self.model.to(device)
|
393 |
+
|
394 |
+
def tune(self, *args, **kwargs):
|
395 |
+
"""
|
396 |
+
Runs hyperparameter tuning using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args.
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
(dict): A dictionary containing the results of the hyperparameter search.
|
400 |
+
|
401 |
+
Raises:
|
402 |
+
ModuleNotFoundError: If Ray Tune is not installed.
|
403 |
+
"""
|
404 |
+
self._check_is_pytorch_model()
|
405 |
+
from ultralytics.utils.tuner import run_ray_tune
|
406 |
+
return run_ray_tune(self, *args, **kwargs)
|
407 |
+
|
408 |
+
@property
|
409 |
+
def names(self):
|
410 |
+
"""Returns class names of the loaded model."""
|
411 |
+
return self.model.names if hasattr(self.model, 'names') else None
|
412 |
+
|
413 |
+
@property
|
414 |
+
def device(self):
|
415 |
+
"""Returns device if PyTorch model."""
|
416 |
+
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
417 |
+
|
418 |
+
@property
|
419 |
+
def transforms(self):
|
420 |
+
"""Returns transform of the loaded model."""
|
421 |
+
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
422 |
+
|
423 |
+
def add_callback(self, event: str, func):
|
424 |
+
"""Add a callback."""
|
425 |
+
self.callbacks[event].append(func)
|
426 |
+
|
427 |
+
def clear_callback(self, event: str):
|
428 |
+
"""Clear all event callbacks."""
|
429 |
+
self.callbacks[event] = []
|
430 |
+
|
431 |
+
@staticmethod
|
432 |
+
def _reset_ckpt_args(args):
|
433 |
+
"""Reset arguments when loading a PyTorch model."""
|
434 |
+
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
435 |
+
return {k: v for k, v in args.items() if k in include}
|
436 |
+
|
437 |
+
def _reset_callbacks(self):
|
438 |
+
"""Reset all registered callbacks."""
|
439 |
+
for event in callbacks.default_callbacks.keys():
|
440 |
+
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
441 |
+
|
442 |
+
def __getattr__(self, attr):
|
443 |
+
"""Raises error if object has no requested attribute."""
|
444 |
+
name = self.__class__.__name__
|
445 |
+
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
446 |
+
|
447 |
+
def smart_load(self, key):
|
448 |
+
"""Load model/trainer/validator/predictor."""
|
449 |
+
try:
|
450 |
+
return self.task_map[self.task][key]
|
451 |
+
except Exception:
|
452 |
+
name = self.__class__.__name__
|
453 |
+
mode = inspect.stack()[1][3] # get the function name.
|
454 |
+
raise NotImplementedError(
|
455 |
+
f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')
|
456 |
+
|
457 |
+
@property
|
458 |
+
def task_map(self):
|
459 |
+
"""
|
460 |
+
Map head to model, trainer, validator, and predictor classes.
|
461 |
+
|
462 |
+
Returns:
|
463 |
+
task_map (dict): The map of model task to mode classes.
|
464 |
+
"""
|
465 |
+
raise NotImplementedError('Please provide task map for your model!')
|
ultralytics/engine/predictor.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
4 |
+
|
5 |
+
Usage - sources:
|
6 |
+
$ yolo mode=predict model=yolov8n.pt source=0 # webcam
|
7 |
+
img.jpg # image
|
8 |
+
vid.mp4 # video
|
9 |
+
screen # screenshot
|
10 |
+
path/ # directory
|
11 |
+
list.txt # list of images
|
12 |
+
list.streams # list of streams
|
13 |
+
'path/*.jpg' # glob
|
14 |
+
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
15 |
+
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
16 |
+
|
17 |
+
Usage - formats:
|
18 |
+
$ yolo mode=predict model=yolov8n.pt # PyTorch
|
19 |
+
yolov8n.torchscript # TorchScript
|
20 |
+
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
21 |
+
yolov8n_openvino_model # OpenVINO
|
22 |
+
yolov8n.engine # TensorRT
|
23 |
+
yolov8n.mlmodel # CoreML (macOS-only)
|
24 |
+
yolov8n_saved_model # TensorFlow SavedModel
|
25 |
+
yolov8n.pb # TensorFlow GraphDef
|
26 |
+
yolov8n.tflite # TensorFlow Lite
|
27 |
+
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
28 |
+
yolov8n_paddle_model # PaddlePaddle
|
29 |
+
"""
|
30 |
+
import platform
|
31 |
+
from pathlib import Path
|
32 |
+
|
33 |
+
import cv2
|
34 |
+
import numpy as np
|
35 |
+
import torch
|
36 |
+
|
37 |
+
from ultralytics.cfg import get_cfg
|
38 |
+
from ultralytics.data import load_inference_source
|
39 |
+
from ultralytics.data.augment import LetterBox, classify_transforms
|
40 |
+
from ultralytics.nn.autobackend import AutoBackend
|
41 |
+
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
|
42 |
+
from ultralytics.utils.checks import check_imgsz, check_imshow
|
43 |
+
from ultralytics.utils.files import increment_path
|
44 |
+
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
45 |
+
|
46 |
+
STREAM_WARNING = """
|
47 |
+
WARNING ⚠️ stream/video/webcam/dir predict source will accumulate results in RAM unless `stream=True` is passed,
|
48 |
+
causing potential out-of-memory errors for large sources or long-running streams/videos.
|
49 |
+
|
50 |
+
Usage:
|
51 |
+
results = model(source=..., stream=True) # generator of Results objects
|
52 |
+
for r in results:
|
53 |
+
boxes = r.boxes # Boxes object for bbox outputs
|
54 |
+
masks = r.masks # Masks object for segment masks outputs
|
55 |
+
probs = r.probs # Class probabilities for classification outputs
|
56 |
+
"""
|
57 |
+
|
58 |
+
inference_Time=0
|
59 |
+
class BasePredictor:
|
60 |
+
"""
|
61 |
+
BasePredictor
|
62 |
+
|
63 |
+
A base class for creating predictors.
|
64 |
+
|
65 |
+
Attributes:
|
66 |
+
args (SimpleNamespace): Configuration for the predictor.
|
67 |
+
save_dir (Path): Directory to save results.
|
68 |
+
done_warmup (bool): Whether the predictor has finished setup.
|
69 |
+
model (nn.Module): Model used for prediction.
|
70 |
+
data (dict): Data configuration.
|
71 |
+
device (torch.device): Device used for prediction.
|
72 |
+
dataset (Dataset): Dataset used for prediction.
|
73 |
+
vid_path (str): Path to video file.
|
74 |
+
vid_writer (cv2.VideoWriter): Video writer for saving video output.
|
75 |
+
data_path (str): Path to data.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
79 |
+
"""
|
80 |
+
Initializes the BasePredictor class.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
84 |
+
overrides (dict, optional): Configuration overrides. Defaults to None.
|
85 |
+
"""
|
86 |
+
self.args = get_cfg(cfg, overrides)
|
87 |
+
self.save_dir = self.get_save_dir()
|
88 |
+
if self.args.conf is None:
|
89 |
+
self.args.conf = 0.25 # default conf=0.25
|
90 |
+
self.done_warmup = False
|
91 |
+
if self.args.show:
|
92 |
+
self.args.show = check_imshow(warn=True)
|
93 |
+
|
94 |
+
# Usable if setup is done
|
95 |
+
self.model = None
|
96 |
+
self.data = self.args.data # data_dict
|
97 |
+
self.imgsz = None
|
98 |
+
self.device = None
|
99 |
+
self.dataset = None
|
100 |
+
self.vid_path, self.vid_writer = None, None
|
101 |
+
self.plotted_img = None
|
102 |
+
self.data_path = None
|
103 |
+
self.source_type = None
|
104 |
+
self.batch = None
|
105 |
+
self.results = None
|
106 |
+
self.transforms = None
|
107 |
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
108 |
+
self.txt_path = None
|
109 |
+
callbacks.add_integration_callbacks(self)
|
110 |
+
|
111 |
+
def get_save_dir(self):
|
112 |
+
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
113 |
+
name = self.args.name or f'{self.args.mode}'
|
114 |
+
return increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
115 |
+
|
116 |
+
def preprocess(self, im):
|
117 |
+
"""Prepares input image before inference.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
121 |
+
"""
|
122 |
+
not_tensor = not isinstance(im, torch.Tensor)
|
123 |
+
if not_tensor:
|
124 |
+
im = np.stack(self.pre_transform(im))
|
125 |
+
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
126 |
+
im = np.ascontiguousarray(im) # contiguous
|
127 |
+
im = torch.from_numpy(im)
|
128 |
+
|
129 |
+
img = im.to(self.device)
|
130 |
+
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
131 |
+
if not_tensor:
|
132 |
+
img /= 255 # 0 - 255 to 0.0 - 1.0
|
133 |
+
return img
|
134 |
+
|
135 |
+
def inference(self, im, *args, **kwargs):
|
136 |
+
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
|
137 |
+
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
|
138 |
+
return self.model(im, augment=self.args.augment, visualize=visualize)
|
139 |
+
|
140 |
+
def pre_transform(self, im):
|
141 |
+
"""Pre-transform input image before inference.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
145 |
+
|
146 |
+
Return: A list of transformed imgs.
|
147 |
+
"""
|
148 |
+
same_shapes = all(x.shape == im[0].shape for x in im)
|
149 |
+
auto = same_shapes and self.model.pt
|
150 |
+
return [LetterBox(self.imgsz, auto=auto, stride=self.model.stride)(image=x) for x in im]
|
151 |
+
|
152 |
+
def write_results(self, idx, results, batch):
|
153 |
+
"""Write inference results to a file or directory."""
|
154 |
+
p, im, _ = batch
|
155 |
+
log_string = ''
|
156 |
+
if len(im.shape) == 3:
|
157 |
+
im = im[None] # expand for batch dim
|
158 |
+
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
159 |
+
log_string += f'{idx}: '
|
160 |
+
frame = self.dataset.count
|
161 |
+
else:
|
162 |
+
frame = getattr(self.dataset, 'frame', 0)
|
163 |
+
self.data_path = p
|
164 |
+
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
165 |
+
log_string += '%gx%g ' % im.shape[2:] # print string
|
166 |
+
result = results[idx]
|
167 |
+
log_string += result.verbose()
|
168 |
+
|
169 |
+
if self.args.save or self.args.show: # Add bbox to image
|
170 |
+
plot_args = {
|
171 |
+
'line_width': self.args.line_width,
|
172 |
+
'boxes': self.args.boxes,
|
173 |
+
'conf': self.args.show_conf,
|
174 |
+
'labels': self.args.show_labels}
|
175 |
+
if not self.args.retina_masks:
|
176 |
+
plot_args['im_gpu'] = im[idx]
|
177 |
+
self.plotted_img = result.plot(**plot_args)
|
178 |
+
# Write
|
179 |
+
if self.args.save_txt:
|
180 |
+
result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
|
181 |
+
if self.args.save_crop:
|
182 |
+
result.save_crop(save_dir=self.save_dir / 'crops',
|
183 |
+
file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}'))
|
184 |
+
|
185 |
+
return log_string
|
186 |
+
|
187 |
+
def postprocess(self, preds, img, orig_imgs):
|
188 |
+
"""Post-processes predictions for an image and returns them."""
|
189 |
+
return preds
|
190 |
+
|
191 |
+
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
|
192 |
+
"""Performs inference on an image or stream."""
|
193 |
+
self.stream = stream
|
194 |
+
if stream:
|
195 |
+
return self.stream_inference(source, model, *args, **kwargs)
|
196 |
+
else:
|
197 |
+
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
|
198 |
+
|
199 |
+
def predict_cli(self, source=None, model=None):
|
200 |
+
"""Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
|
201 |
+
gen = self.stream_inference(source, model)
|
202 |
+
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
203 |
+
pass
|
204 |
+
|
205 |
+
def setup_source(self, source):
|
206 |
+
"""Sets up source and inference mode."""
|
207 |
+
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
208 |
+
self.transforms = getattr(self.model.model, 'transforms', classify_transforms(
|
209 |
+
self.imgsz[0])) if self.args.task == 'classify' else None
|
210 |
+
self.dataset = load_inference_source(source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride)
|
211 |
+
self.source_type = self.dataset.source_type
|
212 |
+
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
|
213 |
+
len(self.dataset) > 1000 or # images
|
214 |
+
any(getattr(self.dataset, 'video_flag', [False]))): # videos
|
215 |
+
LOGGER.warning(STREAM_WARNING)
|
216 |
+
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
217 |
+
|
218 |
+
@smart_inference_mode()
|
219 |
+
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
220 |
+
"""Streams real-time inference on camera feed and saves results to file."""
|
221 |
+
if self.args.verbose:
|
222 |
+
LOGGER.info('')
|
223 |
+
|
224 |
+
# Setup model
|
225 |
+
if not self.model:
|
226 |
+
self.setup_model(model)
|
227 |
+
|
228 |
+
# Setup source every time predict is called
|
229 |
+
self.setup_source(source if source is not None else self.args.source)
|
230 |
+
|
231 |
+
# Check if save_dir/ label file exists
|
232 |
+
if self.args.save or self.args.save_txt:
|
233 |
+
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
234 |
+
|
235 |
+
# Warmup model
|
236 |
+
if not self.done_warmup:
|
237 |
+
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
|
238 |
+
self.done_warmup = True
|
239 |
+
|
240 |
+
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
|
241 |
+
self.run_callbacks('on_predict_start')
|
242 |
+
for batch in self.dataset:
|
243 |
+
self.run_callbacks('on_predict_batch_start')
|
244 |
+
self.batch = batch
|
245 |
+
path, im0s, vid_cap, s = batch
|
246 |
+
|
247 |
+
# Preprocess
|
248 |
+
with profilers[0]:
|
249 |
+
im = self.preprocess(im0s)
|
250 |
+
|
251 |
+
# Inference
|
252 |
+
with profilers[1]:
|
253 |
+
preds = self.inference(im, *args, **kwargs)
|
254 |
+
|
255 |
+
# Postprocess
|
256 |
+
with profilers[2]:
|
257 |
+
self.results = self.postprocess(preds, im, im0s)
|
258 |
+
self.run_callbacks('on_predict_postprocess_end')
|
259 |
+
|
260 |
+
# Visualize, save, write results
|
261 |
+
n = len(im0s)
|
262 |
+
for i in range(n):
|
263 |
+
self.seen += 1
|
264 |
+
self.results[i].speed = {
|
265 |
+
'preprocess': profilers[0].dt * 1E3 / n,
|
266 |
+
'inference': profilers[1].dt * 1E3 / n,
|
267 |
+
'postprocess': profilers[2].dt * 1E3 / n}
|
268 |
+
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
|
269 |
+
p = Path(p)
|
270 |
+
|
271 |
+
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
272 |
+
s += self.write_results(i, self.results, (p, im, im0))
|
273 |
+
if self.args.save or self.args.save_txt:
|
274 |
+
self.results[i].save_dir = self.save_dir.__str__()
|
275 |
+
if self.args.show and self.plotted_img is not None:
|
276 |
+
self.show(p)
|
277 |
+
if self.args.save and self.plotted_img is not None:
|
278 |
+
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
279 |
+
|
280 |
+
self.run_callbacks('on_predict_batch_end')
|
281 |
+
yield from self.results
|
282 |
+
|
283 |
+
# Print time (inference-only)
|
284 |
+
if self.args.verbose:
|
285 |
+
LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
|
286 |
+
|
287 |
+
# Release assets
|
288 |
+
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
|
289 |
+
self.vid_writer[-1].release() # release final video writer
|
290 |
+
|
291 |
+
# Print results
|
292 |
+
if self.args.verbose and self.seen:
|
293 |
+
t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
|
294 |
+
LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
|
295 |
+
f'{(1, 3, *im.shape[2:])}' % t)
|
296 |
+
if self.args.save or self.args.save_txt or self.args.save_crop:
|
297 |
+
nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
|
298 |
+
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
299 |
+
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
300 |
+
|
301 |
+
self.run_callbacks('on_predict_end')
|
302 |
+
|
303 |
+
def setup_model(self, model, verbose=True):
|
304 |
+
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
305 |
+
self.model = AutoBackend(model or self.args.model,
|
306 |
+
device=select_device(self.args.device, verbose=verbose),
|
307 |
+
dnn=self.args.dnn,
|
308 |
+
data=self.args.data,
|
309 |
+
fp16=self.args.half,
|
310 |
+
fuse=True,
|
311 |
+
verbose=verbose)
|
312 |
+
|
313 |
+
self.device = self.model.device # update device
|
314 |
+
self.args.half = self.model.fp16 # update half
|
315 |
+
self.model.eval()
|
316 |
+
|
317 |
+
def show(self, p):
|
318 |
+
"""Display an image in a window using OpenCV imshow()."""
|
319 |
+
im0 = self.plotted_img
|
320 |
+
if platform.system() == 'Linux' and p not in self.windows:
|
321 |
+
self.windows.append(p)
|
322 |
+
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
323 |
+
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
324 |
+
cv2.imshow(str(p), im0)
|
325 |
+
cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond
|
326 |
+
|
327 |
+
def save_preds(self, vid_cap, idx, save_path):
|
328 |
+
"""Save video predictions as mp4 at specified path."""
|
329 |
+
im0 = self.plotted_img
|
330 |
+
# Save imgs
|
331 |
+
if self.dataset.mode == 'image':
|
332 |
+
cv2.imwrite(save_path, im0)
|
333 |
+
else: # 'video' or 'stream'
|
334 |
+
if self.vid_path[idx] != save_path: # new video
|
335 |
+
self.vid_path[idx] = save_path
|
336 |
+
if isinstance(self.vid_writer[idx], cv2.VideoWriter):
|
337 |
+
self.vid_writer[idx].release() # release previous video writer
|
338 |
+
if vid_cap: # video
|
339 |
+
fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec
|
340 |
+
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
341 |
+
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
342 |
+
else: # stream
|
343 |
+
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
344 |
+
suffix = '.mp4' if MACOS else '.avi' if WINDOWS else '.avi'
|
345 |
+
fourcc = 'avc1' if MACOS else 'WMV2' if WINDOWS else 'MJPG'
|
346 |
+
save_path = str(Path(save_path).with_suffix(suffix))
|
347 |
+
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
|
348 |
+
self.vid_writer[idx].write(im0)
|
349 |
+
|
350 |
+
def run_callbacks(self, event: str):
|
351 |
+
"""Runs all registered callbacks for a specific event."""
|
352 |
+
for callback in self.callbacks.get(event, []):
|
353 |
+
callback(self)
|
354 |
+
|
355 |
+
def add_callback(self, event: str, func):
|
356 |
+
"""
|
357 |
+
Add callback
|
358 |
+
"""
|
359 |
+
self.callbacks[event].append(func)
|
ultralytics/engine/results.py
ADDED
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Ultralytics Results, Boxes and Masks classes for handling inference results
|
4 |
+
|
5 |
+
Usage: See https://docs.ultralytics.com/modes/predict/
|
6 |
+
"""
|
7 |
+
|
8 |
+
from copy import deepcopy
|
9 |
+
from functools import lru_cache
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from ultralytics.data.augment import LetterBox
|
16 |
+
from ultralytics.utils import LOGGER, SimpleClass, deprecation_warn, ops
|
17 |
+
from ultralytics.utils.plotting import Annotator, colors, save_one_box
|
18 |
+
|
19 |
+
|
20 |
+
class BaseTensor(SimpleClass):
|
21 |
+
"""
|
22 |
+
Base tensor class with additional methods for easy manipulation and device handling.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, data, orig_shape) -> None:
|
26 |
+
"""Initialize BaseTensor with data and original shape.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
|
30 |
+
orig_shape (tuple): Original shape of image.
|
31 |
+
"""
|
32 |
+
assert isinstance(data, (torch.Tensor, np.ndarray))
|
33 |
+
self.data = data
|
34 |
+
self.orig_shape = orig_shape
|
35 |
+
|
36 |
+
@property
|
37 |
+
def shape(self):
|
38 |
+
"""Return the shape of the data tensor."""
|
39 |
+
return self.data.shape
|
40 |
+
|
41 |
+
def cpu(self):
|
42 |
+
"""Return a copy of the tensor on CPU memory."""
|
43 |
+
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
|
44 |
+
|
45 |
+
def numpy(self):
|
46 |
+
"""Return a copy of the tensor as a numpy array."""
|
47 |
+
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
|
48 |
+
|
49 |
+
def cuda(self):
|
50 |
+
"""Return a copy of the tensor on GPU memory."""
|
51 |
+
return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)
|
52 |
+
|
53 |
+
def to(self, *args, **kwargs):
|
54 |
+
"""Return a copy of the tensor with the specified device and dtype."""
|
55 |
+
return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)
|
56 |
+
|
57 |
+
def __len__(self): # override len(results)
|
58 |
+
"""Return the length of the data tensor."""
|
59 |
+
return len(self.data)
|
60 |
+
|
61 |
+
def __getitem__(self, idx):
|
62 |
+
"""Return a BaseTensor with the specified index of the data tensor."""
|
63 |
+
return self.__class__(self.data[idx], self.orig_shape)
|
64 |
+
|
65 |
+
|
66 |
+
class Results(SimpleClass):
|
67 |
+
"""
|
68 |
+
A class for storing and manipulating inference results.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
orig_img (numpy.ndarray): The original image as a numpy array.
|
72 |
+
path (str): The path to the image file.
|
73 |
+
names (dict): A dictionary of class names.
|
74 |
+
boxes (torch.tensor, optional): A 2D tensor of bounding box coordinates for each detection.
|
75 |
+
masks (torch.tensor, optional): A 3D tensor of detection masks, where each mask is a binary image.
|
76 |
+
probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task.
|
77 |
+
keypoints (List[List[float]], optional): A list of detected keypoints for each object.
|
78 |
+
|
79 |
+
Attributes:
|
80 |
+
orig_img (numpy.ndarray): The original image as a numpy array.
|
81 |
+
orig_shape (tuple): The original image shape in (height, width) format.
|
82 |
+
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
83 |
+
masks (Masks, optional): A Masks object containing the detection masks.
|
84 |
+
probs (Probs, optional): A Probs object containing probabilities of each class for classification task.
|
85 |
+
keypoints (Keypoints, optional): A Keypoints object containing detected keypoints for each object.
|
86 |
+
speed (dict): A dictionary of preprocess, inference, and postprocess speeds in milliseconds per image.
|
87 |
+
names (dict): A dictionary of class names.
|
88 |
+
path (str): The path to the image file.
|
89 |
+
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
|
93 |
+
"""Initialize the Results class."""
|
94 |
+
self.orig_img = orig_img
|
95 |
+
self.orig_shape = orig_img.shape[:2]
|
96 |
+
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
97 |
+
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
98 |
+
self.probs = Probs(probs) if probs is not None else None
|
99 |
+
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
|
100 |
+
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
|
101 |
+
self.names = names
|
102 |
+
self.path = path
|
103 |
+
self.save_dir = None
|
104 |
+
self._keys = ('boxes', 'masks', 'probs', 'keypoints')
|
105 |
+
|
106 |
+
def __getitem__(self, idx):
|
107 |
+
"""Return a Results object for the specified index."""
|
108 |
+
r = self.new()
|
109 |
+
for k in self.keys:
|
110 |
+
setattr(r, k, getattr(self, k)[idx])
|
111 |
+
return r
|
112 |
+
|
113 |
+
def __len__(self):
|
114 |
+
"""Return the number of detections in the Results object."""
|
115 |
+
for k in self.keys:
|
116 |
+
return len(getattr(self, k))
|
117 |
+
|
118 |
+
def update(self, boxes=None, masks=None, probs=None):
|
119 |
+
"""Update the boxes, masks, and probs attributes of the Results object."""
|
120 |
+
if boxes is not None:
|
121 |
+
ops.clip_boxes(boxes, self.orig_shape) # clip boxes
|
122 |
+
self.boxes = Boxes(boxes, self.orig_shape)
|
123 |
+
if masks is not None:
|
124 |
+
self.masks = Masks(masks, self.orig_shape)
|
125 |
+
if probs is not None:
|
126 |
+
self.probs = probs
|
127 |
+
|
128 |
+
def cpu(self):
|
129 |
+
"""Return a copy of the Results object with all tensors on CPU memory."""
|
130 |
+
r = self.new()
|
131 |
+
for k in self.keys:
|
132 |
+
setattr(r, k, getattr(self, k).cpu())
|
133 |
+
return r
|
134 |
+
|
135 |
+
def numpy(self):
|
136 |
+
"""Return a copy of the Results object with all tensors as numpy arrays."""
|
137 |
+
r = self.new()
|
138 |
+
for k in self.keys:
|
139 |
+
setattr(r, k, getattr(self, k).numpy())
|
140 |
+
return r
|
141 |
+
|
142 |
+
def cuda(self):
|
143 |
+
"""Return a copy of the Results object with all tensors on GPU memory."""
|
144 |
+
r = self.new()
|
145 |
+
for k in self.keys:
|
146 |
+
setattr(r, k, getattr(self, k).cuda())
|
147 |
+
return r
|
148 |
+
|
149 |
+
def to(self, *args, **kwargs):
|
150 |
+
"""Return a copy of the Results object with tensors on the specified device and dtype."""
|
151 |
+
r = self.new()
|
152 |
+
for k in self.keys:
|
153 |
+
setattr(r, k, getattr(self, k).to(*args, **kwargs))
|
154 |
+
return r
|
155 |
+
|
156 |
+
def new(self):
|
157 |
+
"""Return a new Results object with the same image, path, and names."""
|
158 |
+
return Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
159 |
+
|
160 |
+
@property
|
161 |
+
def keys(self):
|
162 |
+
"""Return a list of non-empty attribute names."""
|
163 |
+
return [k for k in self._keys if getattr(self, k) is not None]
|
164 |
+
|
165 |
+
def plot(
|
166 |
+
self,
|
167 |
+
conf=True,
|
168 |
+
line_width=None,
|
169 |
+
font_size=None,
|
170 |
+
font='Arial.ttf',
|
171 |
+
pil=False,
|
172 |
+
img=None,
|
173 |
+
im_gpu=None,
|
174 |
+
kpt_radius=5,
|
175 |
+
kpt_line=True,
|
176 |
+
labels=True,
|
177 |
+
boxes=True,
|
178 |
+
masks=True,
|
179 |
+
probs=True,
|
180 |
+
**kwargs # deprecated args TODO: remove support in 8.2
|
181 |
+
):
|
182 |
+
"""
|
183 |
+
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
conf (bool): Whether to plot the detection confidence score.
|
187 |
+
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
|
188 |
+
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
|
189 |
+
font (str): The font to use for the text.
|
190 |
+
pil (bool): Whether to return the image as a PIL Image.
|
191 |
+
img (numpy.ndarray): Plot to another image. if not, plot to original image.
|
192 |
+
im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
|
193 |
+
kpt_radius (int, optional): Radius of the drawn keypoints. Default is 5.
|
194 |
+
kpt_line (bool): Whether to draw lines connecting keypoints.
|
195 |
+
labels (bool): Whether to plot the label of bounding boxes.
|
196 |
+
boxes (bool): Whether to plot the bounding boxes.
|
197 |
+
masks (bool): Whether to plot the masks.
|
198 |
+
probs (bool): Whether to plot classification probability
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
(numpy.ndarray): A numpy array of the annotated image.
|
202 |
+
|
203 |
+
Example:
|
204 |
+
```python
|
205 |
+
from PIL import Image
|
206 |
+
from ultralytics import YOLO
|
207 |
+
|
208 |
+
model = YOLO('yolov8n.pt')
|
209 |
+
results = model('bus.jpg') # results list
|
210 |
+
for r in results:
|
211 |
+
im_array = r.plot() # plot a BGR numpy array of predictions
|
212 |
+
im = Image.fromarray(im[..., ::-1]) # RGB PIL image
|
213 |
+
im.show() # show image
|
214 |
+
im.save('results.jpg') # save image
|
215 |
+
```
|
216 |
+
"""
|
217 |
+
if img is None and isinstance(self.orig_img, torch.Tensor):
|
218 |
+
img = np.ascontiguousarray(self.orig_img[0].permute(1, 2, 0).cpu().detach().numpy()) * 255
|
219 |
+
|
220 |
+
# Deprecation warn TODO: remove in 8.2
|
221 |
+
if 'show_conf' in kwargs:
|
222 |
+
deprecation_warn('show_conf', 'conf')
|
223 |
+
conf = kwargs['show_conf']
|
224 |
+
assert type(conf) == bool, '`show_conf` should be of boolean type, i.e, show_conf=True/False'
|
225 |
+
|
226 |
+
if 'line_thickness' in kwargs:
|
227 |
+
deprecation_warn('line_thickness', 'line_width')
|
228 |
+
line_width = kwargs['line_thickness']
|
229 |
+
assert type(line_width) == int, '`line_width` should be of int type, i.e, line_width=3'
|
230 |
+
|
231 |
+
names = self.names
|
232 |
+
pred_boxes, show_boxes = self.boxes, boxes
|
233 |
+
pred_masks, show_masks = self.masks, masks
|
234 |
+
pred_probs, show_probs = self.probs, probs
|
235 |
+
annotator = Annotator(
|
236 |
+
deepcopy(self.orig_img if img is None else img),
|
237 |
+
line_width,
|
238 |
+
font_size,
|
239 |
+
font,
|
240 |
+
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
|
241 |
+
example=names)
|
242 |
+
|
243 |
+
# Plot Segment results
|
244 |
+
if pred_masks and show_masks:
|
245 |
+
if im_gpu is None:
|
246 |
+
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
247 |
+
im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
|
248 |
+
2, 0, 1).flip(0).contiguous() / 255
|
249 |
+
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
|
250 |
+
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
|
251 |
+
|
252 |
+
# Plot Detect results
|
253 |
+
if pred_boxes and show_boxes:
|
254 |
+
for d in reversed(pred_boxes):
|
255 |
+
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
256 |
+
name = ('' if id is None else f'id:{id} ') + names[c]
|
257 |
+
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
|
258 |
+
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
259 |
+
|
260 |
+
# Plot Classify results
|
261 |
+
if pred_probs is not None and show_probs:
|
262 |
+
text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
|
263 |
+
x = round(self.orig_shape[0] * 0.03)
|
264 |
+
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
265 |
+
|
266 |
+
# Plot Pose results
|
267 |
+
if self.keypoints is not None:
|
268 |
+
for k in reversed(self.keypoints.data):
|
269 |
+
annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)
|
270 |
+
|
271 |
+
return annotator.result()
|
272 |
+
|
273 |
+
def verbose(self):
|
274 |
+
"""
|
275 |
+
Return log string for each task.
|
276 |
+
"""
|
277 |
+
log_string = ''
|
278 |
+
probs = self.probs
|
279 |
+
boxes = self.boxes
|
280 |
+
if len(self) == 0:
|
281 |
+
return log_string if probs is not None else f'{log_string}(no detections), '
|
282 |
+
if probs is not None:
|
283 |
+
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
|
284 |
+
if boxes:
|
285 |
+
for c in boxes.cls.unique():
|
286 |
+
n = (boxes.cls == c).sum() # detections per class
|
287 |
+
log_string += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "
|
288 |
+
return log_string
|
289 |
+
|
290 |
+
def save_txt(self, txt_file, save_conf=False):
|
291 |
+
"""
|
292 |
+
Save predictions into txt file.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
txt_file (str): txt file path.
|
296 |
+
save_conf (bool): save confidence score or not.
|
297 |
+
"""
|
298 |
+
boxes = self.boxes
|
299 |
+
masks = self.masks
|
300 |
+
probs = self.probs
|
301 |
+
kpts = self.keypoints
|
302 |
+
texts = []
|
303 |
+
if probs is not None:
|
304 |
+
# Classify
|
305 |
+
[texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
|
306 |
+
elif boxes:
|
307 |
+
# Detect/segment/pose
|
308 |
+
for j, d in enumerate(boxes):
|
309 |
+
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
|
310 |
+
line = (c, *d.xywhn.view(-1))
|
311 |
+
if masks:
|
312 |
+
seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
|
313 |
+
line = (c, *seg)
|
314 |
+
if kpts is not None:
|
315 |
+
kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
|
316 |
+
line += (*kpt.reshape(-1).tolist(), )
|
317 |
+
line += (conf, ) * save_conf + (() if id is None else (id, ))
|
318 |
+
texts.append(('%g ' * len(line)).rstrip() % line)
|
319 |
+
|
320 |
+
if texts:
|
321 |
+
with open(txt_file, 'a') as f:
|
322 |
+
f.writelines(text + '\n' for text in texts)
|
323 |
+
|
324 |
+
def save_crop(self, save_dir, file_name=Path('im.jpg')):
|
325 |
+
"""
|
326 |
+
Save cropped predictions to `save_dir/cls/file_name.jpg`.
|
327 |
+
|
328 |
+
Args:
|
329 |
+
save_dir (str | pathlib.Path): Save path.
|
330 |
+
file_name (str | pathlib.Path): File name.
|
331 |
+
"""
|
332 |
+
if self.probs is not None:
|
333 |
+
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
|
334 |
+
return
|
335 |
+
if isinstance(save_dir, str):
|
336 |
+
save_dir = Path(save_dir)
|
337 |
+
if isinstance(file_name, str):
|
338 |
+
file_name = Path(file_name)
|
339 |
+
for d in self.boxes:
|
340 |
+
save_one_box(d.xyxy,
|
341 |
+
self.orig_img.copy(),
|
342 |
+
file=save_dir / self.names[int(d.cls)] / f'{file_name.stem}.jpg',
|
343 |
+
BGR=True)
|
344 |
+
|
345 |
+
def tojson(self, normalize=False):
|
346 |
+
"""Convert the object to JSON format."""
|
347 |
+
if self.probs is not None:
|
348 |
+
LOGGER.warning('Warning: Classify task do not support `tojson` yet.')
|
349 |
+
return
|
350 |
+
|
351 |
+
import json
|
352 |
+
|
353 |
+
# Create list of detection dictionaries
|
354 |
+
results = []
|
355 |
+
data = self.boxes.data.cpu().tolist()
|
356 |
+
h, w = self.orig_shape if normalize else (1, 1)
|
357 |
+
for i, row in enumerate(data):
|
358 |
+
box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h}
|
359 |
+
conf = row[4]
|
360 |
+
id = int(row[5])
|
361 |
+
name = self.names[id]
|
362 |
+
result = {'name': name, 'class': id, 'confidence': conf, 'box': box}
|
363 |
+
if self.masks:
|
364 |
+
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
|
365 |
+
result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
|
366 |
+
if self.keypoints is not None:
|
367 |
+
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
|
368 |
+
result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
|
369 |
+
results.append(result)
|
370 |
+
|
371 |
+
# Convert detections to JSON
|
372 |
+
return json.dumps(results, indent=2)
|
373 |
+
|
374 |
+
|
375 |
+
class Boxes(BaseTensor):
|
376 |
+
"""
|
377 |
+
A class for storing and manipulating detection boxes.
|
378 |
+
|
379 |
+
Args:
|
380 |
+
boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
381 |
+
with shape (num_boxes, 6) or (num_boxes, 7). The last two columns contain confidence and class values.
|
382 |
+
If present, the third last column contains track IDs.
|
383 |
+
orig_shape (tuple): Original image size, in the format (height, width).
|
384 |
+
|
385 |
+
Attributes:
|
386 |
+
xyxy (torch.Tensor | numpy.ndarray): The boxes in xyxy format.
|
387 |
+
conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
|
388 |
+
cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
|
389 |
+
id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
|
390 |
+
xywh (torch.Tensor | numpy.ndarray): The boxes in xywh format.
|
391 |
+
xyxyn (torch.Tensor | numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
392 |
+
xywhn (torch.Tensor | numpy.ndarray): The boxes in xywh format normalized by original image size.
|
393 |
+
data (torch.Tensor): The raw bboxes tensor (alias for `boxes`).
|
394 |
+
|
395 |
+
Methods:
|
396 |
+
cpu(): Move the object to CPU memory.
|
397 |
+
numpy(): Convert the object to a numpy array.
|
398 |
+
cuda(): Move the object to CUDA memory.
|
399 |
+
to(*args, **kwargs): Move the object to the specified device.
|
400 |
+
"""
|
401 |
+
|
402 |
+
def __init__(self, boxes, orig_shape) -> None:
|
403 |
+
"""Initialize the Boxes class."""
|
404 |
+
if boxes.ndim == 1:
|
405 |
+
boxes = boxes[None, :]
|
406 |
+
n = boxes.shape[-1]
|
407 |
+
assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
|
408 |
+
super().__init__(boxes, orig_shape)
|
409 |
+
self.is_track = n == 7
|
410 |
+
self.orig_shape = orig_shape
|
411 |
+
|
412 |
+
@property
|
413 |
+
def xyxy(self):
|
414 |
+
"""Return the boxes in xyxy format."""
|
415 |
+
return self.data[:, :4]
|
416 |
+
|
417 |
+
@property
|
418 |
+
def conf(self):
|
419 |
+
"""Return the confidence values of the boxes."""
|
420 |
+
return self.data[:, -2]
|
421 |
+
|
422 |
+
@property
|
423 |
+
def cls(self):
|
424 |
+
"""Return the class values of the boxes."""
|
425 |
+
return self.data[:, -1]
|
426 |
+
|
427 |
+
@property
|
428 |
+
def id(self):
|
429 |
+
"""Return the track IDs of the boxes (if available)."""
|
430 |
+
return self.data[:, -3] if self.is_track else None
|
431 |
+
|
432 |
+
@property
|
433 |
+
@lru_cache(maxsize=2) # maxsize 1 should suffice
|
434 |
+
def xywh(self):
|
435 |
+
"""Return the boxes in xywh format."""
|
436 |
+
return ops.xyxy2xywh(self.xyxy)
|
437 |
+
|
438 |
+
@property
|
439 |
+
@lru_cache(maxsize=2)
|
440 |
+
def xyxyn(self):
|
441 |
+
"""Return the boxes in xyxy format normalized by original image size."""
|
442 |
+
xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy)
|
443 |
+
xyxy[..., [0, 2]] /= self.orig_shape[1]
|
444 |
+
xyxy[..., [1, 3]] /= self.orig_shape[0]
|
445 |
+
return xyxy
|
446 |
+
|
447 |
+
@property
|
448 |
+
@lru_cache(maxsize=2)
|
449 |
+
def xywhn(self):
|
450 |
+
"""Return the boxes in xywh format normalized by original image size."""
|
451 |
+
xywh = ops.xyxy2xywh(self.xyxy)
|
452 |
+
xywh[..., [0, 2]] /= self.orig_shape[1]
|
453 |
+
xywh[..., [1, 3]] /= self.orig_shape[0]
|
454 |
+
return xywh
|
455 |
+
|
456 |
+
@property
|
457 |
+
def boxes(self):
|
458 |
+
"""Return the raw bboxes tensor (deprecated)."""
|
459 |
+
LOGGER.warning("WARNING ⚠️ 'Boxes.boxes' is deprecated. Use 'Boxes.data' instead.")
|
460 |
+
return self.data
|
461 |
+
|
462 |
+
|
463 |
+
class Masks(BaseTensor):
|
464 |
+
"""
|
465 |
+
A class for storing and manipulating detection masks.
|
466 |
+
|
467 |
+
Attributes:
|
468 |
+
segments (list): Deprecated property for segments (normalized).
|
469 |
+
xy (list): A list of segments in pixel coordinates.
|
470 |
+
xyn (list): A list of normalized segments.
|
471 |
+
|
472 |
+
Methods:
|
473 |
+
cpu(): Returns the masks tensor on CPU memory.
|
474 |
+
numpy(): Returns the masks tensor as a numpy array.
|
475 |
+
cuda(): Returns the masks tensor on GPU memory.
|
476 |
+
to(device, dtype): Returns the masks tensor with the specified device and dtype.
|
477 |
+
"""
|
478 |
+
|
479 |
+
def __init__(self, masks, orig_shape) -> None:
|
480 |
+
"""Initialize the Masks class with the given masks tensor and original image shape."""
|
481 |
+
if masks.ndim == 2:
|
482 |
+
masks = masks[None, :]
|
483 |
+
super().__init__(masks, orig_shape)
|
484 |
+
|
485 |
+
@property
|
486 |
+
@lru_cache(maxsize=1)
|
487 |
+
def segments(self):
|
488 |
+
"""Return segments (normalized). Deprecated; use xyn property instead."""
|
489 |
+
LOGGER.warning(
|
490 |
+
"WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and 'Masks.xy' for segments (pixels) instead."
|
491 |
+
)
|
492 |
+
return self.xyn
|
493 |
+
|
494 |
+
@property
|
495 |
+
@lru_cache(maxsize=1)
|
496 |
+
def xyn(self):
|
497 |
+
"""Return normalized segments."""
|
498 |
+
return [
|
499 |
+
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
|
500 |
+
for x in ops.masks2segments(self.data)]
|
501 |
+
|
502 |
+
@property
|
503 |
+
@lru_cache(maxsize=1)
|
504 |
+
def xy(self):
|
505 |
+
"""Return segments in pixel coordinates."""
|
506 |
+
return [
|
507 |
+
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
|
508 |
+
for x in ops.masks2segments(self.data)]
|
509 |
+
|
510 |
+
@property
|
511 |
+
def masks(self):
|
512 |
+
"""Return the raw masks tensor. Deprecated; use data attribute instead."""
|
513 |
+
LOGGER.warning("WARNING ⚠️ 'Masks.masks' is deprecated. Use 'Masks.data' instead.")
|
514 |
+
return self.data
|
515 |
+
|
516 |
+
|
517 |
+
class Keypoints(BaseTensor):
|
518 |
+
"""
|
519 |
+
A class for storing and manipulating detection keypoints.
|
520 |
+
|
521 |
+
Attributes:
|
522 |
+
xy (torch.Tensor): A collection of keypoints containing x, y coordinates for each detection.
|
523 |
+
xyn (torch.Tensor): A normalized version of xy with coordinates in the range [0, 1].
|
524 |
+
conf (torch.Tensor): Confidence values associated with keypoints if available, otherwise None.
|
525 |
+
|
526 |
+
Methods:
|
527 |
+
cpu(): Returns a copy of the keypoints tensor on CPU memory.
|
528 |
+
numpy(): Returns a copy of the keypoints tensor as a numpy array.
|
529 |
+
cuda(): Returns a copy of the keypoints tensor on GPU memory.
|
530 |
+
to(device, dtype): Returns a copy of the keypoints tensor with the specified device and dtype.
|
531 |
+
"""
|
532 |
+
|
533 |
+
def __init__(self, keypoints, orig_shape) -> None:
|
534 |
+
"""Initializes the Keypoints object with detection keypoints and original image size."""
|
535 |
+
if keypoints.ndim == 2:
|
536 |
+
keypoints = keypoints[None, :]
|
537 |
+
super().__init__(keypoints, orig_shape)
|
538 |
+
self.has_visible = self.data.shape[-1] == 3
|
539 |
+
|
540 |
+
@property
|
541 |
+
@lru_cache(maxsize=1)
|
542 |
+
def xy(self):
|
543 |
+
"""Returns x, y coordinates of keypoints."""
|
544 |
+
return self.data[..., :2]
|
545 |
+
|
546 |
+
@property
|
547 |
+
@lru_cache(maxsize=1)
|
548 |
+
def xyn(self):
|
549 |
+
"""Returns normalized x, y coordinates of keypoints."""
|
550 |
+
xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
|
551 |
+
xy[..., 0] /= self.orig_shape[1]
|
552 |
+
xy[..., 1] /= self.orig_shape[0]
|
553 |
+
return xy
|
554 |
+
|
555 |
+
@property
|
556 |
+
@lru_cache(maxsize=1)
|
557 |
+
def conf(self):
|
558 |
+
"""Returns confidence values of keypoints if available, else None."""
|
559 |
+
return self.data[..., 2] if self.has_visible else None
|
560 |
+
|
561 |
+
|
562 |
+
class Probs(BaseTensor):
|
563 |
+
"""
|
564 |
+
A class for storing and manipulating classification predictions.
|
565 |
+
|
566 |
+
Attributes:
|
567 |
+
top1 (int): Index of the top 1 class.
|
568 |
+
top5 (list[int]): Indices of the top 5 classes.
|
569 |
+
top1conf (torch.Tensor): Confidence of the top 1 class.
|
570 |
+
top5conf (torch.Tensor): Confidences of the top 5 classes.
|
571 |
+
|
572 |
+
Methods:
|
573 |
+
cpu(): Returns a copy of the probs tensor on CPU memory.
|
574 |
+
numpy(): Returns a copy of the probs tensor as a numpy array.
|
575 |
+
cuda(): Returns a copy of the probs tensor on GPU memory.
|
576 |
+
to(): Returns a copy of the probs tensor with the specified device and dtype.
|
577 |
+
"""
|
578 |
+
|
579 |
+
def __init__(self, probs, orig_shape=None) -> None:
|
580 |
+
super().__init__(probs, orig_shape)
|
581 |
+
|
582 |
+
@property
|
583 |
+
@lru_cache(maxsize=1)
|
584 |
+
def top1(self):
|
585 |
+
"""Return the index of top 1."""
|
586 |
+
return int(self.data.argmax())
|
587 |
+
|
588 |
+
@property
|
589 |
+
@lru_cache(maxsize=1)
|
590 |
+
def top5(self):
|
591 |
+
"""Return the indices of top 5."""
|
592 |
+
return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy.
|
593 |
+
|
594 |
+
@property
|
595 |
+
@lru_cache(maxsize=1)
|
596 |
+
def top1conf(self):
|
597 |
+
"""Return the confidence of top 1."""
|
598 |
+
return self.data[self.top1]
|
599 |
+
|
600 |
+
@property
|
601 |
+
@lru_cache(maxsize=1)
|
602 |
+
def top5conf(self):
|
603 |
+
"""Return the confidences of top 5."""
|
604 |
+
return self.data[self.top5]
|
ultralytics/engine/trainer.py
ADDED
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Train a model on a dataset
|
4 |
+
|
5 |
+
Usage:
|
6 |
+
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
|
7 |
+
"""
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import time
|
12 |
+
from copy import deepcopy
|
13 |
+
from datetime import datetime, timedelta
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from torch import distributed as dist
|
19 |
+
from torch import nn, optim
|
20 |
+
from torch.cuda import amp
|
21 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from ultralytics.cfg import get_cfg
|
25 |
+
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
26 |
+
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
27 |
+
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, clean_url,
|
28 |
+
colorstr, emojis, yaml_save)
|
29 |
+
from ultralytics.utils.autobatch import check_train_batch_size
|
30 |
+
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
|
31 |
+
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
32 |
+
from ultralytics.utils.files import get_latest_run, increment_path
|
33 |
+
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
|
34 |
+
strip_optimizer)
|
35 |
+
|
36 |
+
|
37 |
+
class BaseTrainer:
|
38 |
+
"""
|
39 |
+
BaseTrainer
|
40 |
+
|
41 |
+
A base class for creating trainers.
|
42 |
+
|
43 |
+
Attributes:
|
44 |
+
args (SimpleNamespace): Configuration for the trainer.
|
45 |
+
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
46 |
+
validator (BaseValidator): Validator instance.
|
47 |
+
model (nn.Module): Model instance.
|
48 |
+
callbacks (defaultdict): Dictionary of callbacks.
|
49 |
+
save_dir (Path): Directory to save results.
|
50 |
+
wdir (Path): Directory to save weights.
|
51 |
+
last (Path): Path to last checkpoint.
|
52 |
+
best (Path): Path to best checkpoint.
|
53 |
+
save_period (int): Save checkpoint every x epochs (disabled if < 1).
|
54 |
+
batch_size (int): Batch size for training.
|
55 |
+
epochs (int): Number of epochs to train for.
|
56 |
+
start_epoch (int): Starting epoch for training.
|
57 |
+
device (torch.device): Device to use for training.
|
58 |
+
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
|
59 |
+
scaler (amp.GradScaler): Gradient scaler for AMP.
|
60 |
+
data (str): Path to data.
|
61 |
+
trainset (torch.utils.data.Dataset): Training dataset.
|
62 |
+
testset (torch.utils.data.Dataset): Testing dataset.
|
63 |
+
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
64 |
+
lf (nn.Module): Loss function.
|
65 |
+
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
66 |
+
best_fitness (float): The best fitness value achieved.
|
67 |
+
fitness (float): Current fitness value.
|
68 |
+
loss (float): Current loss value.
|
69 |
+
tloss (float): Total loss value.
|
70 |
+
loss_names (list): List of loss names.
|
71 |
+
csv (Path): Path to results CSV file.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
75 |
+
"""
|
76 |
+
Initializes the BaseTrainer class.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
80 |
+
overrides (dict, optional): Configuration overrides. Defaults to None.
|
81 |
+
"""
|
82 |
+
self.args = get_cfg(cfg, overrides)
|
83 |
+
self.device = select_device(self.args.device, self.args.batch)
|
84 |
+
self.check_resume()
|
85 |
+
self.validator = None
|
86 |
+
self.model = None
|
87 |
+
self.metrics = None
|
88 |
+
self.plots = {}
|
89 |
+
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
90 |
+
|
91 |
+
# Dirs
|
92 |
+
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
93 |
+
name = self.args.name or f'{self.args.mode}'
|
94 |
+
if hasattr(self.args, 'save_dir'):
|
95 |
+
self.save_dir = Path(self.args.save_dir)
|
96 |
+
else:
|
97 |
+
self.save_dir = Path(
|
98 |
+
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
|
99 |
+
self.wdir = self.save_dir / 'weights' # weights dir
|
100 |
+
if RANK in (-1, 0):
|
101 |
+
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
102 |
+
self.args.save_dir = str(self.save_dir)
|
103 |
+
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
|
104 |
+
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
105 |
+
self.save_period = self.args.save_period
|
106 |
+
|
107 |
+
self.batch_size = self.args.batch
|
108 |
+
self.epochs = self.args.epochs
|
109 |
+
self.start_epoch = 0
|
110 |
+
if RANK == -1:
|
111 |
+
print_args(vars(self.args))
|
112 |
+
|
113 |
+
# Device
|
114 |
+
if self.device.type == 'cpu':
|
115 |
+
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
116 |
+
|
117 |
+
# Model and Dataset
|
118 |
+
self.model = self.args.model
|
119 |
+
try:
|
120 |
+
if self.args.task == 'classify':
|
121 |
+
self.data = check_cls_dataset(self.args.data)
|
122 |
+
elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment'):
|
123 |
+
self.data = check_det_dataset(self.args.data)
|
124 |
+
if 'yaml_file' in self.data:
|
125 |
+
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
126 |
+
except Exception as e:
|
127 |
+
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
128 |
+
|
129 |
+
self.trainset, self.testset = self.get_dataset(self.data)
|
130 |
+
self.ema = None
|
131 |
+
|
132 |
+
# Optimization utils init
|
133 |
+
self.lf = None
|
134 |
+
self.scheduler = None
|
135 |
+
|
136 |
+
# Epoch level metrics
|
137 |
+
self.best_fitness = None
|
138 |
+
self.fitness = None
|
139 |
+
self.loss = None
|
140 |
+
self.tloss = None
|
141 |
+
self.loss_names = ['Loss']
|
142 |
+
self.csv = self.save_dir / 'results.csv'
|
143 |
+
self.plot_idx = [0, 1, 2]
|
144 |
+
|
145 |
+
# Callbacks
|
146 |
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
147 |
+
if RANK in (-1, 0):
|
148 |
+
callbacks.add_integration_callbacks(self)
|
149 |
+
|
150 |
+
def add_callback(self, event: str, callback):
|
151 |
+
"""
|
152 |
+
Appends the given callback.
|
153 |
+
"""
|
154 |
+
self.callbacks[event].append(callback)
|
155 |
+
|
156 |
+
def set_callback(self, event: str, callback):
|
157 |
+
"""
|
158 |
+
Overrides the existing callbacks with the given callback.
|
159 |
+
"""
|
160 |
+
self.callbacks[event] = [callback]
|
161 |
+
|
162 |
+
def run_callbacks(self, event: str):
|
163 |
+
"""Run all existing callbacks associated with a particular event."""
|
164 |
+
for callback in self.callbacks.get(event, []):
|
165 |
+
callback(self)
|
166 |
+
|
167 |
+
def train(self):
|
168 |
+
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
169 |
+
if isinstance(self.args.device, int) or self.args.device: # i.e. device=0 or device=[0,1,2,3]
|
170 |
+
world_size = torch.cuda.device_count()
|
171 |
+
elif torch.cuda.is_available(): # i.e. device=None or device=''
|
172 |
+
world_size = 1 # default to device 0
|
173 |
+
else: # i.e. device='cpu' or 'mps'
|
174 |
+
world_size = 0
|
175 |
+
|
176 |
+
# Run subprocess if DDP training, else train normally
|
177 |
+
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
178 |
+
# Argument checks
|
179 |
+
if self.args.rect:
|
180 |
+
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting rect=False")
|
181 |
+
self.args.rect = False
|
182 |
+
# Command
|
183 |
+
cmd, file = generate_ddp_command(world_size, self)
|
184 |
+
try:
|
185 |
+
LOGGER.info(f'DDP command: {cmd}')
|
186 |
+
subprocess.run(cmd, check=True)
|
187 |
+
except Exception as e:
|
188 |
+
raise e
|
189 |
+
finally:
|
190 |
+
ddp_cleanup(self, str(file))
|
191 |
+
else:
|
192 |
+
self._do_train(world_size)
|
193 |
+
|
194 |
+
def _setup_ddp(self, world_size):
|
195 |
+
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
196 |
+
torch.cuda.set_device(RANK)
|
197 |
+
self.device = torch.device('cuda', RANK)
|
198 |
+
LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
199 |
+
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
|
200 |
+
dist.init_process_group(
|
201 |
+
'nccl' if dist.is_nccl_available() else 'gloo',
|
202 |
+
timeout=timedelta(seconds=10800), # 3 hours
|
203 |
+
rank=RANK,
|
204 |
+
world_size=world_size)
|
205 |
+
|
206 |
+
def _setup_train(self, world_size):
|
207 |
+
"""
|
208 |
+
Builds dataloaders and optimizer on correct rank process.
|
209 |
+
"""
|
210 |
+
# Model
|
211 |
+
self.run_callbacks('on_pretrain_routine_start')
|
212 |
+
ckpt = self.setup_model()
|
213 |
+
self.model = self.model.to(self.device)
|
214 |
+
self.set_model_attributes()
|
215 |
+
# Check AMP
|
216 |
+
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
217 |
+
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
218 |
+
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
219 |
+
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
220 |
+
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
221 |
+
if RANK > -1 and world_size > 1: # DDP
|
222 |
+
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
223 |
+
self.amp = bool(self.amp) # as boolean
|
224 |
+
self.scaler = amp.GradScaler(enabled=self.amp)
|
225 |
+
if world_size > 1:
|
226 |
+
self.model = DDP(self.model, device_ids=[RANK])
|
227 |
+
# Check imgsz
|
228 |
+
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
|
229 |
+
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
230 |
+
# Batch size
|
231 |
+
if self.batch_size == -1:
|
232 |
+
if RANK == -1: # single-GPU only, estimate best batch size
|
233 |
+
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
234 |
+
else:
|
235 |
+
SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
|
236 |
+
'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
|
237 |
+
|
238 |
+
# Dataloaders
|
239 |
+
batch_size = self.batch_size // max(world_size, 1)
|
240 |
+
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
|
241 |
+
if RANK in (-1, 0):
|
242 |
+
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
|
243 |
+
self.validator = self.get_validator()
|
244 |
+
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
|
245 |
+
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
246 |
+
self.ema = ModelEMA(self.model)
|
247 |
+
if self.args.plots:
|
248 |
+
self.plot_training_labels()
|
249 |
+
|
250 |
+
# Optimizer
|
251 |
+
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
252 |
+
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
253 |
+
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
254 |
+
self.optimizer = self.build_optimizer(model=self.model,
|
255 |
+
name=self.args.optimizer,
|
256 |
+
lr=self.args.lr0,
|
257 |
+
momentum=self.args.momentum,
|
258 |
+
decay=weight_decay,
|
259 |
+
iterations=iterations)
|
260 |
+
# Scheduler
|
261 |
+
if self.args.cos_lr:
|
262 |
+
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
263 |
+
else:
|
264 |
+
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
265 |
+
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
266 |
+
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
267 |
+
self.resume_training(ckpt)
|
268 |
+
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
269 |
+
self.run_callbacks('on_pretrain_routine_end')
|
270 |
+
|
271 |
+
def _do_train(self, world_size=1):
|
272 |
+
"""Train completed, evaluate and plot if specified by arguments."""
|
273 |
+
if world_size > 1:
|
274 |
+
self._setup_ddp(world_size)
|
275 |
+
|
276 |
+
self._setup_train(world_size)
|
277 |
+
|
278 |
+
self.epoch_time = None
|
279 |
+
self.epoch_time_start = time.time()
|
280 |
+
self.train_time_start = time.time()
|
281 |
+
nb = len(self.train_loader) # number of batches
|
282 |
+
nw = max(round(self.args.warmup_epochs *
|
283 |
+
nb), 100) if self.args.warmup_epochs > 0 else -1 # number of warmup iterations
|
284 |
+
last_opt_step = -1
|
285 |
+
self.run_callbacks('on_train_start')
|
286 |
+
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
287 |
+
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
288 |
+
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
289 |
+
f'Starting training for {self.epochs} epochs...')
|
290 |
+
if self.args.close_mosaic:
|
291 |
+
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
292 |
+
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
293 |
+
epoch = self.epochs # predefine for resume fully trained model edge cases
|
294 |
+
for epoch in range(self.start_epoch, self.epochs):
|
295 |
+
self.epoch = epoch
|
296 |
+
self.run_callbacks('on_train_epoch_start')
|
297 |
+
self.model.train()
|
298 |
+
if RANK != -1:
|
299 |
+
self.train_loader.sampler.set_epoch(epoch)
|
300 |
+
pbar = enumerate(self.train_loader)
|
301 |
+
# Update dataloader attributes (optional)
|
302 |
+
if epoch == (self.epochs - self.args.close_mosaic):
|
303 |
+
LOGGER.info('Closing dataloader mosaic')
|
304 |
+
if hasattr(self.train_loader.dataset, 'mosaic'):
|
305 |
+
self.train_loader.dataset.mosaic = False
|
306 |
+
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
307 |
+
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
308 |
+
self.train_loader.reset()
|
309 |
+
|
310 |
+
if RANK in (-1, 0):
|
311 |
+
LOGGER.info(self.progress_string())
|
312 |
+
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
|
313 |
+
self.tloss = None
|
314 |
+
self.optimizer.zero_grad()
|
315 |
+
for i, batch in pbar:
|
316 |
+
self.run_callbacks('on_train_batch_start')
|
317 |
+
# Warmup
|
318 |
+
ni = i + nb * epoch
|
319 |
+
if ni <= nw:
|
320 |
+
xi = [0, nw] # x interp
|
321 |
+
self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
|
322 |
+
for j, x in enumerate(self.optimizer.param_groups):
|
323 |
+
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
324 |
+
x['lr'] = np.interp(
|
325 |
+
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
|
326 |
+
if 'momentum' in x:
|
327 |
+
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
328 |
+
|
329 |
+
# Forward
|
330 |
+
with torch.cuda.amp.autocast(self.amp):
|
331 |
+
batch = self.preprocess_batch(batch)
|
332 |
+
self.loss, self.loss_items = self.model(batch)
|
333 |
+
if RANK != -1:
|
334 |
+
self.loss *= world_size
|
335 |
+
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
336 |
+
else self.loss_items
|
337 |
+
|
338 |
+
# Backward
|
339 |
+
self.scaler.scale(self.loss).backward()
|
340 |
+
|
341 |
+
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
342 |
+
if ni - last_opt_step >= self.accumulate:
|
343 |
+
self.optimizer_step()
|
344 |
+
last_opt_step = ni
|
345 |
+
|
346 |
+
# Log
|
347 |
+
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
348 |
+
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
349 |
+
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
350 |
+
if RANK in (-1, 0):
|
351 |
+
pbar.set_description(
|
352 |
+
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
353 |
+
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
|
354 |
+
self.run_callbacks('on_batch_end')
|
355 |
+
if self.args.plots and ni in self.plot_idx:
|
356 |
+
self.plot_training_samples(batch, ni)
|
357 |
+
|
358 |
+
self.run_callbacks('on_train_batch_end')
|
359 |
+
|
360 |
+
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
361 |
+
|
362 |
+
self.scheduler.step()
|
363 |
+
self.run_callbacks('on_train_epoch_end')
|
364 |
+
|
365 |
+
if RANK in (-1, 0):
|
366 |
+
|
367 |
+
# Validation
|
368 |
+
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
369 |
+
final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
|
370 |
+
|
371 |
+
if self.args.val or final_epoch:
|
372 |
+
self.metrics, self.fitness = self.validate()
|
373 |
+
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
374 |
+
self.stop = self.stopper(epoch + 1, self.fitness)
|
375 |
+
|
376 |
+
# Save model
|
377 |
+
if self.args.save or (epoch + 1 == self.epochs):
|
378 |
+
self.save_model()
|
379 |
+
self.run_callbacks('on_model_save')
|
380 |
+
|
381 |
+
tnow = time.time()
|
382 |
+
self.epoch_time = tnow - self.epoch_time_start
|
383 |
+
self.epoch_time_start = tnow
|
384 |
+
self.run_callbacks('on_fit_epoch_end')
|
385 |
+
torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
|
386 |
+
|
387 |
+
# Early Stopping
|
388 |
+
if RANK != -1: # if DDP training
|
389 |
+
broadcast_list = [self.stop if RANK == 0 else None]
|
390 |
+
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
391 |
+
if RANK != 0:
|
392 |
+
self.stop = broadcast_list[0]
|
393 |
+
if self.stop:
|
394 |
+
break # must break all DDP ranks
|
395 |
+
|
396 |
+
if RANK in (-1, 0):
|
397 |
+
# Do final val with best.pt
|
398 |
+
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
399 |
+
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
400 |
+
self.final_eval()
|
401 |
+
if self.args.plots:
|
402 |
+
self.plot_metrics()
|
403 |
+
self.run_callbacks('on_train_end')
|
404 |
+
torch.cuda.empty_cache()
|
405 |
+
self.run_callbacks('teardown')
|
406 |
+
|
407 |
+
def save_model(self):
|
408 |
+
"""Save model checkpoints based on various conditions."""
|
409 |
+
ckpt = {
|
410 |
+
'epoch': self.epoch,
|
411 |
+
'best_fitness': self.best_fitness,
|
412 |
+
'model': deepcopy(de_parallel(self.model)).half(),
|
413 |
+
'ema': deepcopy(self.ema.ema).half(),
|
414 |
+
'updates': self.ema.updates,
|
415 |
+
'optimizer': self.optimizer.state_dict(),
|
416 |
+
'train_args': vars(self.args), # save as dict
|
417 |
+
'date': datetime.now().isoformat(),
|
418 |
+
'version': __version__}
|
419 |
+
|
420 |
+
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
421 |
+
try:
|
422 |
+
import dill as pickle
|
423 |
+
except ImportError:
|
424 |
+
import pickle
|
425 |
+
|
426 |
+
# Save last, best and delete
|
427 |
+
torch.save(ckpt, self.last, pickle_module=pickle)
|
428 |
+
if self.best_fitness == self.fitness:
|
429 |
+
torch.save(ckpt, self.best, pickle_module=pickle)
|
430 |
+
if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
|
431 |
+
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
|
432 |
+
del ckpt
|
433 |
+
|
434 |
+
@staticmethod
|
435 |
+
def get_dataset(data):
|
436 |
+
"""
|
437 |
+
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
438 |
+
"""
|
439 |
+
return data['train'], data.get('val') or data.get('test')
|
440 |
+
|
441 |
+
def setup_model(self):
|
442 |
+
"""
|
443 |
+
load/create/download model for any task.
|
444 |
+
"""
|
445 |
+
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
446 |
+
return
|
447 |
+
|
448 |
+
model, weights = self.model, None
|
449 |
+
ckpt = None
|
450 |
+
if str(model).endswith('.pt'):
|
451 |
+
weights, ckpt = attempt_load_one_weight(model)
|
452 |
+
cfg = ckpt['model'].yaml
|
453 |
+
else:
|
454 |
+
cfg = model
|
455 |
+
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
456 |
+
return ckpt
|
457 |
+
|
458 |
+
def optimizer_step(self):
|
459 |
+
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
|
460 |
+
self.scaler.unscale_(self.optimizer) # unscale gradients
|
461 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
|
462 |
+
self.scaler.step(self.optimizer)
|
463 |
+
self.scaler.update()
|
464 |
+
self.optimizer.zero_grad()
|
465 |
+
if self.ema:
|
466 |
+
self.ema.update(self.model)
|
467 |
+
|
468 |
+
def preprocess_batch(self, batch):
|
469 |
+
"""
|
470 |
+
Allows custom preprocessing model inputs and ground truths depending on task type.
|
471 |
+
"""
|
472 |
+
return batch
|
473 |
+
|
474 |
+
def validate(self):
|
475 |
+
"""
|
476 |
+
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
477 |
+
"""
|
478 |
+
metrics = self.validator(self)
|
479 |
+
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
480 |
+
if not self.best_fitness or self.best_fitness < fitness:
|
481 |
+
self.best_fitness = fitness
|
482 |
+
return metrics, fitness
|
483 |
+
|
484 |
+
def get_model(self, cfg=None, weights=None, verbose=True):
|
485 |
+
"""Get model and raise NotImplementedError for loading cfg files."""
|
486 |
+
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
487 |
+
|
488 |
+
def get_validator(self):
|
489 |
+
"""Returns a NotImplementedError when the get_validator function is called."""
|
490 |
+
raise NotImplementedError('get_validator function not implemented in trainer')
|
491 |
+
|
492 |
+
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
493 |
+
"""
|
494 |
+
Returns dataloader derived from torch.data.Dataloader.
|
495 |
+
"""
|
496 |
+
raise NotImplementedError('get_dataloader function not implemented in trainer')
|
497 |
+
|
498 |
+
def build_dataset(self, img_path, mode='train', batch=None):
|
499 |
+
"""Build dataset"""
|
500 |
+
raise NotImplementedError('build_dataset function not implemented in trainer')
|
501 |
+
|
502 |
+
def label_loss_items(self, loss_items=None, prefix='train'):
|
503 |
+
"""
|
504 |
+
Returns a loss dict with labelled training loss items tensor
|
505 |
+
"""
|
506 |
+
# Not needed for classification but necessary for segmentation & detection
|
507 |
+
return {'loss': loss_items} if loss_items is not None else ['loss']
|
508 |
+
|
509 |
+
def set_model_attributes(self):
|
510 |
+
"""
|
511 |
+
To set or update model parameters before training.
|
512 |
+
"""
|
513 |
+
self.model.names = self.data['names']
|
514 |
+
|
515 |
+
def build_targets(self, preds, targets):
|
516 |
+
"""Builds target tensors for training YOLO model."""
|
517 |
+
pass
|
518 |
+
|
519 |
+
def progress_string(self):
|
520 |
+
"""Returns a string describing training progress."""
|
521 |
+
return ''
|
522 |
+
|
523 |
+
# TODO: may need to put these following functions into callback
|
524 |
+
def plot_training_samples(self, batch, ni):
|
525 |
+
"""Plots training samples during YOLOv5 training."""
|
526 |
+
pass
|
527 |
+
|
528 |
+
def plot_training_labels(self):
|
529 |
+
"""Plots training labels for YOLO model."""
|
530 |
+
pass
|
531 |
+
|
532 |
+
def save_metrics(self, metrics):
|
533 |
+
"""Saves training metrics to a CSV file."""
|
534 |
+
keys, vals = list(metrics.keys()), list(metrics.values())
|
535 |
+
n = len(metrics) + 1 # number of cols
|
536 |
+
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
|
537 |
+
with open(self.csv, 'a') as f:
|
538 |
+
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
|
539 |
+
|
540 |
+
def plot_metrics(self):
|
541 |
+
"""Plot and display metrics visually."""
|
542 |
+
pass
|
543 |
+
|
544 |
+
def on_plot(self, name, data=None):
|
545 |
+
"""Registers plots (e.g. to be consumed in callbacks)"""
|
546 |
+
self.plots[name] = {'data': data, 'timestamp': time.time()}
|
547 |
+
|
548 |
+
def final_eval(self):
|
549 |
+
"""Performs final evaluation and validation for object detection YOLO model."""
|
550 |
+
for f in self.last, self.best:
|
551 |
+
if f.exists():
|
552 |
+
strip_optimizer(f) # strip optimizers
|
553 |
+
if f is self.best:
|
554 |
+
LOGGER.info(f'\nValidating {f}...')
|
555 |
+
self.metrics = self.validator(model=f)
|
556 |
+
self.metrics.pop('fitness', None)
|
557 |
+
self.run_callbacks('on_fit_epoch_end')
|
558 |
+
|
559 |
+
def check_resume(self):
|
560 |
+
"""Check if resume checkpoint exists and update arguments accordingly."""
|
561 |
+
resume = self.args.resume
|
562 |
+
if resume:
|
563 |
+
try:
|
564 |
+
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
|
565 |
+
last = Path(check_file(resume) if exists else get_latest_run())
|
566 |
+
|
567 |
+
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
568 |
+
ckpt_args = attempt_load_weights(last).args
|
569 |
+
if not Path(ckpt_args['data']).exists():
|
570 |
+
ckpt_args['data'] = self.args.data
|
571 |
+
|
572 |
+
self.args = get_cfg(ckpt_args)
|
573 |
+
self.args.model, resume = str(last), True # reinstate
|
574 |
+
except Exception as e:
|
575 |
+
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
|
576 |
+
"i.e. 'yolo train resume model=path/to/last.pt'") from e
|
577 |
+
self.resume = resume
|
578 |
+
|
579 |
+
def resume_training(self, ckpt):
|
580 |
+
"""Resume YOLO training from given epoch and best fitness."""
|
581 |
+
if ckpt is None:
|
582 |
+
return
|
583 |
+
best_fitness = 0.0
|
584 |
+
start_epoch = ckpt['epoch'] + 1
|
585 |
+
if ckpt['optimizer'] is not None:
|
586 |
+
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
587 |
+
best_fitness = ckpt['best_fitness']
|
588 |
+
if self.ema and ckpt.get('ema'):
|
589 |
+
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
590 |
+
self.ema.updates = ckpt['updates']
|
591 |
+
if self.resume:
|
592 |
+
assert start_epoch > 0, \
|
593 |
+
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
|
594 |
+
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
595 |
+
LOGGER.info(
|
596 |
+
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
|
597 |
+
if self.epochs < start_epoch:
|
598 |
+
LOGGER.info(
|
599 |
+
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
|
600 |
+
self.epochs += ckpt['epoch'] # finetune additional epochs
|
601 |
+
self.best_fitness = best_fitness
|
602 |
+
self.start_epoch = start_epoch
|
603 |
+
if start_epoch > (self.epochs - self.args.close_mosaic):
|
604 |
+
LOGGER.info('Closing dataloader mosaic')
|
605 |
+
if hasattr(self.train_loader.dataset, 'mosaic'):
|
606 |
+
self.train_loader.dataset.mosaic = False
|
607 |
+
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
608 |
+
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
609 |
+
|
610 |
+
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
611 |
+
"""
|
612 |
+
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
|
613 |
+
momentum, weight decay, and number of iterations.
|
614 |
+
|
615 |
+
Args:
|
616 |
+
model (torch.nn.Module): The model for which to build an optimizer.
|
617 |
+
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
618 |
+
based on the number of iterations. Default: 'auto'.
|
619 |
+
lr (float, optional): The learning rate for the optimizer. Default: 0.001.
|
620 |
+
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
|
621 |
+
decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
|
622 |
+
iterations (float, optional): The number of iterations, which determines the optimizer if
|
623 |
+
name is 'auto'. Default: 1e5.
|
624 |
+
|
625 |
+
Returns:
|
626 |
+
(torch.optim.Optimizer): The constructed optimizer.
|
627 |
+
"""
|
628 |
+
|
629 |
+
g = [], [], [] # optimizer parameter groups
|
630 |
+
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
631 |
+
if name == 'auto':
|
632 |
+
nc = getattr(model, 'nc', 10) # number of classes
|
633 |
+
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
634 |
+
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
|
635 |
+
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
636 |
+
|
637 |
+
for module_name, module in model.named_modules():
|
638 |
+
for param_name, param in module.named_parameters(recurse=False):
|
639 |
+
fullname = f'{module_name}.{param_name}' if module_name else param_name
|
640 |
+
if 'bias' in fullname: # bias (no decay)
|
641 |
+
g[2].append(param)
|
642 |
+
elif isinstance(module, bn): # weight (no decay)
|
643 |
+
g[1].append(param)
|
644 |
+
else: # weight (with decay)
|
645 |
+
g[0].append(param)
|
646 |
+
|
647 |
+
if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
|
648 |
+
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
649 |
+
elif name == 'RMSProp':
|
650 |
+
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
651 |
+
elif name == 'SGD':
|
652 |
+
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
653 |
+
else:
|
654 |
+
raise NotImplementedError(
|
655 |
+
f"Optimizer '{name}' not found in list of available optimizers "
|
656 |
+
f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
|
657 |
+
'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
|
658 |
+
|
659 |
+
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
660 |
+
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
661 |
+
LOGGER.info(
|
662 |
+
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
663 |
+
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
|
664 |
+
return optimizer
|
ultralytics/engine/validator.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Check a model's accuracy on a test or val split of a dataset
|
4 |
+
|
5 |
+
Usage:
|
6 |
+
$ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
|
7 |
+
|
8 |
+
Usage - formats:
|
9 |
+
$ yolo mode=val model=yolov8n.pt # PyTorch
|
10 |
+
yolov8n.torchscript # TorchScript
|
11 |
+
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
12 |
+
yolov8n_openvino_model # OpenVINO
|
13 |
+
yolov8n.engine # TensorRT
|
14 |
+
yolov8n.mlmodel # CoreML (macOS-only)
|
15 |
+
yolov8n_saved_model # TensorFlow SavedModel
|
16 |
+
yolov8n.pb # TensorFlow GraphDef
|
17 |
+
yolov8n.tflite # TensorFlow Lite
|
18 |
+
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
19 |
+
yolov8n_paddle_model # PaddlePaddle
|
20 |
+
"""
|
21 |
+
import json
|
22 |
+
import time
|
23 |
+
from pathlib import Path
|
24 |
+
|
25 |
+
import torch
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
from ultralytics.cfg import get_cfg
|
29 |
+
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
30 |
+
from ultralytics.nn.autobackend import AutoBackend
|
31 |
+
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
32 |
+
from ultralytics.utils.checks import check_imgsz
|
33 |
+
from ultralytics.utils.files import increment_path
|
34 |
+
from ultralytics.utils.ops import Profile
|
35 |
+
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
36 |
+
|
37 |
+
|
38 |
+
class BaseValidator:
|
39 |
+
"""
|
40 |
+
BaseValidator
|
41 |
+
|
42 |
+
A base class for creating validators.
|
43 |
+
|
44 |
+
Attributes:
|
45 |
+
dataloader (DataLoader): Dataloader to use for validation.
|
46 |
+
pbar (tqdm): Progress bar to update during validation.
|
47 |
+
args (SimpleNamespace): Configuration for the validator.
|
48 |
+
model (nn.Module): Model to validate.
|
49 |
+
data (dict): Data dictionary.
|
50 |
+
device (torch.device): Device to use for validation.
|
51 |
+
batch_i (int): Current batch index.
|
52 |
+
training (bool): Whether the model is in training mode.
|
53 |
+
speed (float): Batch processing speed in seconds.
|
54 |
+
jdict (dict): Dictionary to store validation results.
|
55 |
+
save_dir (Path): Directory to save results.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
59 |
+
"""
|
60 |
+
Initializes a BaseValidator instance.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
64 |
+
save_dir (Path): Directory to save results.
|
65 |
+
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
66 |
+
args (SimpleNamespace): Configuration for the validator.
|
67 |
+
"""
|
68 |
+
self.dataloader = dataloader
|
69 |
+
self.pbar = pbar
|
70 |
+
self.args = args or get_cfg(DEFAULT_CFG)
|
71 |
+
self.model = None
|
72 |
+
self.data = None
|
73 |
+
self.device = None
|
74 |
+
self.batch_i = None
|
75 |
+
self.training = True
|
76 |
+
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
77 |
+
self.jdict = None
|
78 |
+
|
79 |
+
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
80 |
+
name = self.args.name or f'{self.args.mode}'
|
81 |
+
self.save_dir = save_dir or increment_path(Path(project) / name,
|
82 |
+
exist_ok=self.args.exist_ok if RANK in (-1, 0) else True)
|
83 |
+
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
84 |
+
|
85 |
+
if self.args.conf is None:
|
86 |
+
self.args.conf = 0.001 # default conf=0.001
|
87 |
+
|
88 |
+
self.plots = {}
|
89 |
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
90 |
+
|
91 |
+
@smart_inference_mode()
|
92 |
+
def __call__(self, trainer=None, model=None):
|
93 |
+
"""
|
94 |
+
Supports validation of a pre-trained model if passed or a model being trained
|
95 |
+
if trainer is passed (trainer gets priority).
|
96 |
+
"""
|
97 |
+
self.training = trainer is not None
|
98 |
+
augment = self.args.augment and (not self.training)
|
99 |
+
if self.training:
|
100 |
+
self.device = trainer.device
|
101 |
+
self.data = trainer.data
|
102 |
+
model = trainer.ema.ema or trainer.model
|
103 |
+
self.args.half = self.device.type != 'cpu' # force FP16 val during training
|
104 |
+
model = model.half() if self.args.half else model.float()
|
105 |
+
self.model = model
|
106 |
+
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
107 |
+
self.args.plots = trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
108 |
+
model.eval()
|
109 |
+
else:
|
110 |
+
callbacks.add_integration_callbacks(self)
|
111 |
+
self.run_callbacks('on_val_start')
|
112 |
+
assert model is not None, 'Either trainer or model is needed for validation'
|
113 |
+
model = AutoBackend(model,
|
114 |
+
device=select_device(self.args.device, self.args.batch),
|
115 |
+
dnn=self.args.dnn,
|
116 |
+
data=self.args.data,
|
117 |
+
fp16=self.args.half)
|
118 |
+
self.model = model
|
119 |
+
self.device = model.device # update device
|
120 |
+
self.args.half = model.fp16 # update half
|
121 |
+
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
122 |
+
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
123 |
+
if engine:
|
124 |
+
self.args.batch = model.batch_size
|
125 |
+
elif not pt and not jit:
|
126 |
+
self.args.batch = 1 # export.py models default to batch-size 1
|
127 |
+
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
128 |
+
|
129 |
+
if isinstance(self.args.data, str) and self.args.data.split('.')[-1] in ('yaml', 'yml'):
|
130 |
+
self.data = check_det_dataset(self.args.data)
|
131 |
+
elif self.args.task == 'classify':
|
132 |
+
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
133 |
+
else:
|
134 |
+
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
135 |
+
|
136 |
+
if self.device.type == 'cpu':
|
137 |
+
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
138 |
+
if not pt:
|
139 |
+
self.args.rect = False
|
140 |
+
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
141 |
+
|
142 |
+
model.eval()
|
143 |
+
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
|
144 |
+
|
145 |
+
dt = Profile(), Profile(), Profile(), Profile()
|
146 |
+
n_batches = len(self.dataloader)
|
147 |
+
desc = self.get_desc()
|
148 |
+
# NOTE: keeping `not self.training` in tqdm will eliminate pbar after segmentation evaluation during training,
|
149 |
+
# which may affect classification task since this arg is in yolov5/classify/val.py.
|
150 |
+
# bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
|
151 |
+
bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
|
152 |
+
self.init_metrics(de_parallel(model))
|
153 |
+
self.jdict = [] # empty before each val
|
154 |
+
for batch_i, batch in enumerate(bar):
|
155 |
+
self.run_callbacks('on_val_batch_start')
|
156 |
+
self.batch_i = batch_i
|
157 |
+
# Preprocess
|
158 |
+
with dt[0]:
|
159 |
+
batch = self.preprocess(batch)
|
160 |
+
|
161 |
+
# Inference
|
162 |
+
with dt[1]:
|
163 |
+
preds = model(batch['img'], augment=augment)
|
164 |
+
|
165 |
+
# Loss
|
166 |
+
with dt[2]:
|
167 |
+
if self.training:
|
168 |
+
self.loss += model.loss(batch, preds)[1]
|
169 |
+
|
170 |
+
# Postprocess
|
171 |
+
with dt[3]:
|
172 |
+
preds = self.postprocess(preds)
|
173 |
+
|
174 |
+
self.update_metrics(preds, batch)
|
175 |
+
if self.args.plots and batch_i < 3:
|
176 |
+
self.plot_val_samples(batch, batch_i)
|
177 |
+
self.plot_predictions(batch, preds, batch_i)
|
178 |
+
|
179 |
+
self.run_callbacks('on_val_batch_end')
|
180 |
+
stats = self.get_stats()
|
181 |
+
self.check_stats(stats)
|
182 |
+
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
|
183 |
+
self.finalize_metrics()
|
184 |
+
self.print_results()
|
185 |
+
self.run_callbacks('on_val_end')
|
186 |
+
if self.training:
|
187 |
+
model.float()
|
188 |
+
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
189 |
+
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
190 |
+
else:
|
191 |
+
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
192 |
+
tuple(self.speed.values()))
|
193 |
+
if self.args.save_json and self.jdict:
|
194 |
+
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
195 |
+
LOGGER.info(f'Saving {f.name}...')
|
196 |
+
json.dump(self.jdict, f) # flatten and save
|
197 |
+
stats = self.eval_json(stats) # update stats
|
198 |
+
if self.args.plots or self.args.save_json:
|
199 |
+
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
200 |
+
return stats
|
201 |
+
|
202 |
+
def add_callback(self, event: str, callback):
|
203 |
+
"""Appends the given callback."""
|
204 |
+
self.callbacks[event].append(callback)
|
205 |
+
|
206 |
+
def run_callbacks(self, event: str):
|
207 |
+
"""Runs all callbacks associated with a specified event."""
|
208 |
+
for callback in self.callbacks.get(event, []):
|
209 |
+
callback(self)
|
210 |
+
|
211 |
+
def get_dataloader(self, dataset_path, batch_size):
|
212 |
+
"""Get data loader from dataset path and batch size."""
|
213 |
+
raise NotImplementedError('get_dataloader function not implemented for this validator')
|
214 |
+
|
215 |
+
def build_dataset(self, img_path):
|
216 |
+
"""Build dataset"""
|
217 |
+
raise NotImplementedError('build_dataset function not implemented in validator')
|
218 |
+
|
219 |
+
def preprocess(self, batch):
|
220 |
+
"""Preprocesses an input batch."""
|
221 |
+
return batch
|
222 |
+
|
223 |
+
def postprocess(self, preds):
|
224 |
+
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
|
225 |
+
return preds
|
226 |
+
|
227 |
+
def init_metrics(self, model):
|
228 |
+
"""Initialize performance metrics for the YOLO model."""
|
229 |
+
pass
|
230 |
+
|
231 |
+
def update_metrics(self, preds, batch):
|
232 |
+
"""Updates metrics based on predictions and batch."""
|
233 |
+
pass
|
234 |
+
|
235 |
+
def finalize_metrics(self, *args, **kwargs):
|
236 |
+
"""Finalizes and returns all metrics."""
|
237 |
+
pass
|
238 |
+
|
239 |
+
def get_stats(self):
|
240 |
+
"""Returns statistics about the model's performance."""
|
241 |
+
return {}
|
242 |
+
|
243 |
+
def check_stats(self, stats):
|
244 |
+
"""Checks statistics."""
|
245 |
+
pass
|
246 |
+
|
247 |
+
def print_results(self):
|
248 |
+
"""Prints the results of the model's predictions."""
|
249 |
+
pass
|
250 |
+
|
251 |
+
def get_desc(self):
|
252 |
+
"""Get description of the YOLO model."""
|
253 |
+
pass
|
254 |
+
|
255 |
+
@property
|
256 |
+
def metric_keys(self):
|
257 |
+
"""Returns the metric keys used in YOLO training/validation."""
|
258 |
+
return []
|
259 |
+
|
260 |
+
def on_plot(self, name, data=None):
|
261 |
+
"""Registers plots (e.g. to be consumed in callbacks)"""
|
262 |
+
self.plots[name] = {'data': data, 'timestamp': time.time()}
|
263 |
+
|
264 |
+
# TODO: may need to put these following functions into callback
|
265 |
+
def plot_val_samples(self, batch, ni):
|
266 |
+
"""Plots validation samples during training."""
|
267 |
+
pass
|
268 |
+
|
269 |
+
def plot_predictions(self, batch, preds, ni):
|
270 |
+
"""Plots YOLO model predictions on batch images."""
|
271 |
+
pass
|
272 |
+
|
273 |
+
def pred_to_json(self, preds, batch):
|
274 |
+
"""Convert predictions to JSON format."""
|
275 |
+
pass
|
276 |
+
|
277 |
+
def eval_json(self, stats):
|
278 |
+
"""Evaluate and return JSON format of prediction statistics."""
|
279 |
+
pass
|
ultralytics/hub/__init__.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import requests
|
4 |
+
|
5 |
+
from ultralytics.data.utils import HUBDatasetStats
|
6 |
+
from ultralytics.hub.auth import Auth
|
7 |
+
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
8 |
+
from ultralytics.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
|
9 |
+
|
10 |
+
|
11 |
+
def login(api_key=''):
|
12 |
+
"""
|
13 |
+
Log in to the Ultralytics HUB API using the provided API key.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
17 |
+
|
18 |
+
Example:
|
19 |
+
```python
|
20 |
+
from ultralytics import hub
|
21 |
+
hub.login('API_KEY')
|
22 |
+
```
|
23 |
+
"""
|
24 |
+
Auth(api_key, verbose=True)
|
25 |
+
|
26 |
+
|
27 |
+
def logout():
|
28 |
+
"""
|
29 |
+
Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'.
|
30 |
+
|
31 |
+
Example:
|
32 |
+
```python
|
33 |
+
from ultralytics import hub
|
34 |
+
hub.logout()
|
35 |
+
```
|
36 |
+
"""
|
37 |
+
SETTINGS['api_key'] = ''
|
38 |
+
yaml_save(USER_CONFIG_DIR / 'settings.yaml', SETTINGS)
|
39 |
+
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
|
40 |
+
|
41 |
+
|
42 |
+
def start(key=''):
|
43 |
+
"""
|
44 |
+
Start training models with Ultralytics HUB (DEPRECATED).
|
45 |
+
|
46 |
+
Args:
|
47 |
+
key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
|
48 |
+
or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
|
49 |
+
"""
|
50 |
+
api_key, model_id = key.split('_')
|
51 |
+
LOGGER.warning(f"""
|
52 |
+
WARNING ⚠️ ultralytics.start() is deprecated after 8.0.60. Updated usage to train Ultralytics HUB models is:
|
53 |
+
|
54 |
+
from ultralytics import YOLO, hub
|
55 |
+
|
56 |
+
hub.login('{api_key}')
|
57 |
+
model = YOLO('{HUB_WEB_ROOT}/models/{model_id}')
|
58 |
+
model.train()""")
|
59 |
+
|
60 |
+
|
61 |
+
def reset_model(model_id=''):
|
62 |
+
"""Reset a trained model to an untrained state."""
|
63 |
+
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
|
64 |
+
if r.status_code == 200:
|
65 |
+
LOGGER.info(f'{PREFIX}Model reset successfully')
|
66 |
+
return
|
67 |
+
LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
|
68 |
+
|
69 |
+
|
70 |
+
def export_fmts_hub():
|
71 |
+
"""Returns a list of HUB-supported export formats."""
|
72 |
+
from ultralytics.engine.exporter import export_formats
|
73 |
+
return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
|
74 |
+
|
75 |
+
|
76 |
+
def export_model(model_id='', format='torchscript'):
|
77 |
+
"""Export a model to all formats."""
|
78 |
+
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
79 |
+
r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export',
|
80 |
+
json={'format': format},
|
81 |
+
headers={'x-api-key': Auth().api_key})
|
82 |
+
assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
|
83 |
+
LOGGER.info(f'{PREFIX}{format} export started ✅')
|
84 |
+
|
85 |
+
|
86 |
+
def get_export(model_id='', format='torchscript'):
|
87 |
+
"""Get an exported model dictionary with download URL."""
|
88 |
+
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
89 |
+
r = requests.post(f'{HUB_API_ROOT}/get-export',
|
90 |
+
json={
|
91 |
+
'apiKey': Auth().api_key,
|
92 |
+
'modelId': model_id,
|
93 |
+
'format': format})
|
94 |
+
assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
|
95 |
+
return r.json()
|
96 |
+
|
97 |
+
|
98 |
+
def check_dataset(path='', task='detect'):
|
99 |
+
"""
|
100 |
+
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is
|
101 |
+
uploaded to the HUB. Usage examples are given below.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''.
|
105 |
+
task (str, optional): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Defaults to 'detect'.
|
106 |
+
|
107 |
+
Example:
|
108 |
+
```python
|
109 |
+
from ultralytics.hub import check_dataset
|
110 |
+
|
111 |
+
check_dataset('path/to/coco8.zip', task='detect') # detect dataset
|
112 |
+
check_dataset('path/to/coco8-seg.zip', task='segment') # segment dataset
|
113 |
+
check_dataset('path/to/coco8-pose.zip', task='pose') # pose dataset
|
114 |
+
```
|
115 |
+
"""
|
116 |
+
HUBDatasetStats(path=path, task=task).get_json()
|
117 |
+
LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
start()
|
ultralytics/hub/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (4.91 kB). View file
|
|
ultralytics/hub/__pycache__/auth.cpython-39.pyc
ADDED
Binary file (4.2 kB). View file
|
|