llzzyy233 commited on
Commit
0f09c5d
·
verified ·
1 Parent(s): 89fd19c

上传改进后的yolov8

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ultralytics/__init__.py +13 -0
  2. ultralytics/__pycache__/__init__.cpython-39.pyc +0 -0
  3. ultralytics/cfg/__init__.py +441 -0
  4. ultralytics/cfg/__pycache__/__init__.cpython-39.pyc +0 -0
  5. ultralytics/cfg/default.yaml +114 -0
  6. ultralytics/cfg/models/v8/yolov8.yaml +46 -0
  7. ultralytics/cfg/models/v8/yolov8_ECA.yaml +50 -0
  8. ultralytics/cfg/models/v8/yolov8_GAM.yaml +50 -0
  9. ultralytics/cfg/models/v8/yolov8_ResBlock_CBAM.yaml +50 -0
  10. ultralytics/cfg/models/v8/yolov8_SA.yaml +50 -0
  11. ultralytics/cfg/trackers/botsort.yaml +18 -0
  12. ultralytics/cfg/trackers/bytetrack.yaml +11 -0
  13. ultralytics/data/__init__.py +8 -0
  14. ultralytics/data/__pycache__/__init__.cpython-39.pyc +0 -0
  15. ultralytics/data/__pycache__/augment.cpython-39.pyc +0 -0
  16. ultralytics/data/__pycache__/base.cpython-39.pyc +0 -0
  17. ultralytics/data/__pycache__/build.cpython-39.pyc +0 -0
  18. ultralytics/data/__pycache__/dataset.cpython-39.pyc +0 -0
  19. ultralytics/data/__pycache__/loaders.cpython-39.pyc +0 -0
  20. ultralytics/data/__pycache__/utils.cpython-39.pyc +0 -0
  21. ultralytics/data/annotator.py +39 -0
  22. ultralytics/data/augment.py +906 -0
  23. ultralytics/data/base.py +287 -0
  24. ultralytics/data/build.py +170 -0
  25. ultralytics/data/converter.py +230 -0
  26. ultralytics/data/dataloaders/__init__.py +0 -0
  27. ultralytics/data/dataset.py +275 -0
  28. ultralytics/data/loaders.py +407 -0
  29. ultralytics/data/scripts/download_weights.sh +18 -0
  30. ultralytics/data/scripts/get_coco.sh +60 -0
  31. ultralytics/data/scripts/get_coco128.sh +17 -0
  32. ultralytics/data/scripts/get_imagenet.sh +51 -0
  33. ultralytics/data/utils.py +557 -0
  34. ultralytics/engine/__init__.py +0 -0
  35. ultralytics/engine/__pycache__/__init__.cpython-39.pyc +0 -0
  36. ultralytics/engine/__pycache__/exporter.cpython-39.pyc +0 -0
  37. ultralytics/engine/__pycache__/model.cpython-39.pyc +0 -0
  38. ultralytics/engine/__pycache__/predictor.cpython-39.pyc +0 -0
  39. ultralytics/engine/__pycache__/results.cpython-39.pyc +0 -0
  40. ultralytics/engine/__pycache__/trainer.cpython-39.pyc +0 -0
  41. ultralytics/engine/__pycache__/validator.cpython-39.pyc +0 -0
  42. ultralytics/engine/exporter.py +969 -0
  43. ultralytics/engine/model.py +465 -0
  44. ultralytics/engine/predictor.py +359 -0
  45. ultralytics/engine/results.py +604 -0
  46. ultralytics/engine/trainer.py +664 -0
  47. ultralytics/engine/validator.py +279 -0
  48. ultralytics/hub/__init__.py +121 -0
  49. ultralytics/hub/__pycache__/__init__.cpython-39.pyc +0 -0
  50. ultralytics/hub/__pycache__/auth.cpython-39.pyc +0 -0
ultralytics/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ __version__ = '8.0.147'
4
+
5
+ from ultralytics.hub import start
6
+ from ultralytics.models import RTDETR, SAM, YOLO
7
+ from ultralytics.models.fastsam import FastSAM
8
+ from ultralytics.models.nas import NAS
9
+ from ultralytics.utils import SETTINGS as settings
10
+ from ultralytics.utils.checks import check_yolo as checks
11
+ from ultralytics.utils.downloads import download
12
+
13
+ __all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'start', 'settings' # allow simpler import
ultralytics/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (681 Bytes). View file
 
ultralytics/cfg/__init__.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import contextlib
4
+ import re
5
+ import shutil
6
+ import sys
7
+ from difflib import get_close_matches
8
+ from pathlib import Path
9
+ from types import SimpleNamespace
10
+ from typing import Dict, List, Union
11
+
12
+ from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, SETTINGS, SETTINGS_YAML,
13
+ IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load,
14
+ yaml_print)
15
+
16
+ # Define valid tasks and modes
17
+ MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
18
+ TASKS = 'detect', 'segment', 'classify', 'pose'
19
+ TASK2DATA = {'detect': 'coco8.yaml', 'segment': 'coco8-seg.yaml', 'classify': 'imagenet100', 'pose': 'coco8-pose.yaml'}
20
+ TASK2MODEL = {
21
+ 'detect': 'yolov8n.pt',
22
+ 'segment': 'yolov8n-seg.pt',
23
+ 'classify': 'yolov8n-cls.pt',
24
+ 'pose': 'yolov8n-pose.pt'}
25
+ TASK2METRIC = {
26
+ 'detect': 'metrics/mAP50-95(B)',
27
+ 'segment': 'metrics/mAP50-95(M)',
28
+ 'classify': 'metrics/accuracy_top1',
29
+ 'pose': 'metrics/mAP50-95(P)'}
30
+
31
+ CLI_HELP_MSG = \
32
+ f"""
33
+ Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
34
+
35
+ yolo TASK MODE ARGS
36
+
37
+ Where TASK (optional) is one of {TASKS}
38
+ MODE (required) is one of {MODES}
39
+ ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
40
+ See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
41
+
42
+ 1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
43
+ yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
44
+
45
+ 2. Predict a YouTube video using a pretrained segmentation model at image size 320:
46
+ yolo predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
47
+
48
+ 3. Val a pretrained detection model at batch-size 1 and image size 640:
49
+ yolo val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
50
+
51
+ 4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
52
+ yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
53
+
54
+ 5. Run special commands:
55
+ yolo help
56
+ yolo checks
57
+ yolo version
58
+ yolo settings
59
+ yolo copy-cfg
60
+ yolo cfg
61
+
62
+ Docs: https://docs.ultralytics.com
63
+ Community: https://community.ultralytics.com
64
+ GitHub: https://github.com/ultralytics/ultralytics
65
+ """
66
+
67
+ # Define keys for arg type checks
68
+ CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
69
+ CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
70
+ 'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
71
+ 'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
72
+ CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
73
+ 'line_width', 'workspace', 'nbs', 'save_period')
74
+ CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
75
+ 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
76
+ 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
77
+ 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile')
78
+
79
+
80
+ def cfg2dict(cfg):
81
+ """
82
+ Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
83
+
84
+ Args:
85
+ cfg (str | Path | SimpleNamespace): Configuration object to be converted to a dictionary.
86
+
87
+ Returns:
88
+ cfg (dict): Configuration object in dictionary format.
89
+ """
90
+ if isinstance(cfg, (str, Path)):
91
+ cfg = yaml_load(cfg) # load dict
92
+ elif isinstance(cfg, SimpleNamespace):
93
+ cfg = vars(cfg) # convert to dict
94
+ return cfg
95
+
96
+
97
+ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
98
+ """
99
+ Load and merge configuration data from a file or dictionary.
100
+
101
+ Args:
102
+ cfg (str | Path | Dict | SimpleNamespace): Configuration data.
103
+ overrides (str | Dict | optional): Overrides in the form of a file name or a dictionary. Default is None.
104
+
105
+ Returns:
106
+ (SimpleNamespace): Training arguments namespace.
107
+ """
108
+ cfg = cfg2dict(cfg)
109
+
110
+ # Merge overrides
111
+ if overrides:
112
+ overrides = cfg2dict(overrides)
113
+ check_dict_alignment(cfg, overrides)
114
+ cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
115
+
116
+ # Special handling for numeric project/name
117
+ for k in 'project', 'name':
118
+ if k in cfg and isinstance(cfg[k], (int, float)):
119
+ cfg[k] = str(cfg[k])
120
+ if cfg.get('name') == 'model': # assign model to 'name' arg
121
+ cfg['name'] = cfg.get('model', '').split('.')[0]
122
+ LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
123
+
124
+ # Type and Value checks
125
+ for k, v in cfg.items():
126
+ if v is not None: # None values may be from optional args
127
+ if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
128
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
129
+ f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
130
+ elif k in CFG_FRACTION_KEYS:
131
+ if not isinstance(v, (int, float)):
132
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
133
+ f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
134
+ if not (0.0 <= v <= 1.0):
135
+ raise ValueError(f"'{k}={v}' is an invalid value. "
136
+ f"Valid '{k}' values are between 0.0 and 1.0.")
137
+ elif k in CFG_INT_KEYS and not isinstance(v, int):
138
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
139
+ f"'{k}' must be an int (i.e. '{k}=8')")
140
+ elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
141
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
142
+ f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
143
+
144
+ # Return instance
145
+ return IterableSimpleNamespace(**cfg)
146
+
147
+
148
+ def _handle_deprecation(custom):
149
+ """Hardcoded function to handle deprecated config keys"""
150
+
151
+ for key in custom.copy().keys():
152
+ if key == 'hide_labels':
153
+ deprecation_warn(key, 'show_labels')
154
+ custom['show_labels'] = custom.pop('hide_labels') == 'False'
155
+ if key == 'hide_conf':
156
+ deprecation_warn(key, 'show_conf')
157
+ custom['show_conf'] = custom.pop('hide_conf') == 'False'
158
+ if key == 'line_thickness':
159
+ deprecation_warn(key, 'line_width')
160
+ custom['line_width'] = custom.pop('line_thickness')
161
+
162
+ return custom
163
+
164
+
165
+ def check_dict_alignment(base: Dict, custom: Dict, e=None):
166
+ """
167
+ This function checks for any mismatched keys between a custom configuration list and a base configuration list.
168
+ If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
169
+
170
+ Args:
171
+ custom (dict): a dictionary of custom configuration options
172
+ base (dict): a dictionary of base configuration options
173
+ """
174
+ custom = _handle_deprecation(custom)
175
+ base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
176
+ mismatched = [k for k in custom_keys if k not in base_keys]
177
+ if mismatched:
178
+ string = ''
179
+ for x in mismatched:
180
+ matches = get_close_matches(x, base_keys) # key list
181
+ matches = [f'{k}={base[k]}' if base.get(k) is not None else k for k in matches]
182
+ match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
183
+ string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
184
+ raise SyntaxError(string + CLI_HELP_MSG) from e
185
+
186
+
187
+ def merge_equals_args(args: List[str]) -> List[str]:
188
+ """
189
+ Merges arguments around isolated '=' args in a list of strings.
190
+ The function considers cases where the first argument ends with '=' or the second starts with '=',
191
+ as well as when the middle one is an equals sign.
192
+
193
+ Args:
194
+ args (List[str]): A list of strings where each element is an argument.
195
+
196
+ Returns:
197
+ List[str]: A list of strings where the arguments around isolated '=' are merged.
198
+ """
199
+ new_args = []
200
+ for i, arg in enumerate(args):
201
+ if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
202
+ new_args[-1] += f'={args[i + 1]}'
203
+ del args[i + 1]
204
+ elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
205
+ new_args.append(f'{arg}{args[i + 1]}')
206
+ del args[i + 1]
207
+ elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
208
+ new_args[-1] += arg
209
+ else:
210
+ new_args.append(arg)
211
+ return new_args
212
+
213
+
214
+ def handle_yolo_hub(args: List[str]) -> None:
215
+ """
216
+ Handle Ultralytics HUB command-line interface (CLI) commands.
217
+
218
+ This function processes Ultralytics HUB CLI commands such as login and logout.
219
+ It should be called when executing a script with arguments related to HUB authentication.
220
+
221
+ Args:
222
+ args (List[str]): A list of command line arguments
223
+
224
+ Example:
225
+ ```python
226
+ python my_script.py hub login your_api_key
227
+ ```
228
+ """
229
+ from ultralytics import hub
230
+
231
+ if args[0] == 'login':
232
+ key = args[1] if len(args) > 1 else ''
233
+ # Log in to Ultralytics HUB using the provided API key
234
+ hub.login(key)
235
+ elif args[0] == 'logout':
236
+ # Log out from Ultralytics HUB
237
+ hub.logout()
238
+
239
+
240
+ def handle_yolo_settings(args: List[str]) -> None:
241
+ """
242
+ Handle YOLO settings command-line interface (CLI) commands.
243
+
244
+ This function processes YOLO settings CLI commands such as reset.
245
+ It should be called when executing a script with arguments related to YOLO settings management.
246
+
247
+ Args:
248
+ args (List[str]): A list of command line arguments for YOLO settings management.
249
+
250
+ Example:
251
+ ```python
252
+ python my_script.py yolo settings reset
253
+ ```
254
+ """
255
+ if any(args):
256
+ if args[0] == 'reset':
257
+ SETTINGS_YAML.unlink() # delete the settings file
258
+ SETTINGS.reset() # create new settings
259
+ LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
260
+ else: # save a new setting
261
+ new = dict(parse_key_value_pair(a) for a in args)
262
+ check_dict_alignment(SETTINGS, new)
263
+ SETTINGS.update(new)
264
+
265
+ yaml_print(SETTINGS_YAML) # print the current settings
266
+
267
+
268
+ def parse_key_value_pair(pair):
269
+ """Parse one 'key=value' pair and return key and value."""
270
+ re.sub(r' *= *', '=', pair) # remove spaces around equals sign
271
+ k, v = pair.split('=', 1) # split on first '=' sign
272
+ assert v, f"missing '{k}' value"
273
+ return k, smart_value(v)
274
+
275
+
276
+ def smart_value(v):
277
+ """Convert a string to an underlying type such as int, float, bool, etc."""
278
+ if v.lower() == 'none':
279
+ return None
280
+ elif v.lower() == 'true':
281
+ return True
282
+ elif v.lower() == 'false':
283
+ return False
284
+ else:
285
+ with contextlib.suppress(Exception):
286
+ return eval(v)
287
+ return v
288
+
289
+
290
+ def entrypoint(debug=''):
291
+ """
292
+ This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
293
+ to the package.
294
+
295
+ This function allows for:
296
+ - passing mandatory YOLO args as a list of strings
297
+ - specifying the task to be performed, either 'detect', 'segment' or 'classify'
298
+ - specifying the mode, either 'train', 'val', 'test', or 'predict'
299
+ - running special modes like 'checks'
300
+ - passing overrides to the package's configuration
301
+
302
+ It uses the package's default cfg and initializes it using the passed overrides.
303
+ Then it calls the CLI function with the composed cfg
304
+ """
305
+ args = (debug.split(' ') if debug else sys.argv)[1:]
306
+ if not args: # no arguments passed
307
+ LOGGER.info(CLI_HELP_MSG)
308
+ return
309
+
310
+ special = {
311
+ 'help': lambda: LOGGER.info(CLI_HELP_MSG),
312
+ 'checks': checks.check_yolo,
313
+ 'version': lambda: LOGGER.info(__version__),
314
+ 'settings': lambda: handle_yolo_settings(args[1:]),
315
+ 'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
316
+ 'hub': lambda: handle_yolo_hub(args[1:]),
317
+ 'login': lambda: handle_yolo_hub(args),
318
+ 'copy-cfg': copy_default_cfg}
319
+ full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
320
+
321
+ # Define common mis-uses of special commands, i.e. -h, -help, --help
322
+ special.update({k[0]: v for k, v in special.items()}) # singular
323
+ special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular
324
+ special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
325
+
326
+ overrides = {} # basic overrides, i.e. imgsz=320
327
+ for a in merge_equals_args(args): # merge spaces around '=' sign
328
+ if a.startswith('--'):
329
+ LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
330
+ a = a[2:]
331
+ if a.endswith(','):
332
+ LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
333
+ a = a[:-1]
334
+ if '=' in a:
335
+ try:
336
+ k, v = parse_key_value_pair(a)
337
+ if k == 'cfg': # custom.yaml passed
338
+ LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
339
+ overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
340
+ else:
341
+ overrides[k] = v
342
+ except (NameError, SyntaxError, ValueError, AssertionError) as e:
343
+ check_dict_alignment(full_args_dict, {a: ''}, e)
344
+
345
+ elif a in TASKS:
346
+ overrides['task'] = a
347
+ elif a in MODES:
348
+ overrides['mode'] = a
349
+ elif a.lower() in special:
350
+ special[a.lower()]()
351
+ return
352
+ elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
353
+ overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
354
+ elif a in DEFAULT_CFG_DICT:
355
+ raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
356
+ f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
357
+ else:
358
+ check_dict_alignment(full_args_dict, {a: ''})
359
+
360
+ # Check keys
361
+ check_dict_alignment(full_args_dict, overrides)
362
+
363
+ # Mode
364
+ mode = overrides.get('mode')
365
+ if mode is None:
366
+ mode = DEFAULT_CFG.mode or 'predict'
367
+ LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
368
+ elif mode not in MODES:
369
+ if mode not in ('checks', checks):
370
+ raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
371
+ LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
372
+ checks.check_yolo()
373
+ return
374
+
375
+ # Task
376
+ task = overrides.pop('task', None)
377
+ if task:
378
+ if task not in TASKS:
379
+ raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
380
+ if 'model' not in overrides:
381
+ overrides['model'] = TASK2MODEL[task]
382
+
383
+ # Model
384
+ model = overrides.pop('model', DEFAULT_CFG.model)
385
+ if model is None:
386
+ model = 'yolov8n.pt'
387
+ LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
388
+ overrides['model'] = model
389
+ if 'rtdetr' in model.lower(): # guess architecture
390
+ from ultralytics import RTDETR
391
+ model = RTDETR(model) # no task argument
392
+ elif 'fastsam' in model.lower():
393
+ from ultralytics import FastSAM
394
+ model = FastSAM(model)
395
+ elif 'sam' in model.lower():
396
+ from ultralytics import SAM
397
+ model = SAM(model)
398
+ else:
399
+ from ultralytics import YOLO
400
+ model = YOLO(model, task=task)
401
+ if isinstance(overrides.get('pretrained'), str):
402
+ model.load(overrides['pretrained'])
403
+
404
+ # Task Update
405
+ if task != model.task:
406
+ if task:
407
+ LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
408
+ f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
409
+ task = model.task
410
+
411
+ # Mode
412
+ if mode in ('predict', 'track') and 'source' not in overrides:
413
+ overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
414
+ else 'https://ultralytics.com/images/bus.jpg'
415
+ LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
416
+ elif mode in ('train', 'val'):
417
+ if 'data' not in overrides:
418
+ overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
419
+ LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
420
+ elif mode == 'export':
421
+ if 'format' not in overrides:
422
+ overrides['format'] = DEFAULT_CFG.format or 'torchscript'
423
+ LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
424
+
425
+ # Run command in python
426
+ # getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
427
+ getattr(model, mode)(**overrides) # default args from model
428
+
429
+
430
+ # Special modes --------------------------------------------------------------------------------------------------------
431
+ def copy_default_cfg():
432
+ """Copy and create a new default configuration file with '_copy' appended to its name."""
433
+ new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
434
+ shutil.copy2(DEFAULT_CFG_PATH, new_file)
435
+ LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
436
+ f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
437
+
438
+
439
+ if __name__ == '__main__':
440
+ # Example Usage: entrypoint(debug='yolo predict model=yolov8n.pt')
441
+ entrypoint(debug='')
ultralytics/cfg/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (16.3 kB). View file
 
ultralytics/cfg/default.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # Default training settings and hyperparameters for medium-augmentation COCO training
3
+
4
+ task: detect # (str) YOLO task, i.e. detect, segment, classify, pose
5
+ mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark
6
+
7
+ # Train settings -------------------------------------------------------------------------------------------------------
8
+ model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
9
+ data: # (str, optional) path to data file, i.e. coco128.yaml
10
+ epochs: 100 # (int) number of epochs to train for
11
+ patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
12
+ batch: -1 # (int) number of images per batch (-1 for AutoBatch)
13
+ imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
14
+ save: True # (bool) save train checkpoints and predict results
15
+ save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1)
16
+ cache: False # (bool) True/ram, disk or False. Use cache for data loading
17
+ device: 0 # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
18
+ workers: 2 # (int) number of worker threads for data loading (per RANK if DDP)
19
+ project: # (str, optional) project name
20
+ name: # (str, optional) experiment name, results saved to 'project/name' directory
21
+ exist_ok: True # (bool) whether to overwrite existing experiment
22
+ pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str)
23
+ optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
24
+ verbose: True # (bool) whether to print verbose output
25
+ seed: 0 # (int) random seed for reproducibility
26
+ deterministic: True # (bool) whether to enable deterministic mode
27
+ single_cls: False # (bool) train multi-class data as single-class
28
+ rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
29
+ cos_lr: False # (bool) use cosine learning rate scheduler
30
+ close_mosaic: 10 # (int) disable mosaic augmentation for final epochs
31
+ resume: False # (bool) resume training from last checkpoint
32
+ amp: False # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
33
+ fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
34
+ profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
35
+ # Segmentation
36
+ overlap_mask: True # (bool) masks should overlap during training (segment train only)
37
+ mask_ratio: 4 # (int) mask downsample ratio (segment train only)
38
+ # Classification
39
+ dropout: 0.0 # (float) use dropout regularization (classify train only)
40
+
41
+ # Val/Test settings ----------------------------------------------------------------------------------------------------
42
+ val: True # (bool) validate/test during training
43
+ split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train'
44
+ save_json: True # (bool) save results to JSON file
45
+ save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions)
46
+ conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val)
47
+ iou: 0.7 # (float) intersection over union (IoU) threshold for NMS
48
+ max_det: 300 # (int) maximum number of detections per image
49
+ half: False # (bool) use half precision (FP16)
50
+ dnn: False # (bool) use OpenCV DNN for ONNX inference
51
+ plots: True # (bool) save plots during train/val
52
+
53
+ # Prediction settings --------------------------------------------------------------------------------------------------
54
+ source: # (str, optional) source directory for images or videos
55
+ show: False # (bool) show results if possible
56
+ save_txt: False # (bool) save results as .txt file
57
+ save_conf: False # (bool) save results with confidence scores
58
+ save_crop: False # (bool) save cropped images with results
59
+ show_labels: True # (bool) show object labels in plots
60
+ show_conf: True # (bool) show object confidence scores in plots
61
+ vid_stride: 1 # (int) video frame-rate stride
62
+ line_width: # (int, optional) line width of the bounding boxes, auto if missing
63
+ visualize: False # (bool) visualize model features
64
+ augment: False # (bool) apply image augmentation to prediction sources
65
+ agnostic_nms: False # (bool) class-agnostic NMS
66
+ classes: # (int | list[int], optional) filter results by class, i.e. class=0, or class=[0,2,3]
67
+ retina_masks: False # (bool) use high-resolution segmentation masks
68
+ boxes: True # (bool) Show boxes in segmentation predictions
69
+
70
+ # Export settings ------------------------------------------------------------------------------------------------------
71
+ format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
72
+ keras: False # (bool) use Kera=s
73
+ optimize: False # (bool) TorchScript: optimize for mobile
74
+ int8: False # (bool) CoreML/TF INT8 quantization
75
+ dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes
76
+ simplify: False # (bool) ONNX: simplify model
77
+ opset: # (int, optional) ONNX: opset version
78
+ workspace: 4 # (int) TensorRT: workspace size (GB)
79
+ nms: False # (bool) CoreML: add NMS
80
+
81
+ # Hyperparameters ------------------------------------------------------------------------------------------------------
82
+ lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
83
+ lrf: 0.01 # (float) final learning rate (lr0 * lrf)
84
+ momentum: 0.937 # (float) SGD momentum/Adam beta1
85
+ weight_decay: 0.0005 # (float) optimizer weight decay 5e-4
86
+ warmup_epochs: 3.0 # (float) warmup epochs (fractions ok)
87
+ warmup_momentum: 0.8 # (float) warmup initial momentum
88
+ warmup_bias_lr: 0.1 # (float) warmup initial bias lr
89
+ box: 7.5 # (float) box loss gain
90
+ cls: 0.5 # (float) cls loss gain (scale with pixels)
91
+ dfl: 1.5 # (float) dfl loss gain
92
+ pose: 12.0 # (float) pose loss gain
93
+ kobj: 1.0 # (float) keypoint obj loss gain
94
+ label_smoothing: 0.0 # (float) label smoothing (fraction)
95
+ nbs: 64 # (int) nominal batch size
96
+ hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction)
97
+ hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction)
98
+ hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction)
99
+ degrees: 0.0 # (float) image rotation (+/- deg)
100
+ translate: 0.1 # (float) image translation (+/- fraction)
101
+ scale: 0.5 # (float) image scale (+/- gain)
102
+ shear: 0.0 # (float) image shear (+/- deg)
103
+ perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001
104
+ flipud: 0.0 # (float) image flip up-down (probability)
105
+ fliplr: 0.5 # (float) image flip left-right (probability)
106
+ mosaic: 1.0 # (float) image mosaic (probability)
107
+ mixup: 0.0 # (float) image mixup (probability)
108
+ copy_paste: 0.0 # (float) segment copy-paste (probability)
109
+
110
+ # Custom config.yaml ---------------------------------------------------------------------------------------------------
111
+ cfg: # (str, optional) for overriding defaults.yaml
112
+ save_dir: ./runs/train1 # 自己设置路径
113
+ # Tracker settings ------------------------------------------------------------------------------------------------------
114
+ tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml]
ultralytics/cfg/models/v8/yolov8.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 1 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+
34
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37
+
38
+ - [-1, 1, Conv, [256, 3, 2]]
39
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
40
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
41
+
42
+ - [-1, 1, Conv, [512, 3, 2]]
43
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
44
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
45
+
46
+ - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/cfg/models/v8/yolov8_ECA.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 9 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+ - [-1, 1, ECAAttention, [512]]
34
+
35
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
37
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
38
+ - [-1, 1, ECAAttention, [256]]
39
+
40
+ - [-1, 1, Conv, [256, 3, 2]]
41
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
42
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
43
+ - [-1, 1, ECAAttention, [512]]
44
+
45
+ - [-1, 1, Conv, [512, 3, 2]]
46
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
47
+ - [-1, 3, C2f, [1024]] # 24 (P5/32-large)
48
+ - [-1, 1, ECAAttention, [1024]]
49
+
50
+ - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/cfg/models/v8/yolov8_GAM.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 9 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+ - [-1, 1, GAM_Attention, [512,512]]
34
+
35
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
37
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
38
+ - [-1, 1, GAM_Attention, [256,256]]
39
+
40
+ - [-1, 1, Conv, [256, 3, 2]]
41
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
42
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
43
+ - [-1, 1, GAM_Attention, [512,512]]
44
+
45
+ - [-1, 1, Conv, [512, 3, 2]]
46
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
47
+ - [-1, 3, C2f, [1024]] # 24 (P5/32-large)
48
+ - [-1, 1, GAM_Attention, [1024,1024]]
49
+
50
+ - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/cfg/models/v8/yolov8_ResBlock_CBAM.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 9 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, GhostConv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+ - [-1, 1, ResBlock_CBAM, [512]]
34
+
35
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
37
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
38
+ - [-1, 1, ResBlock_CBAM, [256]]
39
+
40
+ - [-1, 1, Conv, [256, 3, 2]]
41
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
42
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
43
+ - [-1, 1, ResBlock_CBAM, [512]]
44
+
45
+ - [-1, 1, Conv, [512, 3, 2]]
46
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
47
+ - [-1, 3, C2f, [1024]] # 24 (P5/32-large)
48
+ - [-1, 1, ResBlock_CBAM, [1024]]
49
+
50
+ - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/cfg/models/v8/yolov8_SA.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 9 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+ - [-1, 1, ShuffleAttention, [512]]
34
+
35
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
37
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
38
+ - [-1, 1, ShuffleAttention, [256]]
39
+
40
+ - [-1, 1, Conv, [256, 3, 2]]
41
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
42
+ - [-1, 3, C2f, [512]] # 20 (P4/16-medium)
43
+ - [-1, 1, ShuffleAttention, [512]]
44
+
45
+ - [-1, 1, Conv, [512, 3, 2]]
46
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
47
+ - [-1, 3, C2f, [1024]] # 24 (P5/32-large)
48
+ - [-1, 1, ShuffleAttention, [1024]]
49
+
50
+ - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/cfg/trackers/botsort.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # Default YOLO tracker settings for BoT-SORT tracker https://github.com/NirAharon/BoT-SORT
3
+
4
+ tracker_type: botsort # tracker type, ['botsort', 'bytetrack']
5
+ track_high_thresh: 0.5 # threshold for the first association
6
+ track_low_thresh: 0.1 # threshold for the second association
7
+ new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
8
+ track_buffer: 30 # buffer to calculate the time when to remove tracks
9
+ match_thresh: 0.8 # threshold for matching tracks
10
+ # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
11
+ # mot20: False # for tracker evaluation(not used for now)
12
+
13
+ # BoT-SORT settings
14
+ cmc_method: sparseOptFlow # method of global motion compensation
15
+ # ReID model related thresh (not supported yet)
16
+ proximity_thresh: 0.5
17
+ appearance_thresh: 0.25
18
+ with_reid: False
ultralytics/cfg/trackers/bytetrack.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack
3
+
4
+ tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack']
5
+ track_high_thresh: 0.5 # threshold for the first association
6
+ track_low_thresh: 0.1 # threshold for the second association
7
+ new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
8
+ track_buffer: 30 # buffer to calculate the time when to remove tracks
9
+ match_thresh: 0.8 # threshold for matching tracks
10
+ # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
11
+ # mot20: False # for tracker evaluation(not used for now)
ultralytics/data/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from .base import BaseDataset
4
+ from .build import build_dataloader, build_yolo_dataset, load_inference_source
5
+ from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
6
+
7
+ __all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset',
8
+ 'build_dataloader', 'load_inference_source')
ultralytics/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (473 Bytes). View file
 
ultralytics/data/__pycache__/augment.cpython-39.pyc ADDED
Binary file (31.6 kB). View file
 
ultralytics/data/__pycache__/base.cpython-39.pyc ADDED
Binary file (11.3 kB). View file
 
ultralytics/data/__pycache__/build.cpython-39.pyc ADDED
Binary file (6.2 kB). View file
 
ultralytics/data/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (11.3 kB). View file
 
ultralytics/data/__pycache__/loaders.cpython-39.pyc ADDED
Binary file (15.7 kB). View file
 
ultralytics/data/__pycache__/utils.cpython-39.pyc ADDED
Binary file (24.1 kB). View file
 
ultralytics/data/annotator.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from ultralytics import SAM, YOLO
4
+
5
+
6
+ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
7
+ """
8
+ Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
9
+ Args:
10
+ data (str): Path to a folder containing images to be annotated.
11
+ det_model (str, optional): Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'.
12
+ sam_model (str, optional): Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'.
13
+ device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available).
14
+ output_dir (str | None | optional): Directory to save the annotated results.
15
+ Defaults to a 'labels' folder in the same directory as 'data'.
16
+ """
17
+ det_model = YOLO(det_model)
18
+ sam_model = SAM(sam_model)
19
+
20
+ if not output_dir:
21
+ output_dir = Path(str(data)).parent / 'labels'
22
+ Path(output_dir).mkdir(exist_ok=True, parents=True)
23
+
24
+ det_results = det_model(data, stream=True, device=device)
25
+
26
+ for result in det_results:
27
+ boxes = result.boxes.xyxy # Boxes object for bbox outputs
28
+ class_ids = result.boxes.cls.int().tolist() # noqa
29
+ if len(class_ids):
30
+ sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
31
+ segments = sam_results[0].masks.xyn # noqa
32
+
33
+ with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f:
34
+ for i in range(len(segments)):
35
+ s = segments[i]
36
+ if len(s) == 0:
37
+ continue
38
+ segment = map(str, segments[i].reshape(-1).tolist())
39
+ f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
ultralytics/data/augment.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import math
4
+ import random
5
+ from copy import deepcopy
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ import torchvision.transforms as T
11
+
12
+ from ultralytics.utils import LOGGER, colorstr
13
+ from ultralytics.utils.checks import check_version
14
+ from ultralytics.utils.instance import Instances
15
+ from ultralytics.utils.metrics import bbox_ioa
16
+ from ultralytics.utils.ops import segment2box
17
+
18
+ from .utils import polygons2masks, polygons2masks_overlap
19
+
20
+ POSE_FLIPLR_INDEX = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
21
+
22
+
23
+ # TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
24
+ class BaseTransform:
25
+
26
+ def __init__(self) -> None:
27
+ pass
28
+
29
+ def apply_image(self, labels):
30
+ """Applies image transformation to labels."""
31
+ pass
32
+
33
+ def apply_instances(self, labels):
34
+ """Applies transformations to input 'labels' and returns object instances."""
35
+ pass
36
+
37
+ def apply_semantic(self, labels):
38
+ """Applies semantic segmentation to an image."""
39
+ pass
40
+
41
+ def __call__(self, labels):
42
+ """Applies label transformations to an image, instances and semantic masks."""
43
+ self.apply_image(labels)
44
+ self.apply_instances(labels)
45
+ self.apply_semantic(labels)
46
+
47
+
48
+ class Compose:
49
+
50
+ def __init__(self, transforms):
51
+ """Initializes the Compose object with a list of transforms."""
52
+ self.transforms = transforms
53
+
54
+ def __call__(self, data):
55
+ """Applies a series of transformations to input data."""
56
+ for t in self.transforms:
57
+ data = t(data)
58
+ return data
59
+
60
+ def append(self, transform):
61
+ """Appends a new transform to the existing list of transforms."""
62
+ self.transforms.append(transform)
63
+
64
+ def tolist(self):
65
+ """Converts list of transforms to a standard Python list."""
66
+ return self.transforms
67
+
68
+ def __repr__(self):
69
+ """Return string representation of object."""
70
+ format_string = f'{self.__class__.__name__}('
71
+ for t in self.transforms:
72
+ format_string += '\n'
73
+ format_string += f' {t}'
74
+ format_string += '\n)'
75
+ return format_string
76
+
77
+
78
+ class BaseMixTransform:
79
+ """This implementation is from mmyolo."""
80
+
81
+ def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
82
+ self.dataset = dataset
83
+ self.pre_transform = pre_transform
84
+ self.p = p
85
+
86
+ def __call__(self, labels):
87
+ """Applies pre-processing transforms and mixup/mosaic transforms to labels data."""
88
+ if random.uniform(0, 1) > self.p:
89
+ return labels
90
+
91
+ # Get index of one or three other images
92
+ indexes = self.get_indexes()
93
+ if isinstance(indexes, int):
94
+ indexes = [indexes]
95
+
96
+ # Get images information will be used for Mosaic or MixUp
97
+ mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]
98
+
99
+ if self.pre_transform is not None:
100
+ for i, data in enumerate(mix_labels):
101
+ mix_labels[i] = self.pre_transform(data)
102
+ labels['mix_labels'] = mix_labels
103
+
104
+ # Mosaic or MixUp
105
+ labels = self._mix_transform(labels)
106
+ labels.pop('mix_labels', None)
107
+ return labels
108
+
109
+ def _mix_transform(self, labels):
110
+ """Applies MixUp or Mosaic augmentation to the label dictionary."""
111
+ raise NotImplementedError
112
+
113
+ def get_indexes(self):
114
+ """Gets a list of shuffled indexes for mosaic augmentation."""
115
+ raise NotImplementedError
116
+
117
+
118
+ class Mosaic(BaseMixTransform):
119
+ """
120
+ Mosaic augmentation.
121
+
122
+ This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
123
+ The augmentation is applied to a dataset with a given probability.
124
+
125
+ Attributes:
126
+ dataset: The dataset on which the mosaic augmentation is applied.
127
+ imgsz (int, optional): Image size (height and width) after mosaic pipeline of a single image. Default to 640.
128
+ p (float, optional): Probability of applying the mosaic augmentation. Must be in the range 0-1. Default to 1.0.
129
+ n (int, optional): The grid size, either 4 (for 2x2) or 9 (for 3x3).
130
+ """
131
+
132
+ def __init__(self, dataset, imgsz=640, p=1.0, n=4):
133
+ """Initializes the object with a dataset, image size, probability, and border."""
134
+ assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
135
+ assert n in (4, 9), 'grid must be equal to 4 or 9.'
136
+ super().__init__(dataset=dataset, p=p)
137
+ self.dataset = dataset
138
+ self.imgsz = imgsz
139
+ self.border = (-imgsz // 2, -imgsz // 2) # width, height
140
+ self.n = n
141
+
142
+ def get_indexes(self, buffer=True):
143
+ """Return a list of random indexes from the dataset."""
144
+ if buffer: # select images from buffer
145
+ return random.choices(list(self.dataset.buffer), k=self.n - 1)
146
+ else: # select any images
147
+ return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
148
+
149
+ def _mix_transform(self, labels):
150
+ """Apply mixup transformation to the input image and labels."""
151
+ assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
152
+ assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
153
+ return self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
154
+
155
+ def _mosaic4(self, labels):
156
+ """Create a 2x2 image mosaic."""
157
+ mosaic_labels = []
158
+ s = self.imgsz
159
+ yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
160
+ for i in range(4):
161
+ labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
162
+ # Load image
163
+ img = labels_patch['img']
164
+ h, w = labels_patch.pop('resized_shape')
165
+
166
+ # Place img in img4
167
+ if i == 0: # top left
168
+ img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
169
+ x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
170
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
171
+ elif i == 1: # top right
172
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
173
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
174
+ elif i == 2: # bottom left
175
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
176
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
177
+ elif i == 3: # bottom right
178
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
179
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
180
+
181
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
182
+ padw = x1a - x1b
183
+ padh = y1a - y1b
184
+
185
+ labels_patch = self._update_labels(labels_patch, padw, padh)
186
+ mosaic_labels.append(labels_patch)
187
+ final_labels = self._cat_labels(mosaic_labels)
188
+ final_labels['img'] = img4
189
+ return final_labels
190
+
191
+ def _mosaic9(self, labels):
192
+ """Create a 3x3 image mosaic."""
193
+ mosaic_labels = []
194
+ s = self.imgsz
195
+ hp, wp = -1, -1 # height, width previous
196
+ for i in range(9):
197
+ labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
198
+ # Load image
199
+ img = labels_patch['img']
200
+ h, w = labels_patch.pop('resized_shape')
201
+
202
+ # Place img in img9
203
+ if i == 0: # center
204
+ img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
205
+ h0, w0 = h, w
206
+ c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
207
+ elif i == 1: # top
208
+ c = s, s - h, s + w, s
209
+ elif i == 2: # top right
210
+ c = s + wp, s - h, s + wp + w, s
211
+ elif i == 3: # right
212
+ c = s + w0, s, s + w0 + w, s + h
213
+ elif i == 4: # bottom right
214
+ c = s + w0, s + hp, s + w0 + w, s + hp + h
215
+ elif i == 5: # bottom
216
+ c = s + w0 - w, s + h0, s + w0, s + h0 + h
217
+ elif i == 6: # bottom left
218
+ c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
219
+ elif i == 7: # left
220
+ c = s - w, s + h0 - h, s, s + h0
221
+ elif i == 8: # top left
222
+ c = s - w, s + h0 - hp - h, s, s + h0 - hp
223
+
224
+ padw, padh = c[:2]
225
+ x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
226
+
227
+ # Image
228
+ img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
229
+ hp, wp = h, w # height, width previous for next iteration
230
+
231
+ # Labels assuming imgsz*2 mosaic size
232
+ labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1])
233
+ mosaic_labels.append(labels_patch)
234
+ final_labels = self._cat_labels(mosaic_labels)
235
+
236
+ final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
237
+ return final_labels
238
+
239
+ @staticmethod
240
+ def _update_labels(labels, padw, padh):
241
+ """Update labels."""
242
+ nh, nw = labels['img'].shape[:2]
243
+ labels['instances'].convert_bbox(format='xyxy')
244
+ labels['instances'].denormalize(nw, nh)
245
+ labels['instances'].add_padding(padw, padh)
246
+ return labels
247
+
248
+ def _cat_labels(self, mosaic_labels):
249
+ """Return labels with mosaic border instances clipped."""
250
+ if len(mosaic_labels) == 0:
251
+ return {}
252
+ cls = []
253
+ instances = []
254
+ imgsz = self.imgsz * 2 # mosaic imgsz
255
+ for labels in mosaic_labels:
256
+ cls.append(labels['cls'])
257
+ instances.append(labels['instances'])
258
+ final_labels = {
259
+ 'im_file': mosaic_labels[0]['im_file'],
260
+ 'ori_shape': mosaic_labels[0]['ori_shape'],
261
+ 'resized_shape': (imgsz, imgsz),
262
+ 'cls': np.concatenate(cls, 0),
263
+ 'instances': Instances.concatenate(instances, axis=0),
264
+ 'mosaic_border': self.border} # final_labels
265
+ final_labels['instances'].clip(imgsz, imgsz)
266
+ good = final_labels['instances'].remove_zero_area_boxes()
267
+ final_labels['cls'] = final_labels['cls'][good]
268
+ return final_labels
269
+
270
+
271
+ class MixUp(BaseMixTransform):
272
+
273
+ def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
274
+ super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
275
+
276
+ def get_indexes(self):
277
+ """Get a random index from the dataset."""
278
+ return random.randint(0, len(self.dataset) - 1)
279
+
280
+ def _mix_transform(self, labels):
281
+ """Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf."""
282
+ r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
283
+ labels2 = labels['mix_labels'][0]
284
+ labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
285
+ labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
286
+ labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
287
+ return labels
288
+
289
+
290
+ class RandomPerspective:
291
+
292
+ def __init__(self,
293
+ degrees=0.0,
294
+ translate=0.1,
295
+ scale=0.5,
296
+ shear=0.0,
297
+ perspective=0.0,
298
+ border=(0, 0),
299
+ pre_transform=None):
300
+ self.degrees = degrees
301
+ self.translate = translate
302
+ self.scale = scale
303
+ self.shear = shear
304
+ self.perspective = perspective
305
+ # Mosaic border
306
+ self.border = border
307
+ self.pre_transform = pre_transform
308
+
309
+ def affine_transform(self, img, border):
310
+ """Center."""
311
+ C = np.eye(3, dtype=np.float32)
312
+
313
+ C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
314
+ C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
315
+
316
+ # Perspective
317
+ P = np.eye(3, dtype=np.float32)
318
+ P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y)
319
+ P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x)
320
+
321
+ # Rotation and Scale
322
+ R = np.eye(3, dtype=np.float32)
323
+ a = random.uniform(-self.degrees, self.degrees)
324
+ # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
325
+ s = random.uniform(1 - self.scale, 1 + self.scale)
326
+ # s = 2 ** random.uniform(-scale, scale)
327
+ R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
328
+
329
+ # Shear
330
+ S = np.eye(3, dtype=np.float32)
331
+ S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg)
332
+ S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg)
333
+
334
+ # Translation
335
+ T = np.eye(3, dtype=np.float32)
336
+ T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels)
337
+ T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels)
338
+
339
+ # Combined rotation matrix
340
+ M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
341
+ # Affine image
342
+ if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
343
+ if self.perspective:
344
+ img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
345
+ else: # affine
346
+ img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))
347
+ return img, M, s
348
+
349
+ def apply_bboxes(self, bboxes, M):
350
+ """
351
+ Apply affine to bboxes only.
352
+
353
+ Args:
354
+ bboxes (ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
355
+ M (ndarray): affine matrix.
356
+
357
+ Returns:
358
+ new_bboxes (ndarray): bboxes after affine, [num_bboxes, 4].
359
+ """
360
+ n = len(bboxes)
361
+ if n == 0:
362
+ return bboxes
363
+
364
+ xy = np.ones((n * 4, 3), dtype=bboxes.dtype)
365
+ xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
366
+ xy = xy @ M.T # transform
367
+ xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
368
+
369
+ # Create new boxes
370
+ x = xy[:, [0, 2, 4, 6]]
371
+ y = xy[:, [1, 3, 5, 7]]
372
+ return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
373
+
374
+ def apply_segments(self, segments, M):
375
+ """
376
+ Apply affine to segments and generate new bboxes from segments.
377
+
378
+ Args:
379
+ segments (ndarray): list of segments, [num_samples, 500, 2].
380
+ M (ndarray): affine matrix.
381
+
382
+ Returns:
383
+ new_segments (ndarray): list of segments after affine, [num_samples, 500, 2].
384
+ new_bboxes (ndarray): bboxes after affine, [N, 4].
385
+ """
386
+ n, num = segments.shape[:2]
387
+ if n == 0:
388
+ return [], segments
389
+
390
+ xy = np.ones((n * num, 3), dtype=segments.dtype)
391
+ segments = segments.reshape(-1, 2)
392
+ xy[:, :2] = segments
393
+ xy = xy @ M.T # transform
394
+ xy = xy[:, :2] / xy[:, 2:3]
395
+ segments = xy.reshape(n, -1, 2)
396
+ bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)
397
+ return bboxes, segments
398
+
399
+ def apply_keypoints(self, keypoints, M):
400
+ """
401
+ Apply affine to keypoints.
402
+
403
+ Args:
404
+ keypoints (ndarray): keypoints, [N, 17, 3].
405
+ M (ndarray): affine matrix.
406
+
407
+ Return:
408
+ new_keypoints (ndarray): keypoints after affine, [N, 17, 3].
409
+ """
410
+ n, nkpt = keypoints.shape[:2]
411
+ if n == 0:
412
+ return keypoints
413
+ xy = np.ones((n * nkpt, 3), dtype=keypoints.dtype)
414
+ visible = keypoints[..., 2].reshape(n * nkpt, 1)
415
+ xy[:, :2] = keypoints[..., :2].reshape(n * nkpt, 2)
416
+ xy = xy @ M.T # transform
417
+ xy = xy[:, :2] / xy[:, 2:3] # perspective rescale or affine
418
+ out_mask = (xy[:, 0] < 0) | (xy[:, 1] < 0) | (xy[:, 0] > self.size[0]) | (xy[:, 1] > self.size[1])
419
+ visible[out_mask] = 0
420
+ return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3)
421
+
422
+ def __call__(self, labels):
423
+ """
424
+ Affine images and targets.
425
+
426
+ Args:
427
+ labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
428
+ """
429
+ if self.pre_transform and 'mosaic_border' not in labels:
430
+ labels = self.pre_transform(labels)
431
+ labels.pop('ratio_pad', None) # do not need ratio pad
432
+
433
+ img = labels['img']
434
+ cls = labels['cls']
435
+ instances = labels.pop('instances')
436
+ # Make sure the coord formats are right
437
+ instances.convert_bbox(format='xyxy')
438
+ instances.denormalize(*img.shape[:2][::-1])
439
+
440
+ border = labels.pop('mosaic_border', self.border)
441
+ self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
442
+ # M is affine matrix
443
+ # scale for func:`box_candidates`
444
+ img, M, scale = self.affine_transform(img, border)
445
+
446
+ bboxes = self.apply_bboxes(instances.bboxes, M)
447
+
448
+ segments = instances.segments
449
+ keypoints = instances.keypoints
450
+ # Update bboxes if there are segments.
451
+ if len(segments):
452
+ bboxes, segments = self.apply_segments(segments, M)
453
+
454
+ if keypoints is not None:
455
+ keypoints = self.apply_keypoints(keypoints, M)
456
+ new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
457
+ # Clip
458
+ new_instances.clip(*self.size)
459
+
460
+ # Filter instances
461
+ instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
462
+ # Make the bboxes have the same scale with new_bboxes
463
+ i = self.box_candidates(box1=instances.bboxes.T,
464
+ box2=new_instances.bboxes.T,
465
+ area_thr=0.01 if len(segments) else 0.10)
466
+ labels['instances'] = new_instances[i]
467
+ labels['cls'] = cls[i]
468
+ labels['img'] = img
469
+ labels['resized_shape'] = img.shape[:2]
470
+ return labels
471
+
472
+ def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
473
+ # Compute box candidates: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
474
+ w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
475
+ w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
476
+ ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
477
+ return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
478
+
479
+
480
+ class RandomHSV:
481
+
482
+ def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None:
483
+ self.hgain = hgain
484
+ self.sgain = sgain
485
+ self.vgain = vgain
486
+
487
+ def __call__(self, labels):
488
+ """Applies random horizontal or vertical flip to an image with a given probability."""
489
+ img = labels['img']
490
+ if self.hgain or self.sgain or self.vgain:
491
+ r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
492
+ hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
493
+ dtype = img.dtype # uint8
494
+
495
+ x = np.arange(0, 256, dtype=r.dtype)
496
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
497
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
498
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
499
+
500
+ im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
501
+ cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
502
+ return labels
503
+
504
+
505
+ class RandomFlip:
506
+
507
+ def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None:
508
+ assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
509
+ assert 0 <= p <= 1.0
510
+
511
+ self.p = p
512
+ self.direction = direction
513
+ self.flip_idx = flip_idx
514
+
515
+ def __call__(self, labels):
516
+ """Resize image and padding for detection, instance segmentation, pose."""
517
+ img = labels['img']
518
+ instances = labels.pop('instances')
519
+ instances.convert_bbox(format='xywh')
520
+ h, w = img.shape[:2]
521
+ h = 1 if instances.normalized else h
522
+ w = 1 if instances.normalized else w
523
+
524
+ # Flip up-down
525
+ if self.direction == 'vertical' and random.random() < self.p:
526
+ img = np.flipud(img)
527
+ instances.flipud(h)
528
+ if self.direction == 'horizontal' and random.random() < self.p:
529
+ img = np.fliplr(img)
530
+ instances.fliplr(w)
531
+ # For keypoints
532
+ if self.flip_idx is not None and instances.keypoints is not None:
533
+ instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
534
+ labels['img'] = np.ascontiguousarray(img)
535
+ labels['instances'] = instances
536
+ return labels
537
+
538
+
539
+ class LetterBox:
540
+ """Resize image and padding for detection, instance segmentation, pose."""
541
+
542
+ def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):
543
+ """Initialize LetterBox object with specific parameters."""
544
+ self.new_shape = new_shape
545
+ self.auto = auto
546
+ self.scaleFill = scaleFill
547
+ self.scaleup = scaleup
548
+ self.stride = stride
549
+ self.center = center # Put the image in the middle or top-left
550
+
551
+ def __call__(self, labels=None, image=None):
552
+ """Return updated labels and image with added border."""
553
+ if labels is None:
554
+ labels = {}
555
+ img = labels.get('img') if image is None else image
556
+ shape = img.shape[:2] # current shape [height, width]
557
+ new_shape = labels.pop('rect_shape', self.new_shape)
558
+ if isinstance(new_shape, int):
559
+ new_shape = (new_shape, new_shape)
560
+
561
+ # Scale ratio (new / old)
562
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
563
+ if not self.scaleup: # only scale down, do not scale up (for better val mAP)
564
+ r = min(r, 1.0)
565
+
566
+ # Compute padding
567
+ ratio = r, r # width, height ratios
568
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
569
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
570
+ if self.auto: # minimum rectangle
571
+ dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
572
+ elif self.scaleFill: # stretch
573
+ dw, dh = 0.0, 0.0
574
+ new_unpad = (new_shape[1], new_shape[0])
575
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
576
+
577
+ if self.center:
578
+ dw /= 2 # divide padding into 2 sides
579
+ dh /= 2
580
+ if labels.get('ratio_pad'):
581
+ labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh)) # for evaluation
582
+
583
+ if shape[::-1] != new_unpad: # resize
584
+ img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
585
+ top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
586
+ left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
587
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
588
+ value=(114, 114, 114)) # add border
589
+
590
+ if len(labels):
591
+ labels = self._update_labels(labels, ratio, dw, dh)
592
+ labels['img'] = img
593
+ labels['resized_shape'] = new_shape
594
+ return labels
595
+ else:
596
+ return img
597
+
598
+ def _update_labels(self, labels, ratio, padw, padh):
599
+ """Update labels."""
600
+ labels['instances'].convert_bbox(format='xyxy')
601
+ labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
602
+ labels['instances'].scale(*ratio)
603
+ labels['instances'].add_padding(padw, padh)
604
+ return labels
605
+
606
+
607
+ class CopyPaste:
608
+
609
+ def __init__(self, p=0.5) -> None:
610
+ self.p = p
611
+
612
+ def __call__(self, labels):
613
+ """Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)."""
614
+ im = labels['img']
615
+ cls = labels['cls']
616
+ h, w = im.shape[:2]
617
+ instances = labels.pop('instances')
618
+ instances.convert_bbox(format='xyxy')
619
+ instances.denormalize(w, h)
620
+ if self.p and len(instances.segments):
621
+ n = len(instances)
622
+ _, w, _ = im.shape # height, width, channels
623
+ im_new = np.zeros(im.shape, np.uint8)
624
+
625
+ # Calculate ioa first then select indexes randomly
626
+ ins_flip = deepcopy(instances)
627
+ ins_flip.fliplr(w)
628
+
629
+ ioa = bbox_ioa(ins_flip.bboxes, instances.bboxes) # intersection over area, (N, M)
630
+ indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
631
+ n = len(indexes)
632
+ for j in random.sample(list(indexes), k=round(self.p * n)):
633
+ cls = np.concatenate((cls, cls[[j]]), axis=0)
634
+ instances = Instances.concatenate((instances, ins_flip[[j]]), axis=0)
635
+ cv2.drawContours(im_new, instances.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED)
636
+
637
+ result = cv2.flip(im, 1) # augment segments (flip left-right)
638
+ i = cv2.flip(im_new, 1).astype(bool)
639
+ im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
640
+
641
+ labels['img'] = im
642
+ labels['cls'] = cls
643
+ labels['instances'] = instances
644
+ return labels
645
+
646
+
647
+ class Albumentations:
648
+ """YOLOv8 Albumentations class (optional, only used if package is installed)"""
649
+
650
+ def __init__(self, p=1.0):
651
+ """Initialize the transform object for YOLO bbox formatted params."""
652
+ self.p = p
653
+ self.transform = None
654
+ prefix = colorstr('albumentations: ')
655
+ try:
656
+ import albumentations as A
657
+
658
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
659
+
660
+ T = [
661
+ A.Blur(p=0.01),
662
+ A.MedianBlur(p=0.01),
663
+ A.ToGray(p=0.01),
664
+ A.CLAHE(p=0.01),
665
+ A.RandomBrightnessContrast(p=0.0),
666
+ A.RandomGamma(p=0.0),
667
+ A.ImageCompression(quality_lower=75, p=0.0)] # transforms
668
+ self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
669
+
670
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
671
+ except ImportError: # package not installed, skip
672
+ pass
673
+ except Exception as e:
674
+ LOGGER.info(f'{prefix}{e}')
675
+
676
+ def __call__(self, labels):
677
+ """Generates object detections and returns a dictionary with detection results."""
678
+ im = labels['img']
679
+ cls = labels['cls']
680
+ if len(cls):
681
+ labels['instances'].convert_bbox('xywh')
682
+ labels['instances'].normalize(*im.shape[:2][::-1])
683
+ bboxes = labels['instances'].bboxes
684
+ # TODO: add supports of segments and keypoints
685
+ if self.transform and random.random() < self.p:
686
+ new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
687
+ if len(new['class_labels']) > 0: # skip update if no bbox in new im
688
+ labels['img'] = new['image']
689
+ labels['cls'] = np.array(new['class_labels'])
690
+ bboxes = np.array(new['bboxes'], dtype=np.float32)
691
+ labels['instances'].update(bboxes=bboxes)
692
+ return labels
693
+
694
+
695
+ # TODO: technically this is not an augmentation, maybe we should put this to another files
696
+ class Format:
697
+
698
+ def __init__(self,
699
+ bbox_format='xywh',
700
+ normalize=True,
701
+ return_mask=False,
702
+ return_keypoint=False,
703
+ mask_ratio=4,
704
+ mask_overlap=True,
705
+ batch_idx=True):
706
+ self.bbox_format = bbox_format
707
+ self.normalize = normalize
708
+ self.return_mask = return_mask # set False when training detection only
709
+ self.return_keypoint = return_keypoint
710
+ self.mask_ratio = mask_ratio
711
+ self.mask_overlap = mask_overlap
712
+ self.batch_idx = batch_idx # keep the batch indexes
713
+
714
+ def __call__(self, labels):
715
+ """Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
716
+ img = labels.pop('img')
717
+ h, w = img.shape[:2]
718
+ cls = labels.pop('cls')
719
+ instances = labels.pop('instances')
720
+ instances.convert_bbox(format=self.bbox_format)
721
+ instances.denormalize(w, h)
722
+ nl = len(instances)
723
+
724
+ if self.return_mask:
725
+ if nl:
726
+ masks, instances, cls = self._format_segments(instances, cls, w, h)
727
+ masks = torch.from_numpy(masks)
728
+ else:
729
+ masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
730
+ img.shape[1] // self.mask_ratio)
731
+ labels['masks'] = masks
732
+ if self.normalize:
733
+ instances.normalize(w, h)
734
+ labels['img'] = self._format_img(img)
735
+ labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
736
+ labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
737
+ if self.return_keypoint:
738
+ labels['keypoints'] = torch.from_numpy(instances.keypoints)
739
+ # Then we can use collate_fn
740
+ if self.batch_idx:
741
+ labels['batch_idx'] = torch.zeros(nl)
742
+ return labels
743
+
744
+ def _format_img(self, img):
745
+ """Format the image for YOLOv5 from Numpy array to PyTorch tensor."""
746
+ if len(img.shape) < 3:
747
+ img = np.expand_dims(img, -1)
748
+ img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
749
+ img = torch.from_numpy(img)
750
+ return img
751
+
752
+ def _format_segments(self, instances, cls, w, h):
753
+ """convert polygon points to bitmap."""
754
+ segments = instances.segments
755
+ if self.mask_overlap:
756
+ masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)
757
+ masks = masks[None] # (640, 640) -> (1, 640, 640)
758
+ instances = instances[sorted_idx]
759
+ cls = cls[sorted_idx]
760
+ else:
761
+ masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio)
762
+
763
+ return masks, instances, cls
764
+
765
+
766
+ def v8_transforms(dataset, imgsz, hyp, stretch=False):
767
+ """Convert images to a size suitable for YOLOv8 training."""
768
+ pre_transform = Compose([
769
+ Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
770
+ CopyPaste(p=hyp.copy_paste),
771
+ RandomPerspective(
772
+ degrees=hyp.degrees,
773
+ translate=hyp.translate,
774
+ scale=hyp.scale,
775
+ shear=hyp.shear,
776
+ perspective=hyp.perspective,
777
+ pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
778
+ )])
779
+ flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation
780
+ if dataset.use_keypoints:
781
+ kpt_shape = dataset.data.get('kpt_shape', None)
782
+ if len(flip_idx) == 0 and hyp.fliplr > 0.0:
783
+ hyp.fliplr = 0.0
784
+ LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
785
+ elif flip_idx and (len(flip_idx) != kpt_shape[0]):
786
+ raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
787
+
788
+ return Compose([
789
+ pre_transform,
790
+ MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
791
+ Albumentations(p=1.0),
792
+ RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
793
+ RandomFlip(direction='vertical', p=hyp.flipud),
794
+ RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms
795
+
796
+
797
+ # Classification augmentations -----------------------------------------------------------------------------------------
798
+ def classify_transforms(size=224, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): # IMAGENET_MEAN, IMAGENET_STD
799
+ # Transforms to apply if albumentations not installed
800
+ if not isinstance(size, int):
801
+ raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
802
+ if any(mean) or any(std):
803
+ return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(mean, std, inplace=True)])
804
+ else:
805
+ return T.Compose([CenterCrop(size), ToTensor()])
806
+
807
+
808
+ def hsv2colorjitter(h, s, v):
809
+ """Map HSV (hue, saturation, value) jitter into ColorJitter values (brightness, contrast, saturation, hue)"""
810
+ return v, v, s, h
811
+
812
+
813
+ def classify_albumentations(
814
+ augment=True,
815
+ size=224,
816
+ scale=(0.08, 1.0),
817
+ hflip=0.5,
818
+ vflip=0.0,
819
+ hsv_h=0.015, # image HSV-Hue augmentation (fraction)
820
+ hsv_s=0.7, # image HSV-Saturation augmentation (fraction)
821
+ hsv_v=0.4, # image HSV-Value augmentation (fraction)
822
+ mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
823
+ std=(1.0, 1.0, 1.0), # IMAGENET_STD
824
+ auto_aug=False,
825
+ ):
826
+ """YOLOv8 classification Albumentations (optional, only used if package is installed)."""
827
+ prefix = colorstr('albumentations: ')
828
+ try:
829
+ import albumentations as A
830
+ from albumentations.pytorch import ToTensorV2
831
+
832
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
833
+ if augment: # Resize and crop
834
+ T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
835
+ if auto_aug:
836
+ # TODO: implement AugMix, AutoAug & RandAug in albumentations
837
+ LOGGER.info(f'{prefix}auto augmentations are currently not supported')
838
+ else:
839
+ if hflip > 0:
840
+ T += [A.HorizontalFlip(p=hflip)]
841
+ if vflip > 0:
842
+ T += [A.VerticalFlip(p=vflip)]
843
+ if any((hsv_h, hsv_s, hsv_v)):
844
+ T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
845
+ else: # Use fixed crop for eval set (reproducibility)
846
+ T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
847
+ T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
848
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
849
+ return A.Compose(T)
850
+
851
+ except ImportError: # package not installed, skip
852
+ pass
853
+ except Exception as e:
854
+ LOGGER.info(f'{prefix}{e}')
855
+
856
+
857
+ class ClassifyLetterBox:
858
+ """YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])"""
859
+
860
+ def __init__(self, size=(640, 640), auto=False, stride=32):
861
+ """Resizes image and crops it to center with max dimensions 'h' and 'w'."""
862
+ super().__init__()
863
+ self.h, self.w = (size, size) if isinstance(size, int) else size
864
+ self.auto = auto # pass max size integer, automatically solve for short side using stride
865
+ self.stride = stride # used with auto
866
+
867
+ def __call__(self, im): # im = np.array HWC
868
+ imh, imw = im.shape[:2]
869
+ r = min(self.h / imh, self.w / imw) # ratio of new/old
870
+ h, w = round(imh * r), round(imw * r) # resized image
871
+ hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
872
+ top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
873
+ im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
874
+ im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
875
+ return im_out
876
+
877
+
878
+ class CenterCrop:
879
+ """YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])"""
880
+
881
+ def __init__(self, size=640):
882
+ """Converts an image from numpy array to PyTorch tensor."""
883
+ super().__init__()
884
+ self.h, self.w = (size, size) if isinstance(size, int) else size
885
+
886
+ def __call__(self, im): # im = np.array HWC
887
+ imh, imw = im.shape[:2]
888
+ m = min(imh, imw) # min dimension
889
+ top, left = (imh - m) // 2, (imw - m) // 2
890
+ return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
891
+
892
+
893
+ class ToTensor:
894
+ """YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])."""
895
+
896
+ def __init__(self, half=False):
897
+ """Initialize YOLOv8 ToTensor object with optional half-precision support."""
898
+ super().__init__()
899
+ self.half = half
900
+
901
+ def __call__(self, im): # im = np.array HWC in BGR order
902
+ im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
903
+ im = torch.from_numpy(im) # to torch
904
+ im = im.half() if self.half else im.float() # uint8 to fp16/32
905
+ im /= 255.0 # 0-255 to 0.0-1.0
906
+ return im
ultralytics/data/base.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import glob
4
+ import math
5
+ import os
6
+ import random
7
+ from copy import deepcopy
8
+ from multiprocessing.pool import ThreadPool
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import psutil
15
+ from torch.utils.data import Dataset
16
+ from tqdm import tqdm
17
+
18
+ from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
19
+
20
+ from .utils import HELP_URL, IMG_FORMATS
21
+
22
+
23
+ class BaseDataset(Dataset):
24
+ """
25
+ Base dataset class for loading and processing image data.
26
+
27
+ Args:
28
+ img_path (str): Path to the folder containing images.
29
+ imgsz (int, optional): Image size. Defaults to 640.
30
+ cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
31
+ augment (bool, optional): If True, data augmentation is applied. Defaults to True.
32
+ hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
33
+ prefix (str, optional): Prefix to print in log messages. Defaults to ''.
34
+ rect (bool, optional): If True, rectangular training is used. Defaults to False.
35
+ batch_size (int, optional): Size of batches. Defaults to None.
36
+ stride (int, optional): Stride. Defaults to 32.
37
+ pad (float, optional): Padding. Defaults to 0.0.
38
+ single_cls (bool, optional): If True, single class training is used. Defaults to False.
39
+ classes (list): List of included classes. Default is None.
40
+ fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
41
+
42
+ Attributes:
43
+ im_files (list): List of image file paths.
44
+ labels (list): List of label data dictionaries.
45
+ ni (int): Number of images in the dataset.
46
+ ims (list): List of loaded images.
47
+ npy_files (list): List of numpy file paths.
48
+ transforms (callable): Image transformation function.
49
+ """
50
+
51
+ def __init__(self,
52
+ img_path,
53
+ imgsz=640,
54
+ cache=False,
55
+ augment=True,
56
+ hyp=DEFAULT_CFG,
57
+ prefix='',
58
+ rect=False,
59
+ batch_size=16,
60
+ stride=32,
61
+ pad=0.5,
62
+ single_cls=False,
63
+ classes=None,
64
+ fraction=1.0):
65
+ super().__init__()
66
+ self.img_path = img_path
67
+ self.imgsz = imgsz
68
+ self.augment = augment
69
+ self.single_cls = single_cls
70
+ self.prefix = prefix
71
+ self.fraction = fraction
72
+ self.im_files = self.get_img_files(self.img_path)
73
+ self.labels = self.get_labels()
74
+ self.update_labels(include_class=classes) # single_cls and include_class
75
+ self.ni = len(self.labels) # number of images
76
+ self.rect = rect
77
+ self.batch_size = batch_size
78
+ self.stride = stride
79
+ self.pad = pad
80
+ if self.rect:
81
+ assert self.batch_size is not None
82
+ self.set_rectangle()
83
+
84
+ # Buffer thread for mosaic images
85
+ self.buffer = [] # buffer size = batch size
86
+ self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
87
+
88
+ # Cache stuff
89
+ if cache == 'ram' and not self.check_cache_ram():
90
+ cache = False
91
+ self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
92
+ self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
93
+ if cache:
94
+ self.cache_images(cache)
95
+
96
+ # Transforms
97
+ self.transforms = self.build_transforms(hyp=hyp)
98
+
99
+ def get_img_files(self, img_path):
100
+ """Read image files."""
101
+ try:
102
+ f = [] # image files
103
+ for p in img_path if isinstance(img_path, list) else [img_path]:
104
+ p = Path(p) # os-agnostic
105
+ if p.is_dir(): # dir
106
+ f += glob.glob(str(p / '**' / '*.*'), recursive=True)
107
+ # F = list(p.rglob('*.*')) # pathlib
108
+ elif p.is_file(): # file
109
+ with open(p) as t:
110
+ t = t.read().strip().splitlines()
111
+ parent = str(p.parent) + os.sep
112
+ f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
113
+ # F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
114
+ else:
115
+ raise FileNotFoundError(f'{self.prefix}{p} does not exist')
116
+ im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
117
+ # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
118
+ assert im_files, f'{self.prefix}No images found'
119
+ except Exception as e:
120
+ raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
121
+ if self.fraction < 1:
122
+ im_files = im_files[:round(len(im_files) * self.fraction)]
123
+ return im_files
124
+
125
+ def update_labels(self, include_class: Optional[list]):
126
+ """include_class, filter labels to include only these classes (optional)."""
127
+ include_class_array = np.array(include_class).reshape(1, -1)
128
+ for i in range(len(self.labels)):
129
+ if include_class is not None:
130
+ cls = self.labels[i]['cls']
131
+ bboxes = self.labels[i]['bboxes']
132
+ segments = self.labels[i]['segments']
133
+ keypoints = self.labels[i]['keypoints']
134
+ j = (cls == include_class_array).any(1)
135
+ self.labels[i]['cls'] = cls[j]
136
+ self.labels[i]['bboxes'] = bboxes[j]
137
+ if segments:
138
+ self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
139
+ if keypoints is not None:
140
+ self.labels[i]['keypoints'] = keypoints[j]
141
+ if self.single_cls:
142
+ self.labels[i]['cls'][:, 0] = 0
143
+
144
+ def load_image(self, i):
145
+ """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
146
+ im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
147
+ if im is None: # not cached in RAM
148
+ if fn.exists(): # load npy
149
+ im = np.load(fn)
150
+ else: # read image
151
+ im = cv2.imread(f) # BGR
152
+ if im is None:
153
+ raise FileNotFoundError(f'Image Not Found {f}')
154
+ h0, w0 = im.shape[:2] # orig hw
155
+ r = self.imgsz / max(h0, w0) # ratio
156
+ if r != 1: # if sizes are not equal
157
+ interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
158
+ im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)),
159
+ interpolation=interp)
160
+
161
+ # Add to buffer if training with augmentations
162
+ if self.augment:
163
+ self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
164
+ self.buffer.append(i)
165
+ if len(self.buffer) >= self.max_buffer_length:
166
+ j = self.buffer.pop(0)
167
+ self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
168
+
169
+ return im, (h0, w0), im.shape[:2]
170
+
171
+ return self.ims[i], self.im_hw0[i], self.im_hw[i]
172
+
173
+ def cache_images(self, cache):
174
+ """Cache images to memory or disk."""
175
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
176
+ fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
177
+ with ThreadPool(NUM_THREADS) as pool:
178
+ results = pool.imap(fcn, range(self.ni))
179
+ pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
180
+ for i, x in pbar:
181
+ if cache == 'disk':
182
+ b += self.npy_files[i].stat().st_size
183
+ else: # 'ram'
184
+ self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
185
+ b += self.ims[i].nbytes
186
+ pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
187
+ pbar.close()
188
+
189
+ def cache_images_to_disk(self, i):
190
+ """Saves an image as an *.npy file for faster loading."""
191
+ f = self.npy_files[i]
192
+ if not f.exists():
193
+ np.save(f.as_posix(), cv2.imread(self.im_files[i]))
194
+
195
+ def check_cache_ram(self, safety_margin=0.5):
196
+ """Check image caching requirements vs available memory."""
197
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
198
+ n = min(self.ni, 30) # extrapolate from 30 random images
199
+ for _ in range(n):
200
+ im = cv2.imread(random.choice(self.im_files)) # sample image
201
+ ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
202
+ b += im.nbytes * ratio ** 2
203
+ mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
204
+ mem = psutil.virtual_memory()
205
+ cache = mem_required < mem.available # to cache or not to cache, that is the question
206
+ if not cache:
207
+ LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
208
+ f'with {int(safety_margin * 100)}% safety margin but only '
209
+ f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
210
+ f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
211
+ return cache
212
+
213
+ def set_rectangle(self):
214
+ """Sets the shape of bounding boxes for YOLO detections as rectangles."""
215
+ bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
216
+ nb = bi[-1] + 1 # number of batches
217
+
218
+ s = np.array([x.pop('shape') for x in self.labels]) # hw
219
+ ar = s[:, 0] / s[:, 1] # aspect ratio
220
+ irect = ar.argsort()
221
+ self.im_files = [self.im_files[i] for i in irect]
222
+ self.labels = [self.labels[i] for i in irect]
223
+ ar = ar[irect]
224
+
225
+ # Set training image shapes
226
+ shapes = [[1, 1]] * nb
227
+ for i in range(nb):
228
+ ari = ar[bi == i]
229
+ mini, maxi = ari.min(), ari.max()
230
+ if maxi < 1:
231
+ shapes[i] = [maxi, 1]
232
+ elif mini > 1:
233
+ shapes[i] = [1, 1 / mini]
234
+
235
+ self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
236
+ self.batch = bi # batch index of image
237
+
238
+ def __getitem__(self, index):
239
+ """Returns transformed label information for given index."""
240
+ return self.transforms(self.get_image_and_label(index))
241
+
242
+ def get_image_and_label(self, index):
243
+ """Get and return label information from the dataset."""
244
+ label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
245
+ label.pop('shape', None) # shape is for rect, remove it
246
+ label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
247
+ label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0],
248
+ label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
249
+ if self.rect:
250
+ label['rect_shape'] = self.batch_shapes[self.batch[index]]
251
+ return self.update_labels_info(label)
252
+
253
+ def __len__(self):
254
+ """Returns the length of the labels list for the dataset."""
255
+ return len(self.labels)
256
+
257
+ def update_labels_info(self, label):
258
+ """custom your label format here."""
259
+ return label
260
+
261
+ def build_transforms(self, hyp=None):
262
+ """Users can custom augmentations here
263
+ like:
264
+ if self.augment:
265
+ # Training transforms
266
+ return Compose([])
267
+ else:
268
+ # Val transforms
269
+ return Compose([])
270
+ """
271
+ raise NotImplementedError
272
+
273
+ def get_labels(self):
274
+ """Users can custom their own format here.
275
+ Make sure your output is a list with each element like below:
276
+ dict(
277
+ im_file=im_file,
278
+ shape=shape, # format: (height, width)
279
+ cls=cls,
280
+ bboxes=bboxes, # xywh
281
+ segments=segments, # xy
282
+ keypoints=keypoints, # xy
283
+ normalized=True, # or False
284
+ bbox_format="xyxy", # or xywh, ltwh
285
+ )
286
+ """
287
+ raise NotImplementedError
ultralytics/data/build.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import os
4
+ import random
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import dataloader, distributed
11
+
12
+ from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor,
13
+ SourceTypes, autocast_list)
14
+ from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
15
+ from ultralytics.utils import RANK, colorstr
16
+ from ultralytics.utils.checks import check_file
17
+
18
+ from .dataset import YOLODataset
19
+ from .utils import PIN_MEMORY
20
+
21
+
22
+ class InfiniteDataLoader(dataloader.DataLoader):
23
+ """Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ """Dataloader that infinitely recycles workers, inherits from DataLoader."""
27
+ super().__init__(*args, **kwargs)
28
+ object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
29
+ self.iterator = super().__iter__()
30
+
31
+ def __len__(self):
32
+ """Returns the length of the batch sampler's sampler."""
33
+ return len(self.batch_sampler.sampler)
34
+
35
+ def __iter__(self):
36
+ """Creates a sampler that repeats indefinitely."""
37
+ for _ in range(len(self)):
38
+ yield next(self.iterator)
39
+
40
+ def reset(self):
41
+ """Reset iterator.
42
+ This is useful when we want to modify settings of dataset while training.
43
+ """
44
+ self.iterator = self._get_iterator()
45
+
46
+
47
+ class _RepeatSampler:
48
+ """
49
+ Sampler that repeats forever.
50
+
51
+ Args:
52
+ sampler (Dataset.sampler): The sampler to repeat.
53
+ """
54
+
55
+ def __init__(self, sampler):
56
+ """Initializes an object that repeats a given sampler indefinitely."""
57
+ self.sampler = sampler
58
+
59
+ def __iter__(self):
60
+ """Iterates over the 'sampler' and yields its contents."""
61
+ while True:
62
+ yield from iter(self.sampler)
63
+
64
+
65
+ def seed_worker(worker_id): # noqa
66
+ """Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
67
+ worker_seed = torch.initial_seed() % 2 ** 32
68
+ np.random.seed(worker_seed)
69
+ random.seed(worker_seed)
70
+
71
+
72
+ def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
73
+ """Build YOLO Dataset"""
74
+ return YOLODataset(
75
+ img_path=img_path,
76
+ imgsz=cfg.imgsz,
77
+ batch_size=batch,
78
+ augment=mode == 'train', # augmentation
79
+ hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
80
+ rect=cfg.rect or rect, # rectangular batches
81
+ cache=cfg.cache or None,
82
+ single_cls=cfg.single_cls or False,
83
+ stride=int(stride),
84
+ pad=0.0 if mode == 'train' else 0.5,
85
+ prefix=colorstr(f'{mode}: '),
86
+ use_segments=cfg.task == 'segment',
87
+ use_keypoints=cfg.task == 'pose',
88
+ classes=cfg.classes,
89
+ data=data,
90
+ fraction=cfg.fraction if mode == 'train' else 1.0)
91
+
92
+
93
+ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
94
+ """Return an InfiniteDataLoader or DataLoader for training or validation set."""
95
+ batch = min(batch, len(dataset))
96
+ nd = torch.cuda.device_count() # number of CUDA devices
97
+ nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
98
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
99
+ generator = torch.Generator()
100
+ generator.manual_seed(6148914691236517205 + RANK)
101
+ return InfiniteDataLoader(dataset=dataset,
102
+ batch_size=batch,
103
+ shuffle=shuffle and sampler is None,
104
+ num_workers=nw,
105
+ sampler=sampler,
106
+ pin_memory=PIN_MEMORY,
107
+ collate_fn=getattr(dataset, 'collate_fn', None),
108
+ worker_init_fn=seed_worker,
109
+ generator=generator)
110
+
111
+
112
+ def check_source(source):
113
+ """Check source type and return corresponding flag values."""
114
+ webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
115
+ if isinstance(source, (str, int, Path)): # int for local usb camera
116
+ source = str(source)
117
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
118
+ is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
119
+ webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
120
+ screenshot = source.lower() == 'screen'
121
+ if is_url and is_file:
122
+ source = check_file(source) # download
123
+ elif isinstance(source, tuple(LOADERS)):
124
+ in_memory = True
125
+ elif isinstance(source, (list, tuple)):
126
+ source = autocast_list(source) # convert all list elements to PIL or np arrays
127
+ from_img = True
128
+ elif isinstance(source, (Image.Image, np.ndarray)):
129
+ from_img = True
130
+ elif isinstance(source, torch.Tensor):
131
+ tensor = True
132
+ else:
133
+ raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict')
134
+
135
+ return source, webcam, screenshot, from_img, in_memory, tensor
136
+
137
+
138
+ def load_inference_source(source=None, imgsz=640, vid_stride=1):
139
+ """
140
+ Loads an inference source for object detection and applies necessary transformations.
141
+
142
+ Args:
143
+ source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
144
+ imgsz (int, optional): The size of the image for inference. Default is 640.
145
+ vid_stride (int, optional): The frame interval for video sources. Default is 1.
146
+
147
+ Returns:
148
+ dataset (Dataset): A dataset object for the specified input source.
149
+ """
150
+ source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
151
+ source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
152
+
153
+ # Dataloader
154
+ if tensor:
155
+ dataset = LoadTensor(source)
156
+ elif in_memory:
157
+ dataset = source
158
+ elif webcam:
159
+ dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride)
160
+ elif screenshot:
161
+ dataset = LoadScreenshots(source, imgsz=imgsz)
162
+ elif from_img:
163
+ dataset = LoadPilAndNumpy(source, imgsz=imgsz)
164
+ else:
165
+ dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
166
+
167
+ # Attach source types to the dataset
168
+ setattr(dataset, 'source_type', source_type)
169
+
170
+ return dataset
ultralytics/data/converter.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+ from pathlib import Path
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from ultralytics.utils.checks import check_requirements
10
+ from ultralytics.utils.files import make_dirs
11
+
12
+
13
+ def coco91_to_coco80_class():
14
+ """Converts 91-index COCO class IDs to 80-index COCO class IDs.
15
+
16
+ Returns:
17
+ (list): A list of 91 class IDs where the index represents the 80-index class ID and the value is the
18
+ corresponding 91-index class ID.
19
+
20
+ """
21
+ return [
22
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None,
23
+ None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
24
+ 51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
25
+ None, 73, 74, 75, 76, 77, 78, 79, None]
26
+
27
+
28
+ def convert_coco(labels_dir='../coco/annotations/', use_segments=False, use_keypoints=False, cls91to80=True):
29
+ """Converts COCO dataset annotations to a format suitable for training YOLOv5 models.
30
+
31
+ Args:
32
+ labels_dir (str, optional): Path to directory containing COCO dataset annotation files.
33
+ use_segments (bool, optional): Whether to include segmentation masks in the output.
34
+ use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
35
+ cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
36
+
37
+ Raises:
38
+ FileNotFoundError: If the labels_dir path does not exist.
39
+
40
+ Example Usage:
41
+ convert_coco(labels_dir='../coco/annotations/', use_segments=True, use_keypoints=True, cls91to80=True)
42
+
43
+ Output:
44
+ Generates output files in the specified output directory.
45
+ """
46
+
47
+ save_dir = make_dirs('yolo_labels') # output directory
48
+ coco80 = coco91_to_coco80_class()
49
+
50
+ # Import json
51
+ for json_file in sorted(Path(labels_dir).resolve().glob('*.json')):
52
+ fn = Path(save_dir) / 'labels' / json_file.stem.replace('instances_', '') # folder name
53
+ fn.mkdir(parents=True, exist_ok=True)
54
+ with open(json_file) as f:
55
+ data = json.load(f)
56
+
57
+ # Create image dict
58
+ images = {f'{x["id"]:d}': x for x in data['images']}
59
+ # Create image-annotations dict
60
+ imgToAnns = defaultdict(list)
61
+ for ann in data['annotations']:
62
+ imgToAnns[ann['image_id']].append(ann)
63
+
64
+ # Write labels file
65
+ for img_id, anns in tqdm(imgToAnns.items(), desc=f'Annotations {json_file}'):
66
+ img = images[f'{img_id:d}']
67
+ h, w, f = img['height'], img['width'], img['file_name']
68
+
69
+ bboxes = []
70
+ segments = []
71
+ keypoints = []
72
+ for ann in anns:
73
+ if ann['iscrowd']:
74
+ continue
75
+ # The COCO box format is [top left x, top left y, width, height]
76
+ box = np.array(ann['bbox'], dtype=np.float64)
77
+ box[:2] += box[2:] / 2 # xy top-left corner to center
78
+ box[[0, 2]] /= w # normalize x
79
+ box[[1, 3]] /= h # normalize y
80
+ if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
81
+ continue
82
+
83
+ cls = coco80[ann['category_id'] - 1] if cls91to80 else ann['category_id'] - 1 # class
84
+ box = [cls] + box.tolist()
85
+ if box not in bboxes:
86
+ bboxes.append(box)
87
+ if use_segments and ann.get('segmentation') is not None:
88
+ if len(ann['segmentation']) == 0:
89
+ segments.append([])
90
+ continue
91
+ if isinstance(ann['segmentation'], dict):
92
+ ann['segmentation'] = rle2polygon(ann['segmentation'])
93
+ if len(ann['segmentation']) > 1:
94
+ s = merge_multi_segment(ann['segmentation'])
95
+ s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()
96
+ else:
97
+ s = [j for i in ann['segmentation'] for j in i] # all segments concatenated
98
+ s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist()
99
+ s = [cls] + s
100
+ if s not in segments:
101
+ segments.append(s)
102
+ if use_keypoints and ann.get('keypoints') is not None:
103
+ k = (np.array(ann['keypoints']).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist()
104
+ k = box + k
105
+ keypoints.append(k)
106
+
107
+ # Write
108
+ with open((fn / f).with_suffix('.txt'), 'a') as file:
109
+ for i in range(len(bboxes)):
110
+ if use_keypoints:
111
+ line = *(keypoints[i]), # cls, box, keypoints
112
+ else:
113
+ line = *(segments[i]
114
+ if use_segments and len(segments[i]) > 0 else bboxes[i]), # cls, box or segments
115
+ file.write(('%g ' * len(line)).rstrip() % line + '\n')
116
+
117
+
118
+ def rle2polygon(segmentation):
119
+ """
120
+ Convert Run-Length Encoding (RLE) mask to polygon coordinates.
121
+
122
+ Args:
123
+ segmentation (dict, list): RLE mask representation of the object segmentation.
124
+
125
+ Returns:
126
+ (list): A list of lists representing the polygon coordinates for each contour.
127
+
128
+ Note:
129
+ Requires the 'pycocotools' package to be installed.
130
+ """
131
+ check_requirements('pycocotools')
132
+ from pycocotools import mask
133
+
134
+ m = mask.decode(segmentation)
135
+ m[m > 0] = 255
136
+ contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)
137
+ polygons = []
138
+ for contour in contours:
139
+ epsilon = 0.001 * cv2.arcLength(contour, True)
140
+ contour_approx = cv2.approxPolyDP(contour, epsilon, True)
141
+ polygon = contour_approx.flatten().tolist()
142
+ polygons.append(polygon)
143
+ return polygons
144
+
145
+
146
+ def min_index(arr1, arr2):
147
+ """
148
+ Find a pair of indexes with the shortest distance between two arrays of 2D points.
149
+
150
+ Args:
151
+ arr1 (np.array): A NumPy array of shape (N, 2) representing N 2D points.
152
+ arr2 (np.array): A NumPy array of shape (M, 2) representing M 2D points.
153
+
154
+ Returns:
155
+ (tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively.
156
+ """
157
+ dis = ((arr1[:, None, :] - arr2[None, :, :]) ** 2).sum(-1)
158
+ return np.unravel_index(np.argmin(dis, axis=None), dis.shape)
159
+
160
+
161
+ def merge_multi_segment(segments):
162
+ """
163
+ Merge multiple segments into one list by connecting the coordinates with the minimum distance between each segment.
164
+ This function connects these coordinates with a thin line to merge all segments into one.
165
+
166
+ Args:
167
+ segments (List[List]): Original segmentations in COCO's JSON file.
168
+ Each element is a list of coordinates, like [segmentation1, segmentation2,...].
169
+
170
+ Returns:
171
+ s (List[np.ndarray]): A list of connected segments represented as NumPy arrays.
172
+ """
173
+ s = []
174
+ segments = [np.array(i).reshape(-1, 2) for i in segments]
175
+ idx_list = [[] for _ in range(len(segments))]
176
+
177
+ # record the indexes with min distance between each segment
178
+ for i in range(1, len(segments)):
179
+ idx1, idx2 = min_index(segments[i - 1], segments[i])
180
+ idx_list[i - 1].append(idx1)
181
+ idx_list[i].append(idx2)
182
+
183
+ # use two round to connect all the segments
184
+ for k in range(2):
185
+ # forward connection
186
+ if k == 0:
187
+ for i, idx in enumerate(idx_list):
188
+ # middle segments have two indexes
189
+ # reverse the index of middle segments
190
+ if len(idx) == 2 and idx[0] > idx[1]:
191
+ idx = idx[::-1]
192
+ segments[i] = segments[i][::-1, :]
193
+
194
+ segments[i] = np.roll(segments[i], -idx[0], axis=0)
195
+ segments[i] = np.concatenate([segments[i], segments[i][:1]])
196
+ # deal with the first segment and the last one
197
+ if i in [0, len(idx_list) - 1]:
198
+ s.append(segments[i])
199
+ else:
200
+ idx = [0, idx[1] - idx[0]]
201
+ s.append(segments[i][idx[0]:idx[1] + 1])
202
+
203
+ else:
204
+ for i in range(len(idx_list) - 1, -1, -1):
205
+ if i not in [0, len(idx_list) - 1]:
206
+ idx = idx_list[i]
207
+ nidx = abs(idx[1] - idx[0])
208
+ s.append(segments[i][nidx:])
209
+ return s
210
+
211
+
212
+ def delete_dsstore(path='../datasets'):
213
+ """Delete Apple .DS_Store files in the specified directory and its subdirectories."""
214
+ from pathlib import Path
215
+
216
+ files = list(Path(path).rglob('.DS_store'))
217
+ print(files)
218
+ for f in files:
219
+ f.unlink()
220
+
221
+
222
+ if __name__ == '__main__':
223
+ source = 'COCO'
224
+
225
+ if source == 'COCO':
226
+ convert_coco(
227
+ '../datasets/coco/annotations', # directory with *.json
228
+ use_segments=False,
229
+ use_keypoints=True,
230
+ cls91to80=False)
ultralytics/data/dataloaders/__init__.py ADDED
File without changes
ultralytics/data/dataset.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from itertools import repeat
4
+ from multiprocessing.pool import ThreadPool
5
+ from pathlib import Path
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ import torchvision
11
+ from tqdm import tqdm
12
+
13
+ from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
14
+
15
+ from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
16
+ from .base import BaseDataset
17
+ from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label
18
+
19
+
20
+ class YOLODataset(BaseDataset):
21
+ """
22
+ Dataset class for loading object detection and/or segmentation labels in YOLO format.
23
+
24
+ Args:
25
+ data (dict, optional): A dataset YAML dictionary. Defaults to None.
26
+ use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
27
+ use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
28
+
29
+ Returns:
30
+ (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
31
+ """
32
+ cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
33
+ rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
34
+
35
+ def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs):
36
+ self.use_segments = use_segments
37
+ self.use_keypoints = use_keypoints
38
+ self.data = data
39
+ assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
40
+ super().__init__(*args, **kwargs)
41
+
42
+ def cache_labels(self, path=Path('./labels.cache')):
43
+ """Cache dataset labels, check images and read shapes.
44
+ Args:
45
+ path (Path): path where to save the cache file (default: Path('./labels.cache')).
46
+ Returns:
47
+ (dict): labels.
48
+ """
49
+ x = {'labels': []}
50
+ nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
51
+ desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
52
+ total = len(self.im_files)
53
+ nkpt, ndim = self.data.get('kpt_shape', (0, 0))
54
+ if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
55
+ raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
56
+ "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")
57
+ with ThreadPool(NUM_THREADS) as pool:
58
+ results = pool.imap(func=verify_image_label,
59
+ iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
60
+ repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
61
+ repeat(ndim)))
62
+ pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
63
+ for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
64
+ nm += nm_f
65
+ nf += nf_f
66
+ ne += ne_f
67
+ nc += nc_f
68
+ if im_file:
69
+ x['labels'].append(
70
+ dict(
71
+ im_file=im_file,
72
+ shape=shape,
73
+ cls=lb[:, 0:1], # n, 1
74
+ bboxes=lb[:, 1:], # n, 4
75
+ segments=segments,
76
+ keypoints=keypoint,
77
+ normalized=True,
78
+ bbox_format='xywh'))
79
+ if msg:
80
+ msgs.append(msg)
81
+ pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
82
+ pbar.close()
83
+
84
+ if msgs:
85
+ LOGGER.info('\n'.join(msgs))
86
+ if nf == 0:
87
+ LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
88
+ x['hash'] = get_hash(self.label_files + self.im_files)
89
+ x['results'] = nf, nm, ne, nc, len(self.im_files)
90
+ x['msgs'] = msgs # warnings
91
+ x['version'] = self.cache_version # cache version
92
+ if is_dir_writeable(path.parent):
93
+ if path.exists():
94
+ path.unlink() # remove *.cache file if exists
95
+ np.save(str(path), x) # save cache for next time
96
+ path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
97
+ LOGGER.info(f'{self.prefix}New cache created: {path}')
98
+ else:
99
+ LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
100
+ return x
101
+
102
+ def get_labels(self):
103
+ """Returns dictionary of labels for YOLO training."""
104
+ self.label_files = img2label_paths(self.im_files)
105
+ cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
106
+ try:
107
+ import gc
108
+ gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
109
+ cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
110
+ gc.enable()
111
+ assert cache['version'] == self.cache_version # matches current version
112
+ assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
113
+ except (FileNotFoundError, AssertionError, AttributeError):
114
+ cache, exists = self.cache_labels(cache_path), False # run cache ops
115
+
116
+ # Display cache
117
+ nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
118
+ if exists and LOCAL_RANK in (-1, 0):
119
+ d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
120
+ tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
121
+ if cache['msgs']:
122
+ LOGGER.info('\n'.join(cache['msgs'])) # display warnings
123
+ if nf == 0: # number of labels found
124
+ raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}')
125
+
126
+ # Read cache
127
+ [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
128
+ labels = cache['labels']
129
+ self.im_files = [lb['im_file'] for lb in labels] # update im_files
130
+
131
+ # Check if the dataset is all boxes or all segments
132
+ lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
133
+ len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
134
+ if len_segments and len_boxes != len_segments:
135
+ LOGGER.warning(
136
+ f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
137
+ f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
138
+ 'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
139
+ for lb in labels:
140
+ lb['segments'] = []
141
+ if len_cls == 0:
142
+ raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}')
143
+ return labels
144
+
145
+ # TODO: use hyp config to set all these augmentations
146
+ def build_transforms(self, hyp=None):
147
+ """Builds and appends transforms to the list."""
148
+ if self.augment:
149
+ hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
150
+ hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
151
+ transforms = v8_transforms(self, self.imgsz, hyp)
152
+ else:
153
+ transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
154
+ transforms.append(
155
+ Format(bbox_format='xywh',
156
+ normalize=True,
157
+ return_mask=self.use_segments,
158
+ return_keypoint=self.use_keypoints,
159
+ batch_idx=True,
160
+ mask_ratio=hyp.mask_ratio,
161
+ mask_overlap=hyp.overlap_mask))
162
+ return transforms
163
+
164
+ def close_mosaic(self, hyp):
165
+ """Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
166
+ hyp.mosaic = 0.0 # set mosaic ratio=0.0
167
+ hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
168
+ hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
169
+ self.transforms = self.build_transforms(hyp)
170
+
171
+ def update_labels_info(self, label):
172
+ """custom your label format here."""
173
+ # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
174
+ # we can make it also support classification and semantic segmentation by add or remove some dict keys there.
175
+ bboxes = label.pop('bboxes')
176
+ segments = label.pop('segments')
177
+ keypoints = label.pop('keypoints', None)
178
+ bbox_format = label.pop('bbox_format')
179
+ normalized = label.pop('normalized')
180
+ label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
181
+ return label
182
+
183
+ @staticmethod
184
+ def collate_fn(batch):
185
+ """Collates data samples into batches."""
186
+ new_batch = {}
187
+ keys = batch[0].keys()
188
+ values = list(zip(*[list(b.values()) for b in batch]))
189
+ for i, k in enumerate(keys):
190
+ value = values[i]
191
+ if k == 'img':
192
+ value = torch.stack(value, 0)
193
+ if k in ['masks', 'keypoints', 'bboxes', 'cls']:
194
+ value = torch.cat(value, 0)
195
+ new_batch[k] = value
196
+ new_batch['batch_idx'] = list(new_batch['batch_idx'])
197
+ for i in range(len(new_batch['batch_idx'])):
198
+ new_batch['batch_idx'][i] += i # add target image index for build_targets()
199
+ new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
200
+ return new_batch
201
+
202
+
203
+ # Classification dataloaders -------------------------------------------------------------------------------------------
204
+ class ClassificationDataset(torchvision.datasets.ImageFolder):
205
+ """
206
+ YOLO Classification Dataset.
207
+
208
+ Args:
209
+ root (str): Dataset path.
210
+
211
+ Attributes:
212
+ cache_ram (bool): True if images should be cached in RAM, False otherwise.
213
+ cache_disk (bool): True if images should be cached on disk, False otherwise.
214
+ samples (list): List of samples containing file, index, npy, and im.
215
+ torch_transforms (callable): torchvision transforms applied to the dataset.
216
+ album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
217
+ """
218
+
219
+ def __init__(self, root, args, augment=False, cache=False):
220
+ """
221
+ Initialize YOLO object with root, image size, augmentations, and cache settings.
222
+
223
+ Args:
224
+ root (str): Dataset path.
225
+ args (Namespace): Argument parser containing dataset related settings.
226
+ augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False.
227
+ cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
228
+ """
229
+ super().__init__(root=root)
230
+ if augment and args.fraction < 1.0: # reduce training fraction
231
+ self.samples = self.samples[:round(len(self.samples) * args.fraction)]
232
+ self.cache_ram = cache is True or cache == 'ram'
233
+ self.cache_disk = cache == 'disk'
234
+ self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
235
+ self.torch_transforms = classify_transforms(args.imgsz)
236
+ self.album_transforms = classify_albumentations(
237
+ augment=augment,
238
+ size=args.imgsz,
239
+ scale=(1.0 - args.scale, 1.0), # (0.08, 1.0)
240
+ hflip=args.fliplr,
241
+ vflip=args.flipud,
242
+ hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction)
243
+ hsv_s=args.hsv_s, # HSV-Saturation augmentation (fraction)
244
+ hsv_v=args.hsv_v, # HSV-Value augmentation (fraction)
245
+ mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
246
+ std=(1.0, 1.0, 1.0), # IMAGENET_STD
247
+ auto_aug=False) if augment else None
248
+
249
+ def __getitem__(self, i):
250
+ """Returns subset of data and targets corresponding to given indices."""
251
+ f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
252
+ if self.cache_ram and im is None:
253
+ im = self.samples[i][3] = cv2.imread(f)
254
+ elif self.cache_disk:
255
+ if not fn.exists(): # load npy
256
+ np.save(fn.as_posix(), cv2.imread(f))
257
+ im = np.load(fn)
258
+ else: # read image
259
+ im = cv2.imread(f) # BGR
260
+ if self.album_transforms:
261
+ sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
262
+ else:
263
+ sample = self.torch_transforms(im)
264
+ return {'img': sample, 'cls': j}
265
+
266
+ def __len__(self) -> int:
267
+ return len(self.samples)
268
+
269
+
270
+ # TODO: support semantic segmentation
271
+ class SemanticDataset(BaseDataset):
272
+
273
+ def __init__(self):
274
+ """Initialize a SemanticDataset object."""
275
+ super().__init__()
ultralytics/data/loaders.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import glob
4
+ import math
5
+ import os
6
+ import time
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from threading import Thread
10
+ from urllib.parse import urlparse
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import requests
15
+ import torch
16
+ from PIL import Image
17
+
18
+ from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
19
+ from ultralytics.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
20
+ from ultralytics.utils.checks import check_requirements
21
+
22
+
23
+ @dataclass
24
+ class SourceTypes:
25
+ webcam: bool = False
26
+ screenshot: bool = False
27
+ from_img: bool = False
28
+ tensor: bool = False
29
+
30
+
31
+ class LoadStreams:
32
+ """YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
33
+
34
+ def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
35
+ """Initialize instance variables and check for consistent input stream shapes."""
36
+ torch.backends.cudnn.benchmark = True # faster for fixed-size inference
37
+ self.mode = 'stream'
38
+ self.imgsz = imgsz
39
+ self.vid_stride = vid_stride # video frame-rate stride
40
+ sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
41
+ n = len(sources)
42
+ self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
43
+ self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [None] * n
44
+ for i, s in enumerate(sources): # index, source
45
+ # Start thread to read frames from video stream
46
+ st = f'{i + 1}/{n}: {s}... '
47
+ if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
48
+ # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
49
+ s = get_best_youtube_url(s)
50
+ s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
51
+ if s == 0 and (is_colab() or is_kaggle()):
52
+ raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
53
+ "Try running 'source=0' in a local environment.")
54
+ cap = cv2.VideoCapture(s)
55
+ if not cap.isOpened():
56
+ raise ConnectionError(f'{st}Failed to open {s}')
57
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
58
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
59
+ fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
60
+ self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
61
+ self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
62
+
63
+ success, im = cap.read() # guarantee first frame
64
+ if not success or im is None:
65
+ raise ConnectionError(f'{st}Failed to read images from {s}')
66
+ self.imgs[i].append(im)
67
+ self.shape[i] = im.shape
68
+ self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
69
+ LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
70
+ self.threads[i].start()
71
+ LOGGER.info('') # newline
72
+
73
+ # Check for common shapes
74
+ self.bs = self.__len__()
75
+
76
+ def update(self, i, cap, stream):
77
+ """Read stream `i` frames in daemon thread."""
78
+ n, f = 0, self.frames[i] # frame number, frame array
79
+ while cap.isOpened() and n < f:
80
+ # Only read a new frame if the buffer is empty
81
+ if not self.imgs[i]:
82
+ n += 1
83
+ cap.grab() # .read() = .grab() followed by .retrieve()
84
+ if n % self.vid_stride == 0:
85
+ success, im = cap.retrieve()
86
+ if success:
87
+ self.imgs[i].append(im) # add image to buffer
88
+ else:
89
+ LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
90
+ self.imgs[i].append(np.zeros(self.shape[i]))
91
+ cap.open(stream) # re-open stream if signal was lost
92
+ else:
93
+ time.sleep(0.01) # wait until the buffer is empty
94
+
95
+ def __iter__(self):
96
+ """Iterates through YOLO image feed and re-opens unresponsive streams."""
97
+ self.count = -1
98
+ return self
99
+
100
+ def __next__(self):
101
+ """Returns source paths, transformed and original images for processing."""
102
+ self.count += 1
103
+
104
+ # Wait until a frame is available in each buffer
105
+ while not all(self.imgs):
106
+ if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
107
+ cv2.destroyAllWindows()
108
+ raise StopIteration
109
+ time.sleep(1 / min(self.fps))
110
+
111
+ # Get and remove the next frame from imgs buffer
112
+ return self.sources, [x.pop(0) for x in self.imgs], None, ''
113
+
114
+ def __len__(self):
115
+ """Return the length of the sources object."""
116
+ return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
117
+
118
+
119
+ class LoadScreenshots:
120
+ """YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`."""
121
+
122
+ def __init__(self, source, imgsz=640):
123
+ """source = [screen_number left top width height] (pixels)."""
124
+ check_requirements('mss')
125
+ import mss # noqa
126
+
127
+ source, *params = source.split()
128
+ self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
129
+ if len(params) == 1:
130
+ self.screen = int(params[0])
131
+ elif len(params) == 4:
132
+ left, top, width, height = (int(x) for x in params)
133
+ elif len(params) == 5:
134
+ self.screen, left, top, width, height = (int(x) for x in params)
135
+ self.imgsz = imgsz
136
+ self.mode = 'stream'
137
+ self.frame = 0
138
+ self.sct = mss.mss()
139
+ self.bs = 1
140
+
141
+ # Parse monitor shape
142
+ monitor = self.sct.monitors[self.screen]
143
+ self.top = monitor['top'] if top is None else (monitor['top'] + top)
144
+ self.left = monitor['left'] if left is None else (monitor['left'] + left)
145
+ self.width = width or monitor['width']
146
+ self.height = height or monitor['height']
147
+ self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
148
+
149
+ def __iter__(self):
150
+ """Returns an iterator of the object."""
151
+ return self
152
+
153
+ def __next__(self):
154
+ """mss screen capture: get raw pixels from the screen as np array."""
155
+ im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
156
+ s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
157
+
158
+ self.frame += 1
159
+ return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string
160
+
161
+
162
+ class LoadImages:
163
+ """YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`."""
164
+
165
+ def __init__(self, path, imgsz=640, vid_stride=1):
166
+ """Initialize the Dataloader and raise FileNotFoundError if file not found."""
167
+ parent = None
168
+ if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
169
+ parent = Path(path).parent
170
+ path = Path(path).read_text().rsplit()
171
+ files = []
172
+ for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
173
+ a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
174
+ if '*' in a:
175
+ files.extend(sorted(glob.glob(a, recursive=True))) # glob
176
+ elif os.path.isdir(a):
177
+ files.extend(sorted(glob.glob(os.path.join(a, '*.*')))) # dir
178
+ elif os.path.isfile(a):
179
+ files.append(a) # files (absolute or relative to CWD)
180
+ elif parent and (parent / p).is_file():
181
+ files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
182
+ else:
183
+ raise FileNotFoundError(f'{p} does not exist')
184
+
185
+ images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
186
+ videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
187
+ ni, nv = len(images), len(videos)
188
+
189
+ self.imgsz = imgsz
190
+ self.files = images + videos
191
+ self.nf = ni + nv # number of files
192
+ self.video_flag = [False] * ni + [True] * nv
193
+ self.mode = 'image'
194
+ self.vid_stride = vid_stride # video frame-rate stride
195
+ self.bs = 1
196
+ if any(videos):
197
+ self.orientation = None # rotation degrees
198
+ self._new_video(videos[0]) # new video
199
+ else:
200
+ self.cap = None
201
+ if self.nf == 0:
202
+ raise FileNotFoundError(f'No images or videos found in {p}. '
203
+ f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
204
+
205
+ def __iter__(self):
206
+ """Returns an iterator object for VideoStream or ImageFolder."""
207
+ self.count = 0
208
+ return self
209
+
210
+ def __next__(self):
211
+ """Return next image, path and metadata from dataset."""
212
+ if self.count == self.nf:
213
+ raise StopIteration
214
+ path = self.files[self.count]
215
+
216
+ if self.video_flag[self.count]:
217
+ # Read video
218
+ self.mode = 'video'
219
+ for _ in range(self.vid_stride):
220
+ self.cap.grab()
221
+ success, im0 = self.cap.retrieve()
222
+ while not success:
223
+ self.count += 1
224
+ self.cap.release()
225
+ if self.count == self.nf: # last video
226
+ raise StopIteration
227
+ path = self.files[self.count]
228
+ self._new_video(path)
229
+ success, im0 = self.cap.read()
230
+
231
+ self.frame += 1
232
+ # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
233
+ s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
234
+
235
+ else:
236
+ # Read image
237
+ self.count += 1
238
+ im0 = cv2.imread(path) # BGR
239
+ if im0 is None:
240
+ raise FileNotFoundError(f'Image Not Found {path}')
241
+ s = f'image {self.count}/{self.nf} {path}: '
242
+
243
+ return [path], [im0], self.cap, s
244
+
245
+ def _new_video(self, path):
246
+ """Create a new video capture object."""
247
+ self.frame = 0
248
+ self.cap = cv2.VideoCapture(path)
249
+ self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
250
+ if hasattr(cv2, 'CAP_PROP_ORIENTATION_META'): # cv2<4.6.0 compatibility
251
+ self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
252
+ # Disable auto-orientation due to known issues in https://github.com/ultralytics/yolov5/issues/8493
253
+ # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)
254
+
255
+ def _cv2_rotate(self, im):
256
+ """Rotate a cv2 video manually."""
257
+ if self.orientation == 0:
258
+ return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
259
+ elif self.orientation == 180:
260
+ return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
261
+ elif self.orientation == 90:
262
+ return cv2.rotate(im, cv2.ROTATE_180)
263
+ return im
264
+
265
+ def __len__(self):
266
+ """Returns the number of files in the object."""
267
+ return self.nf # number of files
268
+
269
+
270
+ class LoadPilAndNumpy:
271
+
272
+ def __init__(self, im0, imgsz=640):
273
+ """Initialize PIL and Numpy Dataloader."""
274
+ if not isinstance(im0, list):
275
+ im0 = [im0]
276
+ self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
277
+ self.im0 = [self._single_check(im) for im in im0]
278
+ self.imgsz = imgsz
279
+ self.mode = 'image'
280
+ # Generate fake paths
281
+ self.bs = len(self.im0)
282
+
283
+ @staticmethod
284
+ def _single_check(im):
285
+ """Validate and format an image to numpy array."""
286
+ assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
287
+ if isinstance(im, Image.Image):
288
+ if im.mode != 'RGB':
289
+ im = im.convert('RGB')
290
+ im = np.asarray(im)[:, :, ::-1]
291
+ im = np.ascontiguousarray(im) # contiguous
292
+ return im
293
+
294
+ def __len__(self):
295
+ """Returns the length of the 'im0' attribute."""
296
+ return len(self.im0)
297
+
298
+ def __next__(self):
299
+ """Returns batch paths, images, processed images, None, ''."""
300
+ if self.count == 1: # loop only once as it's batch inference
301
+ raise StopIteration
302
+ self.count += 1
303
+ return self.paths, self.im0, None, ''
304
+
305
+ def __iter__(self):
306
+ """Enables iteration for class LoadPilAndNumpy."""
307
+ self.count = 0
308
+ return self
309
+
310
+
311
+ class LoadTensor:
312
+
313
+ def __init__(self, im0) -> None:
314
+ self.im0 = self._single_check(im0)
315
+ self.bs = self.im0.shape[0]
316
+ self.mode = 'image'
317
+ self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
318
+
319
+ @staticmethod
320
+ def _single_check(im, stride=32):
321
+ """Validate and format an image to torch.Tensor."""
322
+ s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
323
+ f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
324
+ if len(im.shape) != 4:
325
+ if len(im.shape) != 3:
326
+ raise ValueError(s)
327
+ LOGGER.warning(s)
328
+ im = im.unsqueeze(0)
329
+ if im.shape[2] % stride or im.shape[3] % stride:
330
+ raise ValueError(s)
331
+ if im.max() > 1.0:
332
+ LOGGER.warning(f'WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. '
333
+ f'Dividing input by 255.')
334
+ im = im.float() / 255.0
335
+
336
+ return im
337
+
338
+ def __iter__(self):
339
+ """Returns an iterator object."""
340
+ self.count = 0
341
+ return self
342
+
343
+ def __next__(self):
344
+ """Return next item in the iterator."""
345
+ if self.count == 1:
346
+ raise StopIteration
347
+ self.count += 1
348
+ return self.paths, self.im0, None, ''
349
+
350
+ def __len__(self):
351
+ """Returns the batch size."""
352
+ return self.bs
353
+
354
+
355
+ def autocast_list(source):
356
+ """
357
+ Merges a list of source of different types into a list of numpy arrays or PIL images
358
+ """
359
+ files = []
360
+ for im in source:
361
+ if isinstance(im, (str, Path)): # filename or uri
362
+ files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
363
+ elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
364
+ files.append(im)
365
+ else:
366
+ raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
367
+ f'See https://docs.ultralytics.com/modes/predict for supported source types.')
368
+
369
+ return files
370
+
371
+
372
+ LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
373
+
374
+
375
+ def get_best_youtube_url(url, use_pafy=True):
376
+ """
377
+ Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
378
+
379
+ This function uses the pafy or yt_dlp library to extract the video info from YouTube. It then finds the highest
380
+ quality MP4 format that has video codec but no audio codec, and returns the URL of this video stream.
381
+
382
+ Args:
383
+ url (str): The URL of the YouTube video.
384
+ use_pafy (bool): Use the pafy package, default=True, otherwise use yt_dlp package.
385
+
386
+ Returns:
387
+ (str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
388
+ """
389
+ if use_pafy:
390
+ check_requirements(('pafy', 'youtube_dl==2020.12.2'))
391
+ import pafy # noqa
392
+ return pafy.new(url).getbest(preftype='mp4').url
393
+ else:
394
+ check_requirements('yt-dlp')
395
+ import yt_dlp
396
+ with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
397
+ info_dict = ydl.extract_info(url, download=False) # extract info
398
+ for f in info_dict.get('formats', None):
399
+ if f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4':
400
+ return f.get('url', None)
401
+
402
+
403
+ if __name__ == '__main__':
404
+ img = cv2.imread(str(ROOT / 'assets/bus.jpg'))
405
+ dataset = LoadPilAndNumpy(im0=img)
406
+ for d in dataset:
407
+ print(d[0])
ultralytics/data/scripts/download_weights.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download latest models from https://github.com/ultralytics/assets/releases
4
+ # Example usage: bash ultralytics/data/scripts/download_weights.sh
5
+ # parent
6
+ # └── weights
7
+ # ├── yolov8n.pt ← downloads here
8
+ # ├── yolov8s.pt
9
+ # └── ...
10
+
11
+ python - <<EOF
12
+ from ultralytics.utils.downloads import attempt_download_asset
13
+
14
+ assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose')]
15
+ for x in assets:
16
+ attempt_download_asset(f'weights/{x}')
17
+
18
+ EOF
ultralytics/data/scripts/get_coco.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download COCO 2017 dataset http://cocodataset.org
4
+ # Example usage: bash data/scripts/get_coco.sh
5
+ # parent
6
+ # ├── ultralytics
7
+ # └── datasets
8
+ # └── coco ← downloads here
9
+
10
+ # Arguments (optional) Usage: bash data/scripts/get_coco.sh --train --val --test --segments
11
+ if [ "$#" -gt 0 ]; then
12
+ for opt in "$@"; do
13
+ case "${opt}" in
14
+ --train) train=true ;;
15
+ --val) val=true ;;
16
+ --test) test=true ;;
17
+ --segments) segments=true ;;
18
+ --sama) sama=true ;;
19
+ esac
20
+ done
21
+ else
22
+ train=true
23
+ val=true
24
+ test=false
25
+ segments=false
26
+ sama=false
27
+ fi
28
+
29
+ # Download/unzip labels
30
+ d='../datasets' # unzip directory
31
+ url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
32
+ if [ "$segments" == "true" ]; then
33
+ f='coco2017labels-segments.zip' # 169 MB
34
+ elif [ "$sama" == "true" ]; then
35
+ f='coco2017labels-segments-sama.zip' # 199 MB https://www.sama.com/sama-coco-dataset/
36
+ else
37
+ f='coco2017labels.zip' # 46 MB
38
+ fi
39
+ echo 'Downloading' $url$f ' ...'
40
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
41
+
42
+ # Download/unzip images
43
+ d='../datasets/coco/images' # unzip directory
44
+ url=http://images.cocodataset.org/zips/
45
+ if [ "$train" == "true" ]; then
46
+ f='train2017.zip' # 19G, 118k images
47
+ echo 'Downloading' $url$f '...'
48
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
49
+ fi
50
+ if [ "$val" == "true" ]; then
51
+ f='val2017.zip' # 1G, 5k images
52
+ echo 'Downloading' $url$f '...'
53
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
54
+ fi
55
+ if [ "$test" == "true" ]; then
56
+ f='test2017.zip' # 7G, 41k images (optional)
57
+ echo 'Downloading' $url$f '...'
58
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
59
+ fi
60
+ wait # finish background tasks
ultralytics/data/scripts/get_coco128.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017)
4
+ # Example usage: bash data/scripts/get_coco128.sh
5
+ # parent
6
+ # ├── ultralytics
7
+ # └── datasets
8
+ # └── coco128 ← downloads here
9
+
10
+ # Download/unzip images and labels
11
+ d='../datasets' # unzip directory
12
+ url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
13
+ f='coco128.zip' # or 'coco128-segments.zip', 68 MB
14
+ echo 'Downloading' $url$f ' ...'
15
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
16
+
17
+ wait # finish background tasks
ultralytics/data/scripts/get_imagenet.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download ILSVRC2012 ImageNet dataset https://image-net.org
4
+ # Example usage: bash data/scripts/get_imagenet.sh
5
+ # parent
6
+ # ├── ultralytics
7
+ # └── datasets
8
+ # └── imagenet ← downloads here
9
+
10
+ # Arguments (optional) Usage: bash data/scripts/get_imagenet.sh --train --val
11
+ if [ "$#" -gt 0 ]; then
12
+ for opt in "$@"; do
13
+ case "${opt}" in
14
+ --train) train=true ;;
15
+ --val) val=true ;;
16
+ esac
17
+ done
18
+ else
19
+ train=true
20
+ val=true
21
+ fi
22
+
23
+ # Make dir
24
+ d='../datasets/imagenet' # unzip directory
25
+ mkdir -p $d && cd $d
26
+
27
+ # Download/unzip train
28
+ if [ "$train" == "true" ]; then
29
+ wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar # download 138G, 1281167 images
30
+ mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
31
+ tar -xf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
32
+ find . -name "*.tar" | while read NAME; do
33
+ mkdir -p "${NAME%.tar}"
34
+ tar -xf "${NAME}" -C "${NAME%.tar}"
35
+ rm -f "${NAME}"
36
+ done
37
+ cd ..
38
+ fi
39
+
40
+ # Download/unzip val
41
+ if [ "$val" == "true" ]; then
42
+ wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar # download 6.3G, 50000 images
43
+ mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xf ILSVRC2012_img_val.tar
44
+ wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash # move into subdirs
45
+ fi
46
+
47
+ # Delete corrupted image (optional: PNG under JPEG name that may cause dataloaders to fail)
48
+ # rm train/n04266014/n04266014_10835.JPEG
49
+
50
+ # TFRecords (optional)
51
+ # wget https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_lsvrc_2015_synsets.txt
ultralytics/data/utils.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import contextlib
4
+ import hashlib
5
+ import json
6
+ import os
7
+ import random
8
+ import subprocess
9
+ import time
10
+ import zipfile
11
+ from multiprocessing.pool import ThreadPool
12
+ from pathlib import Path
13
+ from tarfile import is_tarfile
14
+
15
+ import cv2
16
+ import numpy as np
17
+ from PIL import ExifTags, Image, ImageOps
18
+ from tqdm import tqdm
19
+
20
+ from ultralytics.nn.autobackend import check_class_names
21
+ from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis,
22
+ yaml_load)
23
+ from ultralytics.utils.checks import check_file, check_font, is_ascii
24
+ from ultralytics.utils.downloads import download, safe_download, unzip_file
25
+ from ultralytics.utils.ops import segments2boxes
26
+
27
+ HELP_URL = 'See https://docs.ultralytics.com/yolov5/tutorials/train_custom_data'
28
+ IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
29
+ VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
30
+ PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
31
+ IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
32
+ IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
33
+
34
+ # Get orientation exif tag
35
+ for orientation in ExifTags.TAGS.keys():
36
+ if ExifTags.TAGS[orientation] == 'Orientation':
37
+ break
38
+
39
+
40
+ def img2label_paths(img_paths):
41
+ """Define label paths as a function of image paths."""
42
+ sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
43
+ return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
44
+
45
+
46
+ def get_hash(paths):
47
+ """Returns a single hash value of a list of paths (files or dirs)."""
48
+ size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
49
+ h = hashlib.sha256(str(size).encode()) # hash sizes
50
+ h.update(''.join(paths).encode()) # hash paths
51
+ return h.hexdigest() # return hash
52
+
53
+
54
+ def exif_size(img):
55
+ """Returns exif-corrected PIL size."""
56
+ s = img.size # (width, height)
57
+ with contextlib.suppress(Exception):
58
+ rotation = dict(img._getexif().items())[orientation]
59
+ if rotation in [6, 8]: # rotation 270 or 90
60
+ s = (s[1], s[0])
61
+ return s
62
+
63
+
64
+ def verify_image_label(args):
65
+ """Verify one image-label pair."""
66
+ im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
67
+ # Number (missing, found, empty, corrupt), message, segments, keypoints
68
+ nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
69
+ try:
70
+ # Verify images
71
+ im = Image.open(im_file)
72
+ im.verify() # PIL verify
73
+ shape = exif_size(im) # image size
74
+ shape = (shape[1], shape[0]) # hw
75
+ assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
76
+ assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
77
+ if im.format.lower() in ('jpg', 'jpeg'):
78
+ with open(im_file, 'rb') as f:
79
+ f.seek(-2, 2)
80
+ if f.read() != b'\xff\xd9': # corrupt JPEG
81
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
82
+ msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
83
+
84
+ # Verify labels
85
+ if os.path.isfile(lb_file):
86
+ nf = 1 # label found
87
+ with open(lb_file) as f:
88
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
89
+ if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
90
+ classes = np.array([x[0] for x in lb], dtype=np.float32)
91
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
92
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
93
+ lb = np.array(lb, dtype=np.float32)
94
+ nl = len(lb)
95
+ if nl:
96
+ if keypoint:
97
+ assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
98
+ assert (lb[:, 5::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
99
+ assert (lb[:, 6::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
100
+ else:
101
+ assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
102
+ assert (lb[:, 1:] <= 1).all(), \
103
+ f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
104
+ assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
105
+ # All labels
106
+ max_cls = int(lb[:, 0].max()) # max label count
107
+ assert max_cls <= num_cls, \
108
+ f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
109
+ f'Possible class labels are 0-{num_cls - 1}'
110
+ _, i = np.unique(lb, axis=0, return_index=True)
111
+ if len(i) < nl: # duplicate row check
112
+ lb = lb[i] # remove duplicates
113
+ if segments:
114
+ segments = [segments[x] for x in i]
115
+ msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
116
+ else:
117
+ ne = 1 # label empty
118
+ lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros(
119
+ (0, 5), dtype=np.float32)
120
+ else:
121
+ nm = 1 # label missing
122
+ lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
123
+ if keypoint:
124
+ keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
125
+ if ndim == 2:
126
+ kpt_mask = np.ones(keypoints.shape[:2], dtype=np.float32)
127
+ kpt_mask = np.where(keypoints[..., 0] < 0, 0.0, kpt_mask)
128
+ kpt_mask = np.where(keypoints[..., 1] < 0, 0.0, kpt_mask)
129
+ keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
130
+ lb = lb[:, :5]
131
+ return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
132
+ except Exception as e:
133
+ nc = 1
134
+ msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
135
+ return [None, None, None, None, None, nm, nf, ne, nc, msg]
136
+
137
+
138
+ def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
139
+ """
140
+ Args:
141
+ imgsz (tuple): The image size.
142
+ polygons (list[np.ndarray]): [N, M], N is the number of polygons, M is the number of points(Be divided by 2).
143
+ color (int): color
144
+ downsample_ratio (int): downsample ratio
145
+ """
146
+ mask = np.zeros(imgsz, dtype=np.uint8)
147
+ polygons = np.asarray(polygons)
148
+ polygons = polygons.astype(np.int32)
149
+ shape = polygons.shape
150
+ polygons = polygons.reshape(shape[0], -1, 2)
151
+ cv2.fillPoly(mask, polygons, color=color)
152
+ nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
153
+ # NOTE: fillPoly firstly then resize is trying the keep the same way
154
+ # of loss calculation when mask-ratio=1.
155
+ mask = cv2.resize(mask, (nw, nh))
156
+ return mask
157
+
158
+
159
+ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
160
+ """
161
+ Args:
162
+ imgsz (tuple): The image size.
163
+ polygons (list[np.ndarray]): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0)
164
+ color (int): color
165
+ downsample_ratio (int): downsample ratio
166
+ """
167
+ masks = []
168
+ for si in range(len(polygons)):
169
+ mask = polygon2mask(imgsz, [polygons[si].reshape(-1)], color, downsample_ratio)
170
+ masks.append(mask)
171
+ return np.array(masks)
172
+
173
+
174
+ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
175
+ """Return a (640, 640) overlap mask."""
176
+ masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
177
+ dtype=np.int32 if len(segments) > 255 else np.uint8)
178
+ areas = []
179
+ ms = []
180
+ for si in range(len(segments)):
181
+ mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
182
+ ms.append(mask)
183
+ areas.append(mask.sum())
184
+ areas = np.asarray(areas)
185
+ index = np.argsort(-areas)
186
+ ms = np.array(ms)[index]
187
+ for i in range(len(segments)):
188
+ mask = ms[i] * (i + 1)
189
+ masks = masks + mask
190
+ masks = np.clip(masks, a_min=0, a_max=i + 1)
191
+ return masks, index
192
+
193
+
194
+ def check_det_dataset(dataset, autodownload=True):
195
+ """Download, check and/or unzip dataset if not found locally."""
196
+ data = check_file(dataset)
197
+
198
+ # Download (optional)
199
+ extract_dir = ''
200
+ if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)):
201
+ new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
202
+ data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
203
+ extract_dir, autodownload = data.parent, False
204
+
205
+ # Read yaml (optional)
206
+ if isinstance(data, (str, Path)):
207
+ data = yaml_load(data, append_filename=True) # dictionary
208
+
209
+ # Checks
210
+ for k in 'train', 'val':
211
+ if k not in data:
212
+ raise SyntaxError(
213
+ emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
214
+ if 'names' not in data and 'nc' not in data:
215
+ raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
216
+ if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
217
+ raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
218
+ if 'names' not in data:
219
+ data['names'] = [f'class_{i}' for i in range(data['nc'])]
220
+ else:
221
+ data['nc'] = len(data['names'])
222
+
223
+ data['names'] = check_class_names(data['names'])
224
+
225
+ # Resolve paths
226
+ path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
227
+
228
+ if not path.is_absolute():
229
+ path = (DATASETS_DIR / path).resolve()
230
+ data['path'] = path # download scripts
231
+ for k in 'train', 'val', 'test':
232
+ if data.get(k): # prepend path
233
+ if isinstance(data[k], str):
234
+ x = (path / data[k]).resolve()
235
+ if not x.exists() and data[k].startswith('../'):
236
+ x = (path / data[k][3:]).resolve()
237
+ data[k] = str(x)
238
+ else:
239
+ data[k] = [str((path / x).resolve()) for x in data[k]]
240
+
241
+ # Parse yaml
242
+ train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
243
+ if val:
244
+ val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
245
+ if not all(x.exists() for x in val):
246
+ name = clean_url(dataset) # dataset name with URL auth stripped
247
+ m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
248
+ if s and autodownload:
249
+ LOGGER.warning(m)
250
+ else:
251
+ m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"
252
+ raise FileNotFoundError(m)
253
+ t = time.time()
254
+ if s.startswith('http') and s.endswith('.zip'): # URL
255
+ safe_download(url=s, dir=DATASETS_DIR, delete=True)
256
+ r = None # success
257
+ elif s.startswith('bash '): # bash script
258
+ LOGGER.info(f'Running {s} ...')
259
+ r = os.system(s)
260
+ else: # python script
261
+ r = exec(s, {'yaml': data}) # return None
262
+ dt = f'({round(time.time() - t, 1)}s)'
263
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
264
+ LOGGER.info(f'Dataset download {s}\n')
265
+ check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
266
+
267
+ return data # dictionary
268
+
269
+
270
+ def check_cls_dataset(dataset: str, split=''):
271
+ """
272
+ Checks a classification dataset such as Imagenet.
273
+
274
+ This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
275
+ If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
276
+
277
+ Args:
278
+ dataset (str): The name of the dataset.
279
+ split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
280
+
281
+ Returns:
282
+ (dict): A dictionary containing the following keys:
283
+ - 'train' (Path): The directory path containing the training set of the dataset.
284
+ - 'val' (Path): The directory path containing the validation set of the dataset.
285
+ - 'test' (Path): The directory path containing the test set of the dataset.
286
+ - 'nc' (int): The number of classes in the dataset.
287
+ - 'names' (dict): A dictionary of class names in the dataset.
288
+
289
+ Raises:
290
+ FileNotFoundError: If the specified dataset is not found and cannot be downloaded.
291
+ """
292
+
293
+ dataset = Path(dataset)
294
+ data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
295
+ if not data_dir.is_dir():
296
+ LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
297
+ t = time.time()
298
+ if str(dataset) == 'imagenet':
299
+ subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
300
+ else:
301
+ url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
302
+ download(url, dir=data_dir.parent)
303
+ s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
304
+ LOGGER.info(s)
305
+ train_set = data_dir / 'train'
306
+ val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val
307
+ test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
308
+ if split == 'val' and not val_set:
309
+ LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
310
+ elif split == 'test' and not test_set:
311
+ LOGGER.info("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
312
+
313
+ nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
314
+ names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
315
+ names = dict(enumerate(sorted(names)))
316
+ return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
317
+
318
+
319
+ class HUBDatasetStats():
320
+ """
321
+ A class for generating HUB dataset JSON and `-hub` dataset directory.
322
+
323
+ Args:
324
+ path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'.
325
+ task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
326
+ autodownload (bool): Attempt to download dataset if not found locally. Default is False.
327
+
328
+ Usage
329
+ from ultralytics.data.utils import HUBDatasetStats
330
+ stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8.zip', task='detect') # detect dataset
331
+ stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-seg.zip', task='segment') # segment dataset
332
+ stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-pose.zip', task='pose') # pose dataset
333
+ stats.get_json(save=False)
334
+ stats.process_images()
335
+ """
336
+
337
+ def __init__(self, path='coco128.yaml', task='detect', autodownload=False):
338
+ """Initialize class."""
339
+ LOGGER.info(f'Starting HUB dataset checks for {path}....')
340
+ zipped, data_dir, yaml_path = self._unzip(Path(path))
341
+ try:
342
+ # data = yaml_load(check_yaml(yaml_path)) # data dict
343
+ data = check_det_dataset(yaml_path, autodownload) # data dict
344
+ if zipped:
345
+ data['path'] = data_dir
346
+ except Exception as e:
347
+ raise Exception('error/HUB/dataset_stats/yaml_load') from e
348
+
349
+ self.hub_dir = Path(str(data['path']) + '-hub')
350
+ self.im_dir = self.hub_dir / 'images'
351
+ self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
352
+ self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
353
+ self.data = data
354
+ self.task = task # detect, segment, pose, classify
355
+
356
+ @staticmethod
357
+ def _find_yaml(dir):
358
+ """Return data.yaml file."""
359
+ files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
360
+ assert files, f'No *.yaml file found in {dir}'
361
+ if len(files) > 1:
362
+ files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
363
+ assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
364
+ assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
365
+ return files[0]
366
+
367
+ def _unzip(self, path):
368
+ """Unzip data.zip."""
369
+ if not str(path).endswith('.zip'): # path is data.yaml
370
+ return False, None, path
371
+ unzip_dir = unzip_file(path, path=path.parent)
372
+ assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
373
+ f'path/to/abc.zip MUST unzip to path/to/abc/'
374
+ return True, str(unzip_dir), self._find_yaml(unzip_dir) # zipped, data_dir, yaml_path
375
+
376
+ def _hub_ops(self, f):
377
+ """Saves a compressed image for HUB previews."""
378
+ compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
379
+
380
+ def get_json(self, save=False, verbose=False):
381
+ """Return dataset JSON for Ultralytics HUB."""
382
+ from ultralytics.data import YOLODataset # ClassificationDataset
383
+
384
+ def _round(labels):
385
+ """Update labels to integer class and 4 decimal place floats."""
386
+ if self.task == 'detect':
387
+ coordinates = labels['bboxes']
388
+ elif self.task == 'segment':
389
+ coordinates = [x.flatten() for x in labels['segments']]
390
+ elif self.task == 'pose':
391
+ n = labels['keypoints'].shape[0]
392
+ coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
393
+ else:
394
+ raise ValueError('Undefined dataset task.')
395
+ zipped = zip(labels['cls'], coordinates)
396
+ return [[int(c), *(round(float(x), 4) for x in points)] for c, points in zipped]
397
+
398
+ for split in 'train', 'val', 'test':
399
+ if self.data.get(split) is None:
400
+ self.stats[split] = None # i.e. no test set
401
+ continue
402
+
403
+ dataset = YOLODataset(img_path=self.data[split],
404
+ data=self.data,
405
+ use_segments=self.task == 'segment',
406
+ use_keypoints=self.task == 'pose')
407
+ x = np.array([
408
+ np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
409
+ for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
410
+ self.stats[split] = {
411
+ 'instance_stats': {
412
+ 'total': int(x.sum()),
413
+ 'per_class': x.sum(0).tolist()},
414
+ 'image_stats': {
415
+ 'total': len(dataset),
416
+ 'unlabelled': int(np.all(x == 0, 1).sum()),
417
+ 'per_class': (x > 0).sum(0).tolist()},
418
+ 'labels': [{
419
+ Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
420
+
421
+ # Save, print and return
422
+ if save:
423
+ stats_path = self.hub_dir / 'stats.json'
424
+ LOGGER.info(f'Saving {stats_path.resolve()}...')
425
+ with open(stats_path, 'w') as f:
426
+ json.dump(self.stats, f) # save stats.json
427
+ if verbose:
428
+ LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
429
+ return self.stats
430
+
431
+ def process_images(self):
432
+ """Compress images for Ultralytics HUB."""
433
+ from ultralytics.data import YOLODataset # ClassificationDataset
434
+
435
+ for split in 'train', 'val', 'test':
436
+ if self.data.get(split) is None:
437
+ continue
438
+ dataset = YOLODataset(img_path=self.data[split], data=self.data)
439
+ with ThreadPool(NUM_THREADS) as pool:
440
+ for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
441
+ pass
442
+ LOGGER.info(f'Done. All images saved to {self.im_dir}')
443
+ return self.im_dir
444
+
445
+
446
+ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
447
+ """
448
+ Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the
449
+ Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will
450
+ not be resized.
451
+
452
+ Args:
453
+ f (str): The path to the input image file.
454
+ f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
455
+ max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
456
+ quality (int, optional): The image compression quality as a percentage. Default is 50%.
457
+
458
+ Usage:
459
+ from pathlib import Path
460
+ from ultralytics.data.utils import compress_one_image
461
+ for f in Path('/Users/glennjocher/Downloads/dataset').rglob('*.jpg'):
462
+ compress_one_image(f)
463
+ """
464
+ try: # use PIL
465
+ im = Image.open(f)
466
+ r = max_dim / max(im.height, im.width) # ratio
467
+ if r < 1.0: # image too large
468
+ im = im.resize((int(im.width * r), int(im.height * r)))
469
+ im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
470
+ except Exception as e: # use OpenCV
471
+ LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
472
+ im = cv2.imread(f)
473
+ im_height, im_width = im.shape[:2]
474
+ r = max_dim / max(im_height, im_width) # ratio
475
+ if r < 1.0: # image too large
476
+ im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
477
+ cv2.imwrite(str(f_new or f), im)
478
+
479
+
480
+ def delete_dsstore(path):
481
+ """
482
+ Deletes all ".DS_store" files under a specified directory.
483
+
484
+ Args:
485
+ path (str, optional): The directory path where the ".DS_store" files should be deleted.
486
+
487
+ Usage:
488
+ from ultralytics.data.utils import delete_dsstore
489
+ delete_dsstore('/Users/glennjocher/Downloads/dataset')
490
+
491
+ Note:
492
+ ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
493
+ are hidden system files and can cause issues when transferring files between different operating systems.
494
+ """
495
+ # Delete Apple .DS_store files
496
+ files = list(Path(path).rglob('.DS_store'))
497
+ LOGGER.info(f'Deleting *.DS_store files: {files}')
498
+ for f in files:
499
+ f.unlink()
500
+
501
+
502
+ def zip_directory(dir, use_zipfile_library=True):
503
+ """
504
+ Zips a directory and saves the archive to the specified output path.
505
+
506
+ Args:
507
+ dir (str): The path to the directory to be zipped.
508
+ use_zipfile_library (bool): Whether to use zipfile library or shutil for zipping.
509
+
510
+ Usage:
511
+ from ultralytics.data.utils import zip_directory
512
+ zip_directory('/Users/glennjocher/Downloads/playground')
513
+
514
+ zip -r coco8-pose.zip coco8-pose
515
+ """
516
+ delete_dsstore(dir)
517
+ if use_zipfile_library:
518
+ dir = Path(dir)
519
+ with zipfile.ZipFile(dir.with_suffix('.zip'), 'w', zipfile.ZIP_DEFLATED) as zip_file:
520
+ for file_path in dir.glob('**/*'):
521
+ if file_path.is_file():
522
+ zip_file.write(file_path, file_path.relative_to(dir))
523
+ else:
524
+ import shutil
525
+ shutil.make_archive(dir, 'zip', dir)
526
+
527
+
528
+ def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
529
+ """
530
+ Autosplit a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
531
+
532
+ Args:
533
+ path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco128/images'.
534
+ weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
535
+ annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
536
+
537
+ Usage:
538
+ from utils.dataloaders import autosplit
539
+ autosplit()
540
+ """
541
+
542
+ path = Path(path) # images dir
543
+ files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
544
+ n = len(files) # number of files
545
+ random.seed(0) # for reproducibility
546
+ indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
547
+
548
+ txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
549
+ for x in txt:
550
+ if (path.parent / x).exists():
551
+ (path.parent / x).unlink() # remove existing
552
+
553
+ LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
554
+ for i, img in tqdm(zip(indices, files), total=n):
555
+ if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
556
+ with open(path.parent / txt[i], 'a') as f:
557
+ f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
ultralytics/engine/__init__.py ADDED
File without changes
ultralytics/engine/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (175 Bytes). View file
 
ultralytics/engine/__pycache__/exporter.cpython-39.pyc ADDED
Binary file (34.4 kB). View file
 
ultralytics/engine/__pycache__/model.cpython-39.pyc ADDED
Binary file (17.8 kB). View file
 
ultralytics/engine/__pycache__/predictor.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
ultralytics/engine/__pycache__/results.cpython-39.pyc ADDED
Binary file (24.8 kB). View file
 
ultralytics/engine/__pycache__/trainer.cpython-39.pyc ADDED
Binary file (24 kB). View file
 
ultralytics/engine/__pycache__/validator.cpython-39.pyc ADDED
Binary file (11.1 kB). View file
 
ultralytics/engine/exporter.py ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Export a YOLOv8 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
4
+
5
+ Format | `format=argument` | Model
6
+ --- | --- | ---
7
+ PyTorch | - | yolov8n.pt
8
+ TorchScript | `torchscript` | yolov8n.torchscript
9
+ ONNX | `onnx` | yolov8n.onnx
10
+ OpenVINO | `openvino` | yolov8n_openvino_model/
11
+ TensorRT | `engine` | yolov8n.engine
12
+ CoreML | `coreml` | yolov8n.mlmodel
13
+ TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
14
+ TensorFlow GraphDef | `pb` | yolov8n.pb
15
+ TensorFlow Lite | `tflite` | yolov8n.tflite
16
+ TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
17
+ TensorFlow.js | `tfjs` | yolov8n_web_model/
18
+ PaddlePaddle | `paddle` | yolov8n_paddle_model/
19
+ ncnn | `ncnn` | yolov8n_ncnn_model/
20
+
21
+ Requirements:
22
+ $ pip install "ultralytics[export]"
23
+
24
+ Python:
25
+ from ultralytics import YOLO
26
+ model = YOLO('yolov8n.pt')
27
+ results = model.export(format='onnx')
28
+
29
+ CLI:
30
+ $ yolo mode=export model=yolov8n.pt format=onnx
31
+
32
+ Inference:
33
+ $ yolo predict model=yolov8n.pt # PyTorch
34
+ yolov8n.torchscript # TorchScript
35
+ yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
36
+ yolov8n_openvino_model # OpenVINO
37
+ yolov8n.engine # TensorRT
38
+ yolov8n.mlmodel # CoreML (macOS-only)
39
+ yolov8n_saved_model # TensorFlow SavedModel
40
+ yolov8n.pb # TensorFlow GraphDef
41
+ yolov8n.tflite # TensorFlow Lite
42
+ yolov8n_edgetpu.tflite # TensorFlow Edge TPU
43
+ yolov8n_paddle_model # PaddlePaddle
44
+
45
+ TensorFlow.js:
46
+ $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
47
+ $ npm install
48
+ $ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
49
+ $ npm start
50
+ """
51
+ import json
52
+ import os
53
+ import shutil
54
+ import subprocess
55
+ import time
56
+ import warnings
57
+ from copy import deepcopy
58
+ from datetime import datetime
59
+ from pathlib import Path
60
+
61
+ import torch
62
+
63
+ from ultralytics.cfg import get_cfg
64
+ from ultralytics.nn.autobackend import check_class_names
65
+ from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
66
+ from ultralytics.nn.tasks import DetectionModel, SegmentationModel
67
+ from ultralytics.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, WINDOWS, __version__, callbacks,
68
+ colorstr, get_default_args, yaml_save)
69
+ from ultralytics.utils.checks import check_imgsz, check_requirements, check_version
70
+ from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
71
+ from ultralytics.utils.files import file_size, spaces_in_path
72
+ from ultralytics.utils.ops import Profile
73
+ from ultralytics.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
74
+
75
+
76
+ def export_formats():
77
+ """YOLOv8 export formats."""
78
+ import pandas
79
+ x = [
80
+ ['PyTorch', '-', '.pt', True, True],
81
+ ['TorchScript', 'torchscript', '.torchscript', True, True],
82
+ ['ONNX', 'onnx', '.onnx', True, True],
83
+ ['OpenVINO', 'openvino', '_openvino_model', True, False],
84
+ ['TensorRT', 'engine', '.engine', False, True],
85
+ ['CoreML', 'coreml', '.mlmodel', True, False],
86
+ ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
87
+ ['TensorFlow GraphDef', 'pb', '.pb', True, True],
88
+ ['TensorFlow Lite', 'tflite', '.tflite', True, False],
89
+ ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
90
+ ['TensorFlow.js', 'tfjs', '_web_model', True, False],
91
+ ['PaddlePaddle', 'paddle', '_paddle_model', True, True],
92
+ ['ncnn', 'ncnn', '_ncnn_model', True, True], ]
93
+ return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
94
+
95
+
96
+ def gd_outputs(gd):
97
+ """TensorFlow GraphDef model output node names."""
98
+ name_list, input_list = [], []
99
+ for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
100
+ name_list.append(node.name)
101
+ input_list.extend(node.input)
102
+ return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
103
+
104
+
105
+ def try_export(inner_func):
106
+ """YOLOv8 export decorator, i..e @try_export."""
107
+ inner_args = get_default_args(inner_func)
108
+
109
+ def outer_func(*args, **kwargs):
110
+ """Export a model."""
111
+ prefix = inner_args['prefix']
112
+ try:
113
+ with Profile() as dt:
114
+ f, model = inner_func(*args, **kwargs)
115
+ LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
116
+ return f, model
117
+ except Exception as e:
118
+ LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
119
+ raise e
120
+
121
+ return outer_func
122
+
123
+
124
+ class Exporter:
125
+ """
126
+ A class for exporting a model.
127
+
128
+ Attributes:
129
+ args (SimpleNamespace): Configuration for the exporter.
130
+ save_dir (Path): Directory to save results.
131
+ """
132
+
133
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
134
+ """
135
+ Initializes the Exporter class.
136
+
137
+ Args:
138
+ cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
139
+ overrides (dict, optional): Configuration overrides. Defaults to None.
140
+ _callbacks (list, optional): List of callback functions. Defaults to None.
141
+ """
142
+ self.args = get_cfg(cfg, overrides)
143
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
144
+ callbacks.add_integration_callbacks(self)
145
+
146
+ @smart_inference_mode()
147
+ def __call__(self, model=None):
148
+ """Returns list of exported files/dirs after running callbacks."""
149
+ self.run_callbacks('on_export_start')
150
+ t = time.time()
151
+ format = self.args.format.lower() # to lowercase
152
+ if format in ('tensorrt', 'trt'): # engine aliases
153
+ format = 'engine'
154
+ fmts = tuple(export_formats()['Argument'][1:]) # available export formats
155
+ flags = [x == format for x in fmts]
156
+ if sum(flags) != 1:
157
+ raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}")
158
+ jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
159
+
160
+ # Load PyTorch model
161
+ self.device = select_device('cpu' if self.args.device is None else self.args.device)
162
+
163
+ # Checks
164
+ model.names = check_class_names(model.names)
165
+ if self.args.half and onnx and self.device.type == 'cpu':
166
+ LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0')
167
+ self.args.half = False
168
+ assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.'
169
+ self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
170
+ if self.args.optimize:
171
+ assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
172
+ assert self.device.type == 'cpu', "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
173
+ if edgetpu and not LINUX:
174
+ raise SystemError('Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/')
175
+
176
+ # Input
177
+ im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
178
+ file = Path(
179
+ getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', ''))
180
+ if file.suffix in ('.yaml', '.yml'):
181
+ file = Path(file.name)
182
+
183
+ # Update model
184
+ model = deepcopy(model).to(self.device)
185
+ for p in model.parameters():
186
+ p.requires_grad = False
187
+ model.eval()
188
+ model.float()
189
+ model = model.fuse()
190
+ for k, m in model.named_modules():
191
+ if isinstance(m, (Detect, RTDETRDecoder)): # Segment and Pose use Detect base class
192
+ m.dynamic = self.args.dynamic
193
+ m.export = True
194
+ m.format = self.args.format
195
+ elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
196
+ # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
197
+ m.forward = m.forward_split
198
+
199
+ y = None
200
+ for _ in range(2):
201
+ y = model(im) # dry runs
202
+ if self.args.half and (engine or onnx) and self.device.type != 'cpu':
203
+ im, model = im.half(), model.half() # to FP16
204
+
205
+ # Filter warnings
206
+ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
207
+ warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant missing ONNX warning
208
+ warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
209
+
210
+ # Assign
211
+ self.im = im
212
+ self.model = model
213
+ self.file = file
214
+ self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else \
215
+ tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
216
+ self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
217
+ trained_on = f'trained on {Path(self.args.data).name}' if self.args.data else '(untrained)'
218
+ description = f'Ultralytics {self.pretty_name} model {trained_on}'
219
+ self.metadata = {
220
+ 'description': description,
221
+ 'author': 'Ultralytics',
222
+ 'license': 'AGPL-3.0 https://ultralytics.com/license',
223
+ 'date': datetime.now().isoformat(),
224
+ 'version': __version__,
225
+ 'stride': int(max(model.stride)),
226
+ 'task': model.task,
227
+ 'batch': self.args.batch,
228
+ 'imgsz': self.imgsz,
229
+ 'names': model.names} # model metadata
230
+ if model.task == 'pose':
231
+ self.metadata['kpt_shape'] = model.model[-1].kpt_shape
232
+
233
+ LOGGER.info(f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
234
+ f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
235
+
236
+ # Exports
237
+ f = [''] * len(fmts) # exported filenames
238
+ if jit or ncnn: # TorchScript
239
+ f[0], _ = self.export_torchscript()
240
+ if engine: # TensorRT required before ONNX
241
+ f[1], _ = self.export_engine()
242
+ if onnx or xml: # OpenVINO requires ONNX
243
+ f[2], _ = self.export_onnx()
244
+ if xml: # OpenVINO
245
+ f[3], _ = self.export_openvino()
246
+ if coreml: # CoreML
247
+ f[4], _ = self.export_coreml()
248
+ if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
249
+ self.args.int8 |= edgetpu
250
+ f[5], s_model = self.export_saved_model()
251
+ if pb or tfjs: # pb prerequisite to tfjs
252
+ f[6], _ = self.export_pb(s_model)
253
+ if tflite:
254
+ f[7], _ = self.export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms)
255
+ if edgetpu:
256
+ f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite')
257
+ if tfjs:
258
+ f[9], _ = self.export_tfjs()
259
+ if paddle: # PaddlePaddle
260
+ f[10], _ = self.export_paddle()
261
+ if ncnn: # ncnn
262
+ f[11], _ = self.export_ncnn()
263
+
264
+ # Finish
265
+ f = [str(x) for x in f if x] # filter out '' and None
266
+ if any(f):
267
+ f = str(Path(f[-1]))
268
+ square = self.imgsz[0] == self.imgsz[1]
269
+ s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
270
+ f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
271
+ imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
272
+ data = f'data={self.args.data}' if model.task == 'segment' and format == 'pb' else ''
273
+ LOGGER.info(
274
+ f'\nExport complete ({time.time() - t:.1f}s)'
275
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
276
+ f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {data}'
277
+ f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={self.args.data} {s}'
278
+ f'\nVisualize: https://netron.app')
279
+
280
+ self.run_callbacks('on_export_end')
281
+ return f # return list of exported files/dirs
282
+
283
+ @try_export
284
+ def export_torchscript(self, prefix=colorstr('TorchScript:')):
285
+ """YOLOv8 TorchScript model export."""
286
+ LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
287
+ f = self.file.with_suffix('.torchscript')
288
+
289
+ ts = torch.jit.trace(self.model, self.im, strict=False)
290
+ extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
291
+ if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
292
+ LOGGER.info(f'{prefix} optimizing for mobile...')
293
+ from torch.utils.mobile_optimizer import optimize_for_mobile
294
+ optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
295
+ else:
296
+ ts.save(str(f), _extra_files=extra_files)
297
+ return f, None
298
+
299
+ @try_export
300
+ def export_onnx(self, prefix=colorstr('ONNX:')):
301
+ """YOLOv8 ONNX export."""
302
+ requirements = ['onnx>=1.12.0']
303
+ if self.args.simplify:
304
+ requirements += ['onnxsim>=0.4.17', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
305
+ check_requirements(requirements)
306
+ import onnx # noqa
307
+
308
+ opset_version = self.args.opset or get_latest_opset()
309
+ LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...')
310
+ f = str(self.file.with_suffix('.onnx'))
311
+
312
+ output_names = ['output0', 'output1'] if isinstance(self.model, SegmentationModel) else ['output0']
313
+ dynamic = self.args.dynamic
314
+ if dynamic:
315
+ dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
316
+ if isinstance(self.model, SegmentationModel):
317
+ dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 116, 8400)
318
+ dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
319
+ elif isinstance(self.model, DetectionModel):
320
+ dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 84, 8400)
321
+
322
+ torch.onnx.export(
323
+ self.model.cpu() if dynamic else self.model, # --dynamic only compatible with cpu
324
+ self.im.cpu() if dynamic else self.im,
325
+ f,
326
+ verbose=False,
327
+ opset_version=opset_version,
328
+ do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
329
+ input_names=['images'],
330
+ output_names=output_names,
331
+ dynamic_axes=dynamic or None)
332
+
333
+ # Checks
334
+ model_onnx = onnx.load(f) # load onnx model
335
+ # onnx.checker.check_model(model_onnx) # check onnx model
336
+
337
+ # Simplify
338
+ if self.args.simplify:
339
+ try:
340
+ import onnxsim
341
+
342
+ LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
343
+ # subprocess.run(f'onnxsim "{f}" "{f}"', shell=True)
344
+ model_onnx, check = onnxsim.simplify(model_onnx)
345
+ assert check, 'Simplified ONNX model could not be validated'
346
+ except Exception as e:
347
+ LOGGER.info(f'{prefix} simplifier failure: {e}')
348
+
349
+ # Metadata
350
+ for k, v in self.metadata.items():
351
+ meta = model_onnx.metadata_props.add()
352
+ meta.key, meta.value = k, str(v)
353
+
354
+ onnx.save(model_onnx, f)
355
+ return f, model_onnx
356
+
357
+ @try_export
358
+ def export_openvino(self, prefix=colorstr('OpenVINO:')):
359
+ """YOLOv8 OpenVINO export."""
360
+ check_requirements('openvino-dev>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
361
+ import openvino.runtime as ov # noqa
362
+ from openvino.tools import mo # noqa
363
+
364
+ LOGGER.info(f'\n{prefix} starting export with openvino {ov.__version__}...')
365
+ f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
366
+ f_onnx = self.file.with_suffix('.onnx')
367
+ f_ov = str(Path(f) / self.file.with_suffix('.xml').name)
368
+
369
+ ov_model = mo.convert_model(f_onnx,
370
+ model_name=self.pretty_name,
371
+ framework='onnx',
372
+ compress_to_fp16=self.args.half) # export
373
+
374
+ # Set RT info
375
+ ov_model.set_rt_info('YOLOv8', ['model_info', 'model_type'])
376
+ ov_model.set_rt_info(True, ['model_info', 'reverse_input_channels'])
377
+ ov_model.set_rt_info(114, ['model_info', 'pad_value'])
378
+ ov_model.set_rt_info([255.0], ['model_info', 'scale_values'])
379
+ ov_model.set_rt_info(self.args.iou, ['model_info', 'iou_threshold'])
380
+ ov_model.set_rt_info([v.replace(' ', '_') for k, v in sorted(self.model.names.items())],
381
+ ['model_info', 'labels'])
382
+ if self.model.task != 'classify':
383
+ ov_model.set_rt_info('fit_to_window_letterbox', ['model_info', 'resize_type'])
384
+
385
+ ov.serialize(ov_model, f_ov) # save
386
+ yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
387
+ return f, None
388
+
389
+ @try_export
390
+ def export_paddle(self, prefix=colorstr('PaddlePaddle:')):
391
+ """YOLOv8 Paddle export."""
392
+ check_requirements(('paddlepaddle', 'x2paddle'))
393
+ import x2paddle # noqa
394
+ from x2paddle.convert import pytorch2paddle # noqa
395
+
396
+ LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
397
+ f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}')
398
+
399
+ pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export
400
+ yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
401
+ return f, None
402
+
403
+ @try_export
404
+ def export_ncnn(self, prefix=colorstr('ncnn:')):
405
+ """
406
+ YOLOv8 ncnn export using PNNX https://github.com/pnnx/pnnx.
407
+ """
408
+ check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn
409
+ import ncnn # noqa
410
+
411
+ LOGGER.info(f'\n{prefix} starting export with ncnn {ncnn.__version__}...')
412
+ f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}'))
413
+ f_ts = self.file.with_suffix('.torchscript')
414
+
415
+ pnnx_filename = 'pnnx.exe' if WINDOWS else 'pnnx'
416
+ if Path(pnnx_filename).is_file():
417
+ pnnx = pnnx_filename
418
+ elif (ROOT / pnnx_filename).is_file():
419
+ pnnx = ROOT / pnnx_filename
420
+ else:
421
+ LOGGER.warning(
422
+ f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
423
+ 'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
424
+ f'or in {ROOT}. See PNNX repo for full installation instructions.')
425
+ _, assets = get_github_assets(repo='pnnx/pnnx', retry=True)
426
+ asset = [x for x in assets if ('macos' if MACOS else 'ubuntu' if LINUX else 'windows') in x][0]
427
+ attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
428
+ unzip_dir = Path(asset).with_suffix('')
429
+ pnnx = ROOT / pnnx_filename # new location
430
+ (unzip_dir / pnnx_filename).rename(pnnx) # move binary to ROOT
431
+ shutil.rmtree(unzip_dir) # delete unzip dir
432
+ Path(asset).unlink() # delete zip
433
+ pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
434
+
435
+ use_ncnn = True
436
+ ncnn_args = [
437
+ f'ncnnparam={f / "model.ncnn.param"}',
438
+ f'ncnnbin={f / "model.ncnn.bin"}',
439
+ f'ncnnpy={f / "model_ncnn.py"}', ] if use_ncnn else []
440
+
441
+ use_pnnx = False
442
+ pnnx_args = [
443
+ f'pnnxparam={f / "model.pnnx.param"}',
444
+ f'pnnxbin={f / "model.pnnx.bin"}',
445
+ f'pnnxpy={f / "model_pnnx.py"}',
446
+ f'pnnxonnx={f / "model.pnnx.onnx"}', ] if use_pnnx else []
447
+
448
+ cmd = [
449
+ str(pnnx),
450
+ str(f_ts),
451
+ *ncnn_args,
452
+ *pnnx_args,
453
+ f'fp16={int(self.args.half)}',
454
+ f'device={self.device.type}',
455
+ f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ]
456
+ f.mkdir(exist_ok=True) # make ncnn_model directory
457
+ LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
458
+ subprocess.run(cmd, check=True)
459
+ for f_debug in 'debug.bin', 'debug.param', 'debug2.bin', 'debug2.param': # remove debug files
460
+ Path(f_debug).unlink(missing_ok=True)
461
+
462
+ yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
463
+ return str(f), None
464
+
465
+ @try_export
466
+ def export_coreml(self, prefix=colorstr('CoreML:')):
467
+ """YOLOv8 CoreML export."""
468
+ check_requirements('coremltools>=6.0,<=6.2')
469
+ import coremltools as ct # noqa
470
+
471
+ LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
472
+ f = self.file.with_suffix('.mlmodel')
473
+
474
+ bias = [0.0, 0.0, 0.0]
475
+ scale = 1 / 255
476
+ classifier_config = None
477
+ if self.model.task == 'classify':
478
+ classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
479
+ model = self.model
480
+ elif self.model.task == 'detect':
481
+ model = iOSDetectModel(self.model, self.im) if self.args.nms else self.model
482
+ else:
483
+ # TODO CoreML Segment and Pose model pipelining
484
+ model = self.model
485
+
486
+ ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
487
+ ct_model = ct.convert(ts,
488
+ inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)],
489
+ classifier_config=classifier_config)
490
+ bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
491
+ if bits < 32:
492
+ if 'kmeans' in mode:
493
+ check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
494
+ ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
495
+ if self.args.nms and self.model.task == 'detect':
496
+ ct_model = self._pipeline_coreml(ct_model)
497
+
498
+ m = self.metadata # metadata dict
499
+ ct_model.short_description = m.pop('description')
500
+ ct_model.author = m.pop('author')
501
+ ct_model.license = m.pop('license')
502
+ ct_model.version = m.pop('version')
503
+ ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
504
+ ct_model.save(str(f))
505
+ return f, ct_model
506
+
507
+ @try_export
508
+ def export_engine(self, prefix=colorstr('TensorRT:')):
509
+ """YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
510
+ assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
511
+ try:
512
+ import tensorrt as trt # noqa
513
+ except ImportError:
514
+ if LINUX:
515
+ check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
516
+ import tensorrt as trt # noqa
517
+
518
+ check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
519
+ self.args.simplify = True
520
+ f_onnx, _ = self.export_onnx()
521
+
522
+ LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
523
+ assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}'
524
+ f = self.file.with_suffix('.engine') # TensorRT engine file
525
+ logger = trt.Logger(trt.Logger.INFO)
526
+ if self.args.verbose:
527
+ logger.min_severity = trt.Logger.Severity.VERBOSE
528
+
529
+ builder = trt.Builder(logger)
530
+ config = builder.create_builder_config()
531
+ config.max_workspace_size = self.args.workspace * 1 << 30
532
+ # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
533
+
534
+ flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
535
+ network = builder.create_network(flag)
536
+ parser = trt.OnnxParser(network, logger)
537
+ if not parser.parse_from_file(f_onnx):
538
+ raise RuntimeError(f'failed to load ONNX file: {f_onnx}')
539
+
540
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
541
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
542
+ for inp in inputs:
543
+ LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
544
+ for out in outputs:
545
+ LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
546
+
547
+ if self.args.dynamic:
548
+ shape = self.im.shape
549
+ if shape[0] <= 1:
550
+ LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
551
+ profile = builder.create_optimization_profile()
552
+ for inp in inputs:
553
+ profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
554
+ config.add_optimization_profile(profile)
555
+
556
+ LOGGER.info(
557
+ f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}')
558
+ if builder.platform_has_fast_fp16 and self.args.half:
559
+ config.set_flag(trt.BuilderFlag.FP16)
560
+
561
+ # Write file
562
+ with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
563
+ # Metadata
564
+ meta = json.dumps(self.metadata)
565
+ t.write(len(meta).to_bytes(4, byteorder='little', signed=True))
566
+ t.write(meta.encode())
567
+ # Model
568
+ t.write(engine.serialize())
569
+
570
+ return f, None
571
+
572
+ @try_export
573
+ def export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')):
574
+ """YOLOv8 TensorFlow SavedModel export."""
575
+ try:
576
+ import tensorflow as tf # noqa
577
+ except ImportError:
578
+ cuda = torch.cuda.is_available()
579
+ check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
580
+ import tensorflow as tf # noqa
581
+ check_requirements(('onnx', 'onnx2tf>=1.9.1', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.17', 'onnx_graphsurgeon>=0.3.26',
582
+ 'tflite_support', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
583
+ cmds='--extra-index-url https://pypi.ngc.nvidia.com')
584
+
585
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
586
+ f = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
587
+ if f.is_dir():
588
+ import shutil
589
+ shutil.rmtree(f) # delete output folder
590
+
591
+ # Export to ONNX
592
+ self.args.simplify = True
593
+ f_onnx, _ = self.export_onnx()
594
+
595
+ # Export to TF
596
+ tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file
597
+ if self.args.int8:
598
+ if self.args.data:
599
+ import numpy as np
600
+
601
+ from ultralytics.data.dataset import YOLODataset
602
+ from ultralytics.data.utils import check_det_dataset
603
+
604
+ # Generate calibration data for integer quantization
605
+ LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
606
+ dataset = YOLODataset(check_det_dataset(self.args.data)['val'], imgsz=self.imgsz[0], augment=False)
607
+ images = []
608
+ n_images = 100 # maximum number of images
609
+ for n, batch in enumerate(dataset):
610
+ if n >= n_images:
611
+ break
612
+ im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC,
613
+ images.append(im)
614
+ f.mkdir()
615
+ images = torch.cat(images, 0).float()
616
+ # mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
617
+ # std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
618
+ np.save(str(tmp_file), images.numpy()) # BHWC
619
+ int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
620
+ else:
621
+ int8 = '-oiqt -qt per-tensor'
622
+ else:
623
+ int8 = ''
624
+
625
+ cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'.strip()
626
+ LOGGER.info(f"{prefix} running '{cmd}'")
627
+ subprocess.run(cmd, shell=True)
628
+ yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
629
+
630
+ # Remove/rename TFLite models
631
+ if self.args.int8:
632
+ tmp_file.unlink(missing_ok=True)
633
+ for file in f.rglob('*_dynamic_range_quant.tflite'):
634
+ file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
635
+ for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
636
+ file.unlink() # delete extra fp16 activation TFLite files
637
+
638
+ # Add TFLite metadata
639
+ for file in f.rglob('*.tflite'):
640
+ f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file)
641
+
642
+ # Load saved_model
643
+ keras_model = tf.saved_model.load(f, tags=None, options=None)
644
+
645
+ return str(f), keras_model
646
+
647
+ @try_export
648
+ def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
649
+ """YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
650
+ import tensorflow as tf # noqa
651
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
652
+
653
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
654
+ f = self.file.with_suffix('.pb')
655
+
656
+ m = tf.function(lambda x: keras_model(x)) # full model
657
+ m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
658
+ frozen_func = convert_variables_to_constants_v2(m)
659
+ frozen_func.graph.as_graph_def()
660
+ tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
661
+ return f, None
662
+
663
+ @try_export
664
+ def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
665
+ """YOLOv8 TensorFlow Lite export."""
666
+ import tensorflow as tf # noqa
667
+
668
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
669
+ saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
670
+ if self.args.int8:
671
+ f = saved_model / f'{self.file.stem}_int8.tflite' # fp32 in/out
672
+ elif self.args.half:
673
+ f = saved_model / f'{self.file.stem}_float16.tflite' # fp32 in/out
674
+ else:
675
+ f = saved_model / f'{self.file.stem}_float32.tflite'
676
+ return str(f), None
677
+
678
+ @try_export
679
+ def export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
680
+ """YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
681
+ LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185')
682
+
683
+ cmd = 'edgetpu_compiler --version'
684
+ help_url = 'https://coral.ai/docs/edgetpu/compiler/'
685
+ assert LINUX, f'export only supported on Linux. See {help_url}'
686
+ if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
687
+ LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
688
+ sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
689
+ for c in (
690
+ 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
691
+ 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
692
+ 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
693
+ subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
694
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
695
+
696
+ LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
697
+ f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
698
+
699
+ cmd = f'edgetpu_compiler -s -d -k 10 --out_dir "{Path(f).parent}" "{tflite_model}"'
700
+ LOGGER.info(f"{prefix} running '{cmd}'")
701
+ subprocess.run(cmd, shell=True)
702
+ self._add_tflite_metadata(f)
703
+ return f, None
704
+
705
+ @try_export
706
+ def export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
707
+ """YOLOv8 TensorFlow.js export."""
708
+ check_requirements('tensorflowjs')
709
+ import tensorflow as tf
710
+ import tensorflowjs as tfjs # noqa
711
+
712
+ LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
713
+ f = str(self.file).replace(self.file.suffix, '_web_model') # js dir
714
+ f_pb = str(self.file.with_suffix('.pb')) # *.pb path
715
+
716
+ gd = tf.Graph().as_graph_def() # TF GraphDef
717
+ with open(f_pb, 'rb') as file:
718
+ gd.ParseFromString(file.read())
719
+ outputs = ','.join(gd_outputs(gd))
720
+ LOGGER.info(f'\n{prefix} output node names: {outputs}')
721
+
722
+ with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
723
+ cmd = f'tensorflowjs_converter --input_format=tf_frozen_model --output_node_names={outputs} "{fpb_}" "{f_}"'
724
+ LOGGER.info(f"{prefix} running '{cmd}'")
725
+ subprocess.run(cmd, shell=True)
726
+
727
+ if ' ' in str(f):
728
+ LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.")
729
+
730
+ # f_json = Path(f) / 'model.json' # *.json path
731
+ # with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
732
+ # subst = re.sub(
733
+ # r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
734
+ # r'"Identity.?.?": {"name": "Identity.?.?"}, '
735
+ # r'"Identity.?.?": {"name": "Identity.?.?"}, '
736
+ # r'"Identity.?.?": {"name": "Identity.?.?"}}}',
737
+ # r'{"outputs": {"Identity": {"name": "Identity"}, '
738
+ # r'"Identity_1": {"name": "Identity_1"}, '
739
+ # r'"Identity_2": {"name": "Identity_2"}, '
740
+ # r'"Identity_3": {"name": "Identity_3"}}}',
741
+ # f_json.read_text(),
742
+ # )
743
+ # j.write(subst)
744
+ yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
745
+ return f, None
746
+
747
+ def _add_tflite_metadata(self, file):
748
+ """Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
749
+ from tflite_support import flatbuffers # noqa
750
+ from tflite_support import metadata as _metadata # noqa
751
+ from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
752
+
753
+ # Create model info
754
+ model_meta = _metadata_fb.ModelMetadataT()
755
+ model_meta.name = self.metadata['description']
756
+ model_meta.version = self.metadata['version']
757
+ model_meta.author = self.metadata['author']
758
+ model_meta.license = self.metadata['license']
759
+
760
+ # Label file
761
+ tmp_file = Path(file).parent / 'temp_meta.txt'
762
+ with open(tmp_file, 'w') as f:
763
+ f.write(str(self.metadata))
764
+
765
+ label_file = _metadata_fb.AssociatedFileT()
766
+ label_file.name = tmp_file.name
767
+ label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
768
+
769
+ # Create input info
770
+ input_meta = _metadata_fb.TensorMetadataT()
771
+ input_meta.name = 'image'
772
+ input_meta.description = 'Input image to be detected.'
773
+ input_meta.content = _metadata_fb.ContentT()
774
+ input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
775
+ input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
776
+ input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
777
+
778
+ # Create output info
779
+ output1 = _metadata_fb.TensorMetadataT()
780
+ output1.name = 'output'
781
+ output1.description = 'Coordinates of detected objects, class labels, and confidence score'
782
+ output1.associatedFiles = [label_file]
783
+ if self.model.task == 'segment':
784
+ output2 = _metadata_fb.TensorMetadataT()
785
+ output2.name = 'output'
786
+ output2.description = 'Mask protos'
787
+ output2.associatedFiles = [label_file]
788
+
789
+ # Create subgraph info
790
+ subgraph = _metadata_fb.SubGraphMetadataT()
791
+ subgraph.inputTensorMetadata = [input_meta]
792
+ subgraph.outputTensorMetadata = [output1, output2] if self.model.task == 'segment' else [output1]
793
+ model_meta.subgraphMetadata = [subgraph]
794
+
795
+ b = flatbuffers.Builder(0)
796
+ b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
797
+ metadata_buf = b.Output()
798
+
799
+ populator = _metadata.MetadataPopulator.with_model_file(str(file))
800
+ populator.load_metadata_buffer(metadata_buf)
801
+ populator.load_associated_files([str(tmp_file)])
802
+ populator.populate()
803
+ tmp_file.unlink()
804
+
805
+ def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
806
+ """YOLOv8 CoreML pipeline."""
807
+ import coremltools as ct # noqa
808
+
809
+ LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
810
+ batch_size, ch, h, w = list(self.im.shape) # BCHW
811
+
812
+ # Output shapes
813
+ spec = model.get_spec()
814
+ out0, out1 = iter(spec.description.output)
815
+ if MACOS:
816
+ from PIL import Image
817
+ img = Image.new('RGB', (w, h)) # img(192 width, 320 height)
818
+ # img = torch.zeros((*opt.img_size, 3)).numpy() # img size(320,192,3) iDetection
819
+ out = model.predict({'image': img})
820
+ out0_shape = out[out0.name].shape
821
+ out1_shape = out[out1.name].shape
822
+ else: # linux and windows can not run model.predict(), get sizes from pytorch output y
823
+ out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
824
+ out1_shape = self.output_shape[2], 4 # (3780, 4)
825
+
826
+ # Checks
827
+ names = self.metadata['names']
828
+ nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
829
+ na, nc = out0_shape
830
+ # na, nc = out0.type.multiArrayType.shape # number anchors, classes
831
+ assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check
832
+
833
+ # Define output shapes (missing)
834
+ out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
835
+ out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
836
+ # spec.neuralNetwork.preprocessing[0].featureName = '0'
837
+
838
+ # Flexible input shapes
839
+ # from coremltools.models.neural_network import flexible_shape_utils
840
+ # s = [] # shapes
841
+ # s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
842
+ # s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
843
+ # flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
844
+ # r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
845
+ # r.add_height_range((192, 640))
846
+ # r.add_width_range((192, 640))
847
+ # flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
848
+
849
+ # Print
850
+ # print(spec.description)
851
+
852
+ # Model from spec
853
+ model = ct.models.MLModel(spec)
854
+
855
+ # 3. Create NMS protobuf
856
+ nms_spec = ct.proto.Model_pb2.Model()
857
+ nms_spec.specificationVersion = 5
858
+ for i in range(2):
859
+ decoder_output = model._spec.description.output[i].SerializeToString()
860
+ nms_spec.description.input.add()
861
+ nms_spec.description.input[i].ParseFromString(decoder_output)
862
+ nms_spec.description.output.add()
863
+ nms_spec.description.output[i].ParseFromString(decoder_output)
864
+
865
+ nms_spec.description.output[0].name = 'confidence'
866
+ nms_spec.description.output[1].name = 'coordinates'
867
+
868
+ output_sizes = [nc, 4]
869
+ for i in range(2):
870
+ ma_type = nms_spec.description.output[i].type.multiArrayType
871
+ ma_type.shapeRange.sizeRanges.add()
872
+ ma_type.shapeRange.sizeRanges[0].lowerBound = 0
873
+ ma_type.shapeRange.sizeRanges[0].upperBound = -1
874
+ ma_type.shapeRange.sizeRanges.add()
875
+ ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
876
+ ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
877
+ del ma_type.shape[:]
878
+
879
+ nms = nms_spec.nonMaximumSuppression
880
+ nms.confidenceInputFeatureName = out0.name # 1x507x80
881
+ nms.coordinatesInputFeatureName = out1.name # 1x507x4
882
+ nms.confidenceOutputFeatureName = 'confidence'
883
+ nms.coordinatesOutputFeatureName = 'coordinates'
884
+ nms.iouThresholdInputFeatureName = 'iouThreshold'
885
+ nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
886
+ nms.iouThreshold = 0.45
887
+ nms.confidenceThreshold = 0.25
888
+ nms.pickTop.perClass = True
889
+ nms.stringClassLabels.vector.extend(names.values())
890
+ nms_model = ct.models.MLModel(nms_spec)
891
+
892
+ # 4. Pipeline models together
893
+ pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)),
894
+ ('iouThreshold', ct.models.datatypes.Double()),
895
+ ('confidenceThreshold', ct.models.datatypes.Double())],
896
+ output_features=['confidence', 'coordinates'])
897
+ pipeline.add_model(model)
898
+ pipeline.add_model(nms_model)
899
+
900
+ # Correct datatypes
901
+ pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
902
+ pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
903
+ pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
904
+
905
+ # Update metadata
906
+ pipeline.spec.specificationVersion = 5
907
+ pipeline.spec.description.metadata.userDefined.update({
908
+ 'IoU threshold': str(nms.iouThreshold),
909
+ 'Confidence threshold': str(nms.confidenceThreshold)})
910
+
911
+ # Save the model
912
+ model = ct.models.MLModel(pipeline.spec)
913
+ model.input_description['image'] = 'Input image'
914
+ model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})'
915
+ model.input_description['confidenceThreshold'] = \
916
+ f'(optional) Confidence threshold override (default: {nms.confidenceThreshold})'
917
+ model.output_description['confidence'] = 'Boxes × Class confidence (see user-defined metadata "classes")'
918
+ model.output_description['coordinates'] = 'Boxes × [x, y, width, height] (relative to image size)'
919
+ LOGGER.info(f'{prefix} pipeline success')
920
+ return model
921
+
922
+ def add_callback(self, event: str, callback):
923
+ """
924
+ Appends the given callback.
925
+ """
926
+ self.callbacks[event].append(callback)
927
+
928
+ def run_callbacks(self, event: str):
929
+ """Execute all callbacks for a given event."""
930
+ for callback in self.callbacks.get(event, []):
931
+ callback(self)
932
+
933
+
934
+ class iOSDetectModel(torch.nn.Module):
935
+ """Wrap an Ultralytics YOLO model for iOS export."""
936
+
937
+ def __init__(self, model, im):
938
+ """Initialize the iOSDetectModel class with a YOLO model and example image."""
939
+ super().__init__()
940
+ b, c, h, w = im.shape # batch, channel, height, width
941
+ self.model = model
942
+ self.nc = len(model.names) # number of classes
943
+ if w == h:
944
+ self.normalize = 1.0 / w # scalar
945
+ else:
946
+ self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
947
+
948
+ def forward(self, x):
949
+ """Normalize predictions of object detection model with input size-dependent factors."""
950
+ xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
951
+ return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
952
+
953
+
954
+ def export(cfg=DEFAULT_CFG):
955
+ """Export a YOLOv model to a specific format."""
956
+ cfg.model = cfg.model or 'yolov8n.yaml'
957
+ cfg.format = cfg.format or 'torchscript'
958
+
959
+ from ultralytics import YOLO
960
+ model = YOLO(cfg.model)
961
+ model.export(**vars(cfg))
962
+
963
+
964
+ if __name__ == '__main__':
965
+ """
966
+ CLI:
967
+ yolo mode=export model=yolov8n.yaml format=onnx
968
+ """
969
+ export()
ultralytics/engine/model.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import inspect
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import Union
7
+
8
+ from ultralytics.cfg import get_cfg
9
+ from ultralytics.engine.exporter import Exporter
10
+ from ultralytics.hub.utils import HUB_WEB_ROOT
11
+ from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
12
+ from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
13
+ is_git_dir, yaml_load)
14
+ from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
15
+ from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
16
+ from ultralytics.utils.torch_utils import smart_inference_mode
17
+
18
+
19
+ class Model:
20
+ """
21
+ A base model class to unify apis for all the models.
22
+
23
+ Args:
24
+ model (str, Path): Path to the model file to load or create.
25
+ task (Any, optional): Task type for the YOLO model. Defaults to None.
26
+
27
+ Attributes:
28
+ predictor (Any): The predictor object.
29
+ model (Any): The model object.
30
+ trainer (Any): The trainer object.
31
+ task (str): The type of model task.
32
+ ckpt (Any): The checkpoint object if the model loaded from *.pt file.
33
+ cfg (str): The model configuration if loaded from *.yaml file.
34
+ ckpt_path (str): The checkpoint file path.
35
+ overrides (dict): Overrides for the trainer object.
36
+ metrics (Any): The data for metrics.
37
+
38
+ Methods:
39
+ __call__(source=None, stream=False, **kwargs):
40
+ Alias for the predict method.
41
+ _new(cfg:str, verbose:bool=True) -> None:
42
+ Initializes a new model and infers the task type from the model definitions.
43
+ _load(weights:str, task:str='') -> None:
44
+ Initializes a new model and infers the task type from the model head.
45
+ _check_is_pytorch_model() -> None:
46
+ Raises TypeError if the model is not a PyTorch model.
47
+ reset() -> None:
48
+ Resets the model modules.
49
+ info(verbose:bool=False) -> None:
50
+ Logs the model info.
51
+ fuse() -> None:
52
+ Fuses the model for faster inference.
53
+ predict(source=None, stream=False, **kwargs) -> List[ultralytics.engine.results.Results]:
54
+ Performs prediction using the YOLO model.
55
+
56
+ Returns:
57
+ list(ultralytics.engine.results.Results): The prediction results.
58
+ """
59
+
60
+ def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
61
+ """
62
+ Initializes the YOLO model.
63
+
64
+ Args:
65
+ model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
66
+ task (Any, optional): Task type for the YOLO model. Defaults to None.
67
+ """
68
+ self.callbacks = callbacks.get_default_callbacks()
69
+ self.predictor = None # reuse predictor
70
+ self.model = None # model object
71
+ self.trainer = None # trainer object
72
+ self.ckpt = None # if loaded from *.pt
73
+ self.cfg = None # if loaded from *.yaml
74
+ self.ckpt_path = None
75
+ self.overrides = {} # overrides for trainer object
76
+ self.metrics = None # validation/training metrics
77
+ self.session = None # HUB session
78
+ self.task = task # task type
79
+ model = str(model).strip() # strip spaces
80
+
81
+ # Check if Ultralytics HUB model from https://hub.ultralytics.com
82
+ if self.is_hub_model(model):
83
+ from ultralytics.hub.session import HUBTrainingSession
84
+ self.session = HUBTrainingSession(model)
85
+ model = self.session.model_file
86
+
87
+ # Load or create new YOLO model
88
+ suffix = Path(model).suffix
89
+ if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
90
+ model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
91
+ if suffix in ('.yaml', '.yml'):
92
+ self._new(model, task)
93
+ else:
94
+ self._load(model, task)
95
+
96
+ def __call__(self, source=None, stream=False, **kwargs):
97
+ """Calls the 'predict' function with given arguments to perform object detection."""
98
+ return self.predict(source, stream, **kwargs)
99
+
100
+ @staticmethod
101
+ def is_hub_model(model):
102
+ """Check if the provided model is a HUB model."""
103
+ return any((
104
+ model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
105
+ [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
106
+ len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
107
+
108
+ def _new(self, cfg: str, task=None, model=None, verbose=True):
109
+ """
110
+ Initializes a new model and infers the task type from the model definitions.
111
+
112
+ Args:
113
+ cfg (str): model configuration file
114
+ task (str | None): model task
115
+ model (BaseModel): Customized model.
116
+ verbose (bool): display model info on load
117
+ """
118
+ cfg_dict = yaml_model_load(cfg)
119
+ self.cfg = cfg
120
+ self.task = task or guess_model_task(cfg_dict)
121
+ model = model or self.smart_load('model')
122
+ self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
123
+ self.overrides['model'] = self.cfg
124
+
125
+ # Below added to allow export from yamls
126
+ args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
127
+ self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
128
+ self.model.task = self.task
129
+
130
+ def _load(self, weights: str, task=None):
131
+ """
132
+ Initializes a new model and infers the task type from the model head.
133
+
134
+ Args:
135
+ weights (str): model checkpoint to be loaded
136
+ task (str | None): model task
137
+ """
138
+ suffix = Path(weights).suffix
139
+ if suffix == '.pt':
140
+ self.model, self.ckpt = attempt_load_one_weight(weights)
141
+ self.task = self.model.args['task']
142
+ self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
143
+ self.ckpt_path = self.model.pt_path
144
+ else:
145
+ weights = check_file(weights)
146
+ self.model, self.ckpt = weights, None
147
+ self.task = task or guess_model_task(weights)
148
+ self.ckpt_path = weights
149
+ self.overrides['model'] = weights
150
+ self.overrides['task'] = self.task
151
+
152
+ def _check_is_pytorch_model(self):
153
+ """
154
+ Raises TypeError is model is not a PyTorch model
155
+ """
156
+ pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
157
+ pt_module = isinstance(self.model, nn.Module)
158
+ if not (pt_module or pt_str):
159
+ raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
160
+ f'PyTorch models can be used to train, val, predict and export, i.e. '
161
+ f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
162
+ f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
163
+
164
+ @smart_inference_mode()
165
+ def reset_weights(self):
166
+ """
167
+ Resets the model modules parameters to randomly initialized values, losing all training information.
168
+ """
169
+ self._check_is_pytorch_model()
170
+ for m in self.model.modules():
171
+ if hasattr(m, 'reset_parameters'):
172
+ m.reset_parameters()
173
+ for p in self.model.parameters():
174
+ p.requires_grad = True
175
+ return self
176
+
177
+ @smart_inference_mode()
178
+ def load(self, weights='yolov8n.pt'):
179
+ """
180
+ Transfers parameters with matching names and shapes from 'weights' to model.
181
+ """
182
+ self._check_is_pytorch_model()
183
+ if isinstance(weights, (str, Path)):
184
+ weights, self.ckpt = attempt_load_one_weight(weights)
185
+ self.model.load(weights)
186
+ return self
187
+
188
+ def info(self, detailed=False, verbose=True):
189
+ """
190
+ Logs model info.
191
+
192
+ Args:
193
+ detailed (bool): Show detailed information about model.
194
+ verbose (bool): Controls verbosity.
195
+ """
196
+ self._check_is_pytorch_model()
197
+ return self.model.info(detailed=detailed, verbose=verbose)
198
+
199
+ def fuse(self):
200
+ """Fuse PyTorch Conv2d and BatchNorm2d layers."""
201
+ self._check_is_pytorch_model()
202
+ self.model.fuse()
203
+
204
+ @smart_inference_mode()
205
+ def predict(self, source=None, stream=False, predictor=None, **kwargs):
206
+ """
207
+ Perform prediction using the YOLO model.
208
+
209
+ Args:
210
+ source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
211
+ Accepts all source types accepted by the YOLO model.
212
+ stream (bool): Whether to stream the predictions or not. Defaults to False.
213
+ predictor (BasePredictor): Customized predictor.
214
+ **kwargs : Additional keyword arguments passed to the predictor.
215
+ Check the 'configuration' section in the documentation for all available options.
216
+
217
+ Returns:
218
+ (List[ultralytics.engine.results.Results]): The prediction results.
219
+ """
220
+ if source is None:
221
+ source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
222
+ LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
223
+ is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
224
+ x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
225
+ # Check prompts for SAM/FastSAM
226
+ prompts = kwargs.pop('prompts', None)
227
+ overrides = self.overrides.copy()
228
+ overrides['conf'] = 0.25
229
+ overrides.update(kwargs) # prefer kwargs
230
+ overrides['mode'] = kwargs.get('mode', 'predict')
231
+ assert overrides['mode'] in ['track', 'predict']
232
+ if not is_cli:
233
+ overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
234
+ if not self.predictor:
235
+ self.task = overrides.get('task') or self.task
236
+ predictor = predictor or self.smart_load('predictor')
237
+ self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks)
238
+ self.predictor.setup_model(model=self.model, verbose=is_cli)
239
+ else: # only update args if predictor is already setup
240
+ self.predictor.args = get_cfg(self.predictor.args, overrides)
241
+ if 'project' in overrides or 'name' in overrides:
242
+ self.predictor.save_dir = self.predictor.get_save_dir()
243
+ # Set prompts for SAM/FastSAM
244
+ if len and hasattr(self.predictor, 'set_prompts'):
245
+ self.predictor.set_prompts(prompts)
246
+ return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
247
+
248
+ def track(self, source=None, stream=False, persist=False, **kwargs):
249
+ """
250
+ Perform object tracking on the input source using the registered trackers.
251
+
252
+ Args:
253
+ source (str, optional): The input source for object tracking. Can be a file path or a video stream.
254
+ stream (bool, optional): Whether the input source is a video stream. Defaults to False.
255
+ persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
256
+ **kwargs (optional): Additional keyword arguments for the tracking process.
257
+
258
+ Returns:
259
+ (List[ultralytics.engine.results.Results]): The tracking results.
260
+
261
+ """
262
+ if not hasattr(self.predictor, 'trackers'):
263
+ from ultralytics.trackers import register_tracker
264
+ register_tracker(self, persist)
265
+ # ByteTrack-based method needs low confidence predictions as input
266
+ conf = kwargs.get('conf') or 0.1
267
+ kwargs['conf'] = conf
268
+ kwargs['mode'] = 'track'
269
+ return self.predict(source=source, stream=stream, **kwargs)
270
+
271
+ @smart_inference_mode()
272
+ def val(self, data=None, validator=None, **kwargs):
273
+ """
274
+ Validate a model on a given dataset.
275
+
276
+ Args:
277
+ data (str): The dataset to validate on. Accepts all formats accepted by yolo
278
+ validator (BaseValidator): Customized validator.
279
+ **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
280
+ """
281
+ overrides = self.overrides.copy()
282
+ overrides['rect'] = True # rect batches as default
283
+ overrides.update(kwargs)
284
+ overrides['mode'] = 'val'
285
+ args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
286
+ args.data = data or args.data
287
+ if 'task' in overrides:
288
+ self.task = args.task
289
+ else:
290
+ args.task = self.task
291
+ validator = validator or self.smart_load('validator')
292
+ if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
293
+ args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
294
+ args.imgsz = check_imgsz(args.imgsz, max_dim=1)
295
+
296
+ validator = validator(args=args, _callbacks=self.callbacks)
297
+ validator(model=self.model)
298
+ self.metrics = validator.metrics
299
+
300
+ return validator.metrics
301
+
302
+ @smart_inference_mode()
303
+ def benchmark(self, **kwargs):
304
+ """
305
+ Benchmark a model on all export formats.
306
+
307
+ Args:
308
+ **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
309
+ """
310
+ self._check_is_pytorch_model()
311
+ from ultralytics.utils.benchmarks import benchmark
312
+ overrides = self.model.args.copy()
313
+ overrides.update(kwargs)
314
+ overrides['mode'] = 'benchmark'
315
+ overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
316
+ return benchmark(
317
+ model=self,
318
+ data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
319
+ imgsz=overrides['imgsz'],
320
+ half=overrides['half'],
321
+ int8=overrides['int8'],
322
+ device=overrides['device'],
323
+ verbose=overrides['verbose'])
324
+
325
+ def export(self, **kwargs):
326
+ """
327
+ Export model.
328
+
329
+ Args:
330
+ **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
331
+ """
332
+ self._check_is_pytorch_model()
333
+ overrides = self.overrides.copy()
334
+ overrides.update(kwargs)
335
+ overrides['mode'] = 'export'
336
+ if overrides.get('imgsz') is None:
337
+ overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
338
+ if 'batch' not in kwargs:
339
+ overrides['batch'] = 1 # default to 1 if not modified
340
+ if 'data' not in kwargs:
341
+ overrides['data'] = None # default to None if not modified (avoid int8 calibration with coco.yaml)
342
+ args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
343
+ args.task = self.task
344
+ return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
345
+
346
+ def train(self, trainer=None, **kwargs):
347
+ """
348
+ Trains the model on a given dataset.
349
+
350
+ Args:
351
+ trainer (BaseTrainer, optional): Customized trainer.
352
+ **kwargs (Any): Any number of arguments representing the training configuration.
353
+ """
354
+ self._check_is_pytorch_model()
355
+ if self.session: # Ultralytics HUB session
356
+ if any(kwargs):
357
+ LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
358
+ kwargs = self.session.train_args
359
+ check_pip_update_available()
360
+ overrides = self.overrides.copy()
361
+ if kwargs.get('cfg'):
362
+ LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
363
+ overrides = yaml_load(check_yaml(kwargs['cfg']))
364
+ overrides.update(kwargs)
365
+ overrides['mode'] = 'train'
366
+ if not overrides.get('data'):
367
+ raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
368
+ if overrides.get('resume'):
369
+ overrides['resume'] = self.ckpt_path
370
+ self.task = overrides.get('task') or self.task
371
+ trainer = trainer or self.smart_load('trainer')
372
+ self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
373
+ if not overrides.get('resume'): # manually set model only if not resuming
374
+ self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
375
+ self.model = self.trainer.model
376
+ self.trainer.hub_session = self.session # attach optional HUB session
377
+ self.trainer.train()
378
+ # Update model and cfg after training
379
+ if RANK in (-1, 0):
380
+ self.model, _ = attempt_load_one_weight(str(self.trainer.best))
381
+ self.overrides = self.model.args
382
+ self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
383
+
384
+ def to(self, device):
385
+ """
386
+ Sends the model to the given device.
387
+
388
+ Args:
389
+ device (str): device
390
+ """
391
+ self._check_is_pytorch_model()
392
+ self.model.to(device)
393
+
394
+ def tune(self, *args, **kwargs):
395
+ """
396
+ Runs hyperparameter tuning using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args.
397
+
398
+ Returns:
399
+ (dict): A dictionary containing the results of the hyperparameter search.
400
+
401
+ Raises:
402
+ ModuleNotFoundError: If Ray Tune is not installed.
403
+ """
404
+ self._check_is_pytorch_model()
405
+ from ultralytics.utils.tuner import run_ray_tune
406
+ return run_ray_tune(self, *args, **kwargs)
407
+
408
+ @property
409
+ def names(self):
410
+ """Returns class names of the loaded model."""
411
+ return self.model.names if hasattr(self.model, 'names') else None
412
+
413
+ @property
414
+ def device(self):
415
+ """Returns device if PyTorch model."""
416
+ return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
417
+
418
+ @property
419
+ def transforms(self):
420
+ """Returns transform of the loaded model."""
421
+ return self.model.transforms if hasattr(self.model, 'transforms') else None
422
+
423
+ def add_callback(self, event: str, func):
424
+ """Add a callback."""
425
+ self.callbacks[event].append(func)
426
+
427
+ def clear_callback(self, event: str):
428
+ """Clear all event callbacks."""
429
+ self.callbacks[event] = []
430
+
431
+ @staticmethod
432
+ def _reset_ckpt_args(args):
433
+ """Reset arguments when loading a PyTorch model."""
434
+ include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
435
+ return {k: v for k, v in args.items() if k in include}
436
+
437
+ def _reset_callbacks(self):
438
+ """Reset all registered callbacks."""
439
+ for event in callbacks.default_callbacks.keys():
440
+ self.callbacks[event] = [callbacks.default_callbacks[event][0]]
441
+
442
+ def __getattr__(self, attr):
443
+ """Raises error if object has no requested attribute."""
444
+ name = self.__class__.__name__
445
+ raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
446
+
447
+ def smart_load(self, key):
448
+ """Load model/trainer/validator/predictor."""
449
+ try:
450
+ return self.task_map[self.task][key]
451
+ except Exception:
452
+ name = self.__class__.__name__
453
+ mode = inspect.stack()[1][3] # get the function name.
454
+ raise NotImplementedError(
455
+ f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')
456
+
457
+ @property
458
+ def task_map(self):
459
+ """
460
+ Map head to model, trainer, validator, and predictor classes.
461
+
462
+ Returns:
463
+ task_map (dict): The map of model task to mode classes.
464
+ """
465
+ raise NotImplementedError('Please provide task map for your model!')
ultralytics/engine/predictor.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
4
+
5
+ Usage - sources:
6
+ $ yolo mode=predict model=yolov8n.pt source=0 # webcam
7
+ img.jpg # image
8
+ vid.mp4 # video
9
+ screen # screenshot
10
+ path/ # directory
11
+ list.txt # list of images
12
+ list.streams # list of streams
13
+ 'path/*.jpg' # glob
14
+ 'https://youtu.be/Zgi9g1ksQHc' # YouTube
15
+ 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
16
+
17
+ Usage - formats:
18
+ $ yolo mode=predict model=yolov8n.pt # PyTorch
19
+ yolov8n.torchscript # TorchScript
20
+ yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
21
+ yolov8n_openvino_model # OpenVINO
22
+ yolov8n.engine # TensorRT
23
+ yolov8n.mlmodel # CoreML (macOS-only)
24
+ yolov8n_saved_model # TensorFlow SavedModel
25
+ yolov8n.pb # TensorFlow GraphDef
26
+ yolov8n.tflite # TensorFlow Lite
27
+ yolov8n_edgetpu.tflite # TensorFlow Edge TPU
28
+ yolov8n_paddle_model # PaddlePaddle
29
+ """
30
+ import platform
31
+ from pathlib import Path
32
+
33
+ import cv2
34
+ import numpy as np
35
+ import torch
36
+
37
+ from ultralytics.cfg import get_cfg
38
+ from ultralytics.data import load_inference_source
39
+ from ultralytics.data.augment import LetterBox, classify_transforms
40
+ from ultralytics.nn.autobackend import AutoBackend
41
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
42
+ from ultralytics.utils.checks import check_imgsz, check_imshow
43
+ from ultralytics.utils.files import increment_path
44
+ from ultralytics.utils.torch_utils import select_device, smart_inference_mode
45
+
46
+ STREAM_WARNING = """
47
+ WARNING ⚠️ stream/video/webcam/dir predict source will accumulate results in RAM unless `stream=True` is passed,
48
+ causing potential out-of-memory errors for large sources or long-running streams/videos.
49
+
50
+ Usage:
51
+ results = model(source=..., stream=True) # generator of Results objects
52
+ for r in results:
53
+ boxes = r.boxes # Boxes object for bbox outputs
54
+ masks = r.masks # Masks object for segment masks outputs
55
+ probs = r.probs # Class probabilities for classification outputs
56
+ """
57
+
58
+ inference_Time=0
59
+ class BasePredictor:
60
+ """
61
+ BasePredictor
62
+
63
+ A base class for creating predictors.
64
+
65
+ Attributes:
66
+ args (SimpleNamespace): Configuration for the predictor.
67
+ save_dir (Path): Directory to save results.
68
+ done_warmup (bool): Whether the predictor has finished setup.
69
+ model (nn.Module): Model used for prediction.
70
+ data (dict): Data configuration.
71
+ device (torch.device): Device used for prediction.
72
+ dataset (Dataset): Dataset used for prediction.
73
+ vid_path (str): Path to video file.
74
+ vid_writer (cv2.VideoWriter): Video writer for saving video output.
75
+ data_path (str): Path to data.
76
+ """
77
+
78
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
79
+ """
80
+ Initializes the BasePredictor class.
81
+
82
+ Args:
83
+ cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
84
+ overrides (dict, optional): Configuration overrides. Defaults to None.
85
+ """
86
+ self.args = get_cfg(cfg, overrides)
87
+ self.save_dir = self.get_save_dir()
88
+ if self.args.conf is None:
89
+ self.args.conf = 0.25 # default conf=0.25
90
+ self.done_warmup = False
91
+ if self.args.show:
92
+ self.args.show = check_imshow(warn=True)
93
+
94
+ # Usable if setup is done
95
+ self.model = None
96
+ self.data = self.args.data # data_dict
97
+ self.imgsz = None
98
+ self.device = None
99
+ self.dataset = None
100
+ self.vid_path, self.vid_writer = None, None
101
+ self.plotted_img = None
102
+ self.data_path = None
103
+ self.source_type = None
104
+ self.batch = None
105
+ self.results = None
106
+ self.transforms = None
107
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
108
+ self.txt_path = None
109
+ callbacks.add_integration_callbacks(self)
110
+
111
+ def get_save_dir(self):
112
+ project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
113
+ name = self.args.name or f'{self.args.mode}'
114
+ return increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
115
+
116
+ def preprocess(self, im):
117
+ """Prepares input image before inference.
118
+
119
+ Args:
120
+ im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
121
+ """
122
+ not_tensor = not isinstance(im, torch.Tensor)
123
+ if not_tensor:
124
+ im = np.stack(self.pre_transform(im))
125
+ im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
126
+ im = np.ascontiguousarray(im) # contiguous
127
+ im = torch.from_numpy(im)
128
+
129
+ img = im.to(self.device)
130
+ img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
131
+ if not_tensor:
132
+ img /= 255 # 0 - 255 to 0.0 - 1.0
133
+ return img
134
+
135
+ def inference(self, im, *args, **kwargs):
136
+ visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
137
+ mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
138
+ return self.model(im, augment=self.args.augment, visualize=visualize)
139
+
140
+ def pre_transform(self, im):
141
+ """Pre-transform input image before inference.
142
+
143
+ Args:
144
+ im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
145
+
146
+ Return: A list of transformed imgs.
147
+ """
148
+ same_shapes = all(x.shape == im[0].shape for x in im)
149
+ auto = same_shapes and self.model.pt
150
+ return [LetterBox(self.imgsz, auto=auto, stride=self.model.stride)(image=x) for x in im]
151
+
152
+ def write_results(self, idx, results, batch):
153
+ """Write inference results to a file or directory."""
154
+ p, im, _ = batch
155
+ log_string = ''
156
+ if len(im.shape) == 3:
157
+ im = im[None] # expand for batch dim
158
+ if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
159
+ log_string += f'{idx}: '
160
+ frame = self.dataset.count
161
+ else:
162
+ frame = getattr(self.dataset, 'frame', 0)
163
+ self.data_path = p
164
+ self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
165
+ log_string += '%gx%g ' % im.shape[2:] # print string
166
+ result = results[idx]
167
+ log_string += result.verbose()
168
+
169
+ if self.args.save or self.args.show: # Add bbox to image
170
+ plot_args = {
171
+ 'line_width': self.args.line_width,
172
+ 'boxes': self.args.boxes,
173
+ 'conf': self.args.show_conf,
174
+ 'labels': self.args.show_labels}
175
+ if not self.args.retina_masks:
176
+ plot_args['im_gpu'] = im[idx]
177
+ self.plotted_img = result.plot(**plot_args)
178
+ # Write
179
+ if self.args.save_txt:
180
+ result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
181
+ if self.args.save_crop:
182
+ result.save_crop(save_dir=self.save_dir / 'crops',
183
+ file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}'))
184
+
185
+ return log_string
186
+
187
+ def postprocess(self, preds, img, orig_imgs):
188
+ """Post-processes predictions for an image and returns them."""
189
+ return preds
190
+
191
+ def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
192
+ """Performs inference on an image or stream."""
193
+ self.stream = stream
194
+ if stream:
195
+ return self.stream_inference(source, model, *args, **kwargs)
196
+ else:
197
+ return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
198
+
199
+ def predict_cli(self, source=None, model=None):
200
+ """Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
201
+ gen = self.stream_inference(source, model)
202
+ for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
203
+ pass
204
+
205
+ def setup_source(self, source):
206
+ """Sets up source and inference mode."""
207
+ self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
208
+ self.transforms = getattr(self.model.model, 'transforms', classify_transforms(
209
+ self.imgsz[0])) if self.args.task == 'classify' else None
210
+ self.dataset = load_inference_source(source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride)
211
+ self.source_type = self.dataset.source_type
212
+ if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
213
+ len(self.dataset) > 1000 or # images
214
+ any(getattr(self.dataset, 'video_flag', [False]))): # videos
215
+ LOGGER.warning(STREAM_WARNING)
216
+ self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
217
+
218
+ @smart_inference_mode()
219
+ def stream_inference(self, source=None, model=None, *args, **kwargs):
220
+ """Streams real-time inference on camera feed and saves results to file."""
221
+ if self.args.verbose:
222
+ LOGGER.info('')
223
+
224
+ # Setup model
225
+ if not self.model:
226
+ self.setup_model(model)
227
+
228
+ # Setup source every time predict is called
229
+ self.setup_source(source if source is not None else self.args.source)
230
+
231
+ # Check if save_dir/ label file exists
232
+ if self.args.save or self.args.save_txt:
233
+ (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
234
+
235
+ # Warmup model
236
+ if not self.done_warmup:
237
+ self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
238
+ self.done_warmup = True
239
+
240
+ self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
241
+ self.run_callbacks('on_predict_start')
242
+ for batch in self.dataset:
243
+ self.run_callbacks('on_predict_batch_start')
244
+ self.batch = batch
245
+ path, im0s, vid_cap, s = batch
246
+
247
+ # Preprocess
248
+ with profilers[0]:
249
+ im = self.preprocess(im0s)
250
+
251
+ # Inference
252
+ with profilers[1]:
253
+ preds = self.inference(im, *args, **kwargs)
254
+
255
+ # Postprocess
256
+ with profilers[2]:
257
+ self.results = self.postprocess(preds, im, im0s)
258
+ self.run_callbacks('on_predict_postprocess_end')
259
+
260
+ # Visualize, save, write results
261
+ n = len(im0s)
262
+ for i in range(n):
263
+ self.seen += 1
264
+ self.results[i].speed = {
265
+ 'preprocess': profilers[0].dt * 1E3 / n,
266
+ 'inference': profilers[1].dt * 1E3 / n,
267
+ 'postprocess': profilers[2].dt * 1E3 / n}
268
+ p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
269
+ p = Path(p)
270
+
271
+ if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
272
+ s += self.write_results(i, self.results, (p, im, im0))
273
+ if self.args.save or self.args.save_txt:
274
+ self.results[i].save_dir = self.save_dir.__str__()
275
+ if self.args.show and self.plotted_img is not None:
276
+ self.show(p)
277
+ if self.args.save and self.plotted_img is not None:
278
+ self.save_preds(vid_cap, i, str(self.save_dir / p.name))
279
+
280
+ self.run_callbacks('on_predict_batch_end')
281
+ yield from self.results
282
+
283
+ # Print time (inference-only)
284
+ if self.args.verbose:
285
+ LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
286
+
287
+ # Release assets
288
+ if isinstance(self.vid_writer[-1], cv2.VideoWriter):
289
+ self.vid_writer[-1].release() # release final video writer
290
+
291
+ # Print results
292
+ if self.args.verbose and self.seen:
293
+ t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
294
+ LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
295
+ f'{(1, 3, *im.shape[2:])}' % t)
296
+ if self.args.save or self.args.save_txt or self.args.save_crop:
297
+ nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
298
+ s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
299
+ LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
300
+
301
+ self.run_callbacks('on_predict_end')
302
+
303
+ def setup_model(self, model, verbose=True):
304
+ """Initialize YOLO model with given parameters and set it to evaluation mode."""
305
+ self.model = AutoBackend(model or self.args.model,
306
+ device=select_device(self.args.device, verbose=verbose),
307
+ dnn=self.args.dnn,
308
+ data=self.args.data,
309
+ fp16=self.args.half,
310
+ fuse=True,
311
+ verbose=verbose)
312
+
313
+ self.device = self.model.device # update device
314
+ self.args.half = self.model.fp16 # update half
315
+ self.model.eval()
316
+
317
+ def show(self, p):
318
+ """Display an image in a window using OpenCV imshow()."""
319
+ im0 = self.plotted_img
320
+ if platform.system() == 'Linux' and p not in self.windows:
321
+ self.windows.append(p)
322
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
323
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
324
+ cv2.imshow(str(p), im0)
325
+ cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond
326
+
327
+ def save_preds(self, vid_cap, idx, save_path):
328
+ """Save video predictions as mp4 at specified path."""
329
+ im0 = self.plotted_img
330
+ # Save imgs
331
+ if self.dataset.mode == 'image':
332
+ cv2.imwrite(save_path, im0)
333
+ else: # 'video' or 'stream'
334
+ if self.vid_path[idx] != save_path: # new video
335
+ self.vid_path[idx] = save_path
336
+ if isinstance(self.vid_writer[idx], cv2.VideoWriter):
337
+ self.vid_writer[idx].release() # release previous video writer
338
+ if vid_cap: # video
339
+ fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec
340
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
341
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
342
+ else: # stream
343
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
344
+ suffix = '.mp4' if MACOS else '.avi' if WINDOWS else '.avi'
345
+ fourcc = 'avc1' if MACOS else 'WMV2' if WINDOWS else 'MJPG'
346
+ save_path = str(Path(save_path).with_suffix(suffix))
347
+ self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
348
+ self.vid_writer[idx].write(im0)
349
+
350
+ def run_callbacks(self, event: str):
351
+ """Runs all registered callbacks for a specific event."""
352
+ for callback in self.callbacks.get(event, []):
353
+ callback(self)
354
+
355
+ def add_callback(self, event: str, func):
356
+ """
357
+ Add callback
358
+ """
359
+ self.callbacks[event].append(func)
ultralytics/engine/results.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Ultralytics Results, Boxes and Masks classes for handling inference results
4
+
5
+ Usage: See https://docs.ultralytics.com/modes/predict/
6
+ """
7
+
8
+ from copy import deepcopy
9
+ from functools import lru_cache
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ from ultralytics.data.augment import LetterBox
16
+ from ultralytics.utils import LOGGER, SimpleClass, deprecation_warn, ops
17
+ from ultralytics.utils.plotting import Annotator, colors, save_one_box
18
+
19
+
20
+ class BaseTensor(SimpleClass):
21
+ """
22
+ Base tensor class with additional methods for easy manipulation and device handling.
23
+ """
24
+
25
+ def __init__(self, data, orig_shape) -> None:
26
+ """Initialize BaseTensor with data and original shape.
27
+
28
+ Args:
29
+ data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
30
+ orig_shape (tuple): Original shape of image.
31
+ """
32
+ assert isinstance(data, (torch.Tensor, np.ndarray))
33
+ self.data = data
34
+ self.orig_shape = orig_shape
35
+
36
+ @property
37
+ def shape(self):
38
+ """Return the shape of the data tensor."""
39
+ return self.data.shape
40
+
41
+ def cpu(self):
42
+ """Return a copy of the tensor on CPU memory."""
43
+ return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
44
+
45
+ def numpy(self):
46
+ """Return a copy of the tensor as a numpy array."""
47
+ return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
48
+
49
+ def cuda(self):
50
+ """Return a copy of the tensor on GPU memory."""
51
+ return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)
52
+
53
+ def to(self, *args, **kwargs):
54
+ """Return a copy of the tensor with the specified device and dtype."""
55
+ return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)
56
+
57
+ def __len__(self): # override len(results)
58
+ """Return the length of the data tensor."""
59
+ return len(self.data)
60
+
61
+ def __getitem__(self, idx):
62
+ """Return a BaseTensor with the specified index of the data tensor."""
63
+ return self.__class__(self.data[idx], self.orig_shape)
64
+
65
+
66
+ class Results(SimpleClass):
67
+ """
68
+ A class for storing and manipulating inference results.
69
+
70
+ Args:
71
+ orig_img (numpy.ndarray): The original image as a numpy array.
72
+ path (str): The path to the image file.
73
+ names (dict): A dictionary of class names.
74
+ boxes (torch.tensor, optional): A 2D tensor of bounding box coordinates for each detection.
75
+ masks (torch.tensor, optional): A 3D tensor of detection masks, where each mask is a binary image.
76
+ probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task.
77
+ keypoints (List[List[float]], optional): A list of detected keypoints for each object.
78
+
79
+ Attributes:
80
+ orig_img (numpy.ndarray): The original image as a numpy array.
81
+ orig_shape (tuple): The original image shape in (height, width) format.
82
+ boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
83
+ masks (Masks, optional): A Masks object containing the detection masks.
84
+ probs (Probs, optional): A Probs object containing probabilities of each class for classification task.
85
+ keypoints (Keypoints, optional): A Keypoints object containing detected keypoints for each object.
86
+ speed (dict): A dictionary of preprocess, inference, and postprocess speeds in milliseconds per image.
87
+ names (dict): A dictionary of class names.
88
+ path (str): The path to the image file.
89
+ _keys (tuple): A tuple of attribute names for non-empty attributes.
90
+ """
91
+
92
+ def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
93
+ """Initialize the Results class."""
94
+ self.orig_img = orig_img
95
+ self.orig_shape = orig_img.shape[:2]
96
+ self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
97
+ self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
98
+ self.probs = Probs(probs) if probs is not None else None
99
+ self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
100
+ self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
101
+ self.names = names
102
+ self.path = path
103
+ self.save_dir = None
104
+ self._keys = ('boxes', 'masks', 'probs', 'keypoints')
105
+
106
+ def __getitem__(self, idx):
107
+ """Return a Results object for the specified index."""
108
+ r = self.new()
109
+ for k in self.keys:
110
+ setattr(r, k, getattr(self, k)[idx])
111
+ return r
112
+
113
+ def __len__(self):
114
+ """Return the number of detections in the Results object."""
115
+ for k in self.keys:
116
+ return len(getattr(self, k))
117
+
118
+ def update(self, boxes=None, masks=None, probs=None):
119
+ """Update the boxes, masks, and probs attributes of the Results object."""
120
+ if boxes is not None:
121
+ ops.clip_boxes(boxes, self.orig_shape) # clip boxes
122
+ self.boxes = Boxes(boxes, self.orig_shape)
123
+ if masks is not None:
124
+ self.masks = Masks(masks, self.orig_shape)
125
+ if probs is not None:
126
+ self.probs = probs
127
+
128
+ def cpu(self):
129
+ """Return a copy of the Results object with all tensors on CPU memory."""
130
+ r = self.new()
131
+ for k in self.keys:
132
+ setattr(r, k, getattr(self, k).cpu())
133
+ return r
134
+
135
+ def numpy(self):
136
+ """Return a copy of the Results object with all tensors as numpy arrays."""
137
+ r = self.new()
138
+ for k in self.keys:
139
+ setattr(r, k, getattr(self, k).numpy())
140
+ return r
141
+
142
+ def cuda(self):
143
+ """Return a copy of the Results object with all tensors on GPU memory."""
144
+ r = self.new()
145
+ for k in self.keys:
146
+ setattr(r, k, getattr(self, k).cuda())
147
+ return r
148
+
149
+ def to(self, *args, **kwargs):
150
+ """Return a copy of the Results object with tensors on the specified device and dtype."""
151
+ r = self.new()
152
+ for k in self.keys:
153
+ setattr(r, k, getattr(self, k).to(*args, **kwargs))
154
+ return r
155
+
156
+ def new(self):
157
+ """Return a new Results object with the same image, path, and names."""
158
+ return Results(orig_img=self.orig_img, path=self.path, names=self.names)
159
+
160
+ @property
161
+ def keys(self):
162
+ """Return a list of non-empty attribute names."""
163
+ return [k for k in self._keys if getattr(self, k) is not None]
164
+
165
+ def plot(
166
+ self,
167
+ conf=True,
168
+ line_width=None,
169
+ font_size=None,
170
+ font='Arial.ttf',
171
+ pil=False,
172
+ img=None,
173
+ im_gpu=None,
174
+ kpt_radius=5,
175
+ kpt_line=True,
176
+ labels=True,
177
+ boxes=True,
178
+ masks=True,
179
+ probs=True,
180
+ **kwargs # deprecated args TODO: remove support in 8.2
181
+ ):
182
+ """
183
+ Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
184
+
185
+ Args:
186
+ conf (bool): Whether to plot the detection confidence score.
187
+ line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
188
+ font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
189
+ font (str): The font to use for the text.
190
+ pil (bool): Whether to return the image as a PIL Image.
191
+ img (numpy.ndarray): Plot to another image. if not, plot to original image.
192
+ im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
193
+ kpt_radius (int, optional): Radius of the drawn keypoints. Default is 5.
194
+ kpt_line (bool): Whether to draw lines connecting keypoints.
195
+ labels (bool): Whether to plot the label of bounding boxes.
196
+ boxes (bool): Whether to plot the bounding boxes.
197
+ masks (bool): Whether to plot the masks.
198
+ probs (bool): Whether to plot classification probability
199
+
200
+ Returns:
201
+ (numpy.ndarray): A numpy array of the annotated image.
202
+
203
+ Example:
204
+ ```python
205
+ from PIL import Image
206
+ from ultralytics import YOLO
207
+
208
+ model = YOLO('yolov8n.pt')
209
+ results = model('bus.jpg') # results list
210
+ for r in results:
211
+ im_array = r.plot() # plot a BGR numpy array of predictions
212
+ im = Image.fromarray(im[..., ::-1]) # RGB PIL image
213
+ im.show() # show image
214
+ im.save('results.jpg') # save image
215
+ ```
216
+ """
217
+ if img is None and isinstance(self.orig_img, torch.Tensor):
218
+ img = np.ascontiguousarray(self.orig_img[0].permute(1, 2, 0).cpu().detach().numpy()) * 255
219
+
220
+ # Deprecation warn TODO: remove in 8.2
221
+ if 'show_conf' in kwargs:
222
+ deprecation_warn('show_conf', 'conf')
223
+ conf = kwargs['show_conf']
224
+ assert type(conf) == bool, '`show_conf` should be of boolean type, i.e, show_conf=True/False'
225
+
226
+ if 'line_thickness' in kwargs:
227
+ deprecation_warn('line_thickness', 'line_width')
228
+ line_width = kwargs['line_thickness']
229
+ assert type(line_width) == int, '`line_width` should be of int type, i.e, line_width=3'
230
+
231
+ names = self.names
232
+ pred_boxes, show_boxes = self.boxes, boxes
233
+ pred_masks, show_masks = self.masks, masks
234
+ pred_probs, show_probs = self.probs, probs
235
+ annotator = Annotator(
236
+ deepcopy(self.orig_img if img is None else img),
237
+ line_width,
238
+ font_size,
239
+ font,
240
+ pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
241
+ example=names)
242
+
243
+ # Plot Segment results
244
+ if pred_masks and show_masks:
245
+ if im_gpu is None:
246
+ img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
247
+ im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
248
+ 2, 0, 1).flip(0).contiguous() / 255
249
+ idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
250
+ annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
251
+
252
+ # Plot Detect results
253
+ if pred_boxes and show_boxes:
254
+ for d in reversed(pred_boxes):
255
+ c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
256
+ name = ('' if id is None else f'id:{id} ') + names[c]
257
+ label = (f'{name} {conf:.2f}' if conf else name) if labels else None
258
+ annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
259
+
260
+ # Plot Classify results
261
+ if pred_probs is not None and show_probs:
262
+ text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
263
+ x = round(self.orig_shape[0] * 0.03)
264
+ annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
265
+
266
+ # Plot Pose results
267
+ if self.keypoints is not None:
268
+ for k in reversed(self.keypoints.data):
269
+ annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)
270
+
271
+ return annotator.result()
272
+
273
+ def verbose(self):
274
+ """
275
+ Return log string for each task.
276
+ """
277
+ log_string = ''
278
+ probs = self.probs
279
+ boxes = self.boxes
280
+ if len(self) == 0:
281
+ return log_string if probs is not None else f'{log_string}(no detections), '
282
+ if probs is not None:
283
+ log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
284
+ if boxes:
285
+ for c in boxes.cls.unique():
286
+ n = (boxes.cls == c).sum() # detections per class
287
+ log_string += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "
288
+ return log_string
289
+
290
+ def save_txt(self, txt_file, save_conf=False):
291
+ """
292
+ Save predictions into txt file.
293
+
294
+ Args:
295
+ txt_file (str): txt file path.
296
+ save_conf (bool): save confidence score or not.
297
+ """
298
+ boxes = self.boxes
299
+ masks = self.masks
300
+ probs = self.probs
301
+ kpts = self.keypoints
302
+ texts = []
303
+ if probs is not None:
304
+ # Classify
305
+ [texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
306
+ elif boxes:
307
+ # Detect/segment/pose
308
+ for j, d in enumerate(boxes):
309
+ c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
310
+ line = (c, *d.xywhn.view(-1))
311
+ if masks:
312
+ seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
313
+ line = (c, *seg)
314
+ if kpts is not None:
315
+ kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
316
+ line += (*kpt.reshape(-1).tolist(), )
317
+ line += (conf, ) * save_conf + (() if id is None else (id, ))
318
+ texts.append(('%g ' * len(line)).rstrip() % line)
319
+
320
+ if texts:
321
+ with open(txt_file, 'a') as f:
322
+ f.writelines(text + '\n' for text in texts)
323
+
324
+ def save_crop(self, save_dir, file_name=Path('im.jpg')):
325
+ """
326
+ Save cropped predictions to `save_dir/cls/file_name.jpg`.
327
+
328
+ Args:
329
+ save_dir (str | pathlib.Path): Save path.
330
+ file_name (str | pathlib.Path): File name.
331
+ """
332
+ if self.probs is not None:
333
+ LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
334
+ return
335
+ if isinstance(save_dir, str):
336
+ save_dir = Path(save_dir)
337
+ if isinstance(file_name, str):
338
+ file_name = Path(file_name)
339
+ for d in self.boxes:
340
+ save_one_box(d.xyxy,
341
+ self.orig_img.copy(),
342
+ file=save_dir / self.names[int(d.cls)] / f'{file_name.stem}.jpg',
343
+ BGR=True)
344
+
345
+ def tojson(self, normalize=False):
346
+ """Convert the object to JSON format."""
347
+ if self.probs is not None:
348
+ LOGGER.warning('Warning: Classify task do not support `tojson` yet.')
349
+ return
350
+
351
+ import json
352
+
353
+ # Create list of detection dictionaries
354
+ results = []
355
+ data = self.boxes.data.cpu().tolist()
356
+ h, w = self.orig_shape if normalize else (1, 1)
357
+ for i, row in enumerate(data):
358
+ box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h}
359
+ conf = row[4]
360
+ id = int(row[5])
361
+ name = self.names[id]
362
+ result = {'name': name, 'class': id, 'confidence': conf, 'box': box}
363
+ if self.masks:
364
+ x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
365
+ result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
366
+ if self.keypoints is not None:
367
+ x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
368
+ result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
369
+ results.append(result)
370
+
371
+ # Convert detections to JSON
372
+ return json.dumps(results, indent=2)
373
+
374
+
375
+ class Boxes(BaseTensor):
376
+ """
377
+ A class for storing and manipulating detection boxes.
378
+
379
+ Args:
380
+ boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
381
+ with shape (num_boxes, 6) or (num_boxes, 7). The last two columns contain confidence and class values.
382
+ If present, the third last column contains track IDs.
383
+ orig_shape (tuple): Original image size, in the format (height, width).
384
+
385
+ Attributes:
386
+ xyxy (torch.Tensor | numpy.ndarray): The boxes in xyxy format.
387
+ conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
388
+ cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
389
+ id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
390
+ xywh (torch.Tensor | numpy.ndarray): The boxes in xywh format.
391
+ xyxyn (torch.Tensor | numpy.ndarray): The boxes in xyxy format normalized by original image size.
392
+ xywhn (torch.Tensor | numpy.ndarray): The boxes in xywh format normalized by original image size.
393
+ data (torch.Tensor): The raw bboxes tensor (alias for `boxes`).
394
+
395
+ Methods:
396
+ cpu(): Move the object to CPU memory.
397
+ numpy(): Convert the object to a numpy array.
398
+ cuda(): Move the object to CUDA memory.
399
+ to(*args, **kwargs): Move the object to the specified device.
400
+ """
401
+
402
+ def __init__(self, boxes, orig_shape) -> None:
403
+ """Initialize the Boxes class."""
404
+ if boxes.ndim == 1:
405
+ boxes = boxes[None, :]
406
+ n = boxes.shape[-1]
407
+ assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
408
+ super().__init__(boxes, orig_shape)
409
+ self.is_track = n == 7
410
+ self.orig_shape = orig_shape
411
+
412
+ @property
413
+ def xyxy(self):
414
+ """Return the boxes in xyxy format."""
415
+ return self.data[:, :4]
416
+
417
+ @property
418
+ def conf(self):
419
+ """Return the confidence values of the boxes."""
420
+ return self.data[:, -2]
421
+
422
+ @property
423
+ def cls(self):
424
+ """Return the class values of the boxes."""
425
+ return self.data[:, -1]
426
+
427
+ @property
428
+ def id(self):
429
+ """Return the track IDs of the boxes (if available)."""
430
+ return self.data[:, -3] if self.is_track else None
431
+
432
+ @property
433
+ @lru_cache(maxsize=2) # maxsize 1 should suffice
434
+ def xywh(self):
435
+ """Return the boxes in xywh format."""
436
+ return ops.xyxy2xywh(self.xyxy)
437
+
438
+ @property
439
+ @lru_cache(maxsize=2)
440
+ def xyxyn(self):
441
+ """Return the boxes in xyxy format normalized by original image size."""
442
+ xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy)
443
+ xyxy[..., [0, 2]] /= self.orig_shape[1]
444
+ xyxy[..., [1, 3]] /= self.orig_shape[0]
445
+ return xyxy
446
+
447
+ @property
448
+ @lru_cache(maxsize=2)
449
+ def xywhn(self):
450
+ """Return the boxes in xywh format normalized by original image size."""
451
+ xywh = ops.xyxy2xywh(self.xyxy)
452
+ xywh[..., [0, 2]] /= self.orig_shape[1]
453
+ xywh[..., [1, 3]] /= self.orig_shape[0]
454
+ return xywh
455
+
456
+ @property
457
+ def boxes(self):
458
+ """Return the raw bboxes tensor (deprecated)."""
459
+ LOGGER.warning("WARNING ⚠️ 'Boxes.boxes' is deprecated. Use 'Boxes.data' instead.")
460
+ return self.data
461
+
462
+
463
+ class Masks(BaseTensor):
464
+ """
465
+ A class for storing and manipulating detection masks.
466
+
467
+ Attributes:
468
+ segments (list): Deprecated property for segments (normalized).
469
+ xy (list): A list of segments in pixel coordinates.
470
+ xyn (list): A list of normalized segments.
471
+
472
+ Methods:
473
+ cpu(): Returns the masks tensor on CPU memory.
474
+ numpy(): Returns the masks tensor as a numpy array.
475
+ cuda(): Returns the masks tensor on GPU memory.
476
+ to(device, dtype): Returns the masks tensor with the specified device and dtype.
477
+ """
478
+
479
+ def __init__(self, masks, orig_shape) -> None:
480
+ """Initialize the Masks class with the given masks tensor and original image shape."""
481
+ if masks.ndim == 2:
482
+ masks = masks[None, :]
483
+ super().__init__(masks, orig_shape)
484
+
485
+ @property
486
+ @lru_cache(maxsize=1)
487
+ def segments(self):
488
+ """Return segments (normalized). Deprecated; use xyn property instead."""
489
+ LOGGER.warning(
490
+ "WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and 'Masks.xy' for segments (pixels) instead."
491
+ )
492
+ return self.xyn
493
+
494
+ @property
495
+ @lru_cache(maxsize=1)
496
+ def xyn(self):
497
+ """Return normalized segments."""
498
+ return [
499
+ ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
500
+ for x in ops.masks2segments(self.data)]
501
+
502
+ @property
503
+ @lru_cache(maxsize=1)
504
+ def xy(self):
505
+ """Return segments in pixel coordinates."""
506
+ return [
507
+ ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
508
+ for x in ops.masks2segments(self.data)]
509
+
510
+ @property
511
+ def masks(self):
512
+ """Return the raw masks tensor. Deprecated; use data attribute instead."""
513
+ LOGGER.warning("WARNING ⚠️ 'Masks.masks' is deprecated. Use 'Masks.data' instead.")
514
+ return self.data
515
+
516
+
517
+ class Keypoints(BaseTensor):
518
+ """
519
+ A class for storing and manipulating detection keypoints.
520
+
521
+ Attributes:
522
+ xy (torch.Tensor): A collection of keypoints containing x, y coordinates for each detection.
523
+ xyn (torch.Tensor): A normalized version of xy with coordinates in the range [0, 1].
524
+ conf (torch.Tensor): Confidence values associated with keypoints if available, otherwise None.
525
+
526
+ Methods:
527
+ cpu(): Returns a copy of the keypoints tensor on CPU memory.
528
+ numpy(): Returns a copy of the keypoints tensor as a numpy array.
529
+ cuda(): Returns a copy of the keypoints tensor on GPU memory.
530
+ to(device, dtype): Returns a copy of the keypoints tensor with the specified device and dtype.
531
+ """
532
+
533
+ def __init__(self, keypoints, orig_shape) -> None:
534
+ """Initializes the Keypoints object with detection keypoints and original image size."""
535
+ if keypoints.ndim == 2:
536
+ keypoints = keypoints[None, :]
537
+ super().__init__(keypoints, orig_shape)
538
+ self.has_visible = self.data.shape[-1] == 3
539
+
540
+ @property
541
+ @lru_cache(maxsize=1)
542
+ def xy(self):
543
+ """Returns x, y coordinates of keypoints."""
544
+ return self.data[..., :2]
545
+
546
+ @property
547
+ @lru_cache(maxsize=1)
548
+ def xyn(self):
549
+ """Returns normalized x, y coordinates of keypoints."""
550
+ xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
551
+ xy[..., 0] /= self.orig_shape[1]
552
+ xy[..., 1] /= self.orig_shape[0]
553
+ return xy
554
+
555
+ @property
556
+ @lru_cache(maxsize=1)
557
+ def conf(self):
558
+ """Returns confidence values of keypoints if available, else None."""
559
+ return self.data[..., 2] if self.has_visible else None
560
+
561
+
562
+ class Probs(BaseTensor):
563
+ """
564
+ A class for storing and manipulating classification predictions.
565
+
566
+ Attributes:
567
+ top1 (int): Index of the top 1 class.
568
+ top5 (list[int]): Indices of the top 5 classes.
569
+ top1conf (torch.Tensor): Confidence of the top 1 class.
570
+ top5conf (torch.Tensor): Confidences of the top 5 classes.
571
+
572
+ Methods:
573
+ cpu(): Returns a copy of the probs tensor on CPU memory.
574
+ numpy(): Returns a copy of the probs tensor as a numpy array.
575
+ cuda(): Returns a copy of the probs tensor on GPU memory.
576
+ to(): Returns a copy of the probs tensor with the specified device and dtype.
577
+ """
578
+
579
+ def __init__(self, probs, orig_shape=None) -> None:
580
+ super().__init__(probs, orig_shape)
581
+
582
+ @property
583
+ @lru_cache(maxsize=1)
584
+ def top1(self):
585
+ """Return the index of top 1."""
586
+ return int(self.data.argmax())
587
+
588
+ @property
589
+ @lru_cache(maxsize=1)
590
+ def top5(self):
591
+ """Return the indices of top 5."""
592
+ return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy.
593
+
594
+ @property
595
+ @lru_cache(maxsize=1)
596
+ def top1conf(self):
597
+ """Return the confidence of top 1."""
598
+ return self.data[self.top1]
599
+
600
+ @property
601
+ @lru_cache(maxsize=1)
602
+ def top5conf(self):
603
+ """Return the confidences of top 5."""
604
+ return self.data[self.top5]
ultralytics/engine/trainer.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Train a model on a dataset
4
+
5
+ Usage:
6
+ $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
7
+ """
8
+ import math
9
+ import os
10
+ import subprocess
11
+ import time
12
+ from copy import deepcopy
13
+ from datetime import datetime, timedelta
14
+ from pathlib import Path
15
+
16
+ import numpy as np
17
+ import torch
18
+ from torch import distributed as dist
19
+ from torch import nn, optim
20
+ from torch.cuda import amp
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+ from tqdm import tqdm
23
+
24
+ from ultralytics.cfg import get_cfg
25
+ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
26
+ from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
27
+ from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, clean_url,
28
+ colorstr, emojis, yaml_save)
29
+ from ultralytics.utils.autobatch import check_train_batch_size
30
+ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
31
+ from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
32
+ from ultralytics.utils.files import get_latest_run, increment_path
33
+ from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
34
+ strip_optimizer)
35
+
36
+
37
+ class BaseTrainer:
38
+ """
39
+ BaseTrainer
40
+
41
+ A base class for creating trainers.
42
+
43
+ Attributes:
44
+ args (SimpleNamespace): Configuration for the trainer.
45
+ check_resume (method): Method to check if training should be resumed from a saved checkpoint.
46
+ validator (BaseValidator): Validator instance.
47
+ model (nn.Module): Model instance.
48
+ callbacks (defaultdict): Dictionary of callbacks.
49
+ save_dir (Path): Directory to save results.
50
+ wdir (Path): Directory to save weights.
51
+ last (Path): Path to last checkpoint.
52
+ best (Path): Path to best checkpoint.
53
+ save_period (int): Save checkpoint every x epochs (disabled if < 1).
54
+ batch_size (int): Batch size for training.
55
+ epochs (int): Number of epochs to train for.
56
+ start_epoch (int): Starting epoch for training.
57
+ device (torch.device): Device to use for training.
58
+ amp (bool): Flag to enable AMP (Automatic Mixed Precision).
59
+ scaler (amp.GradScaler): Gradient scaler for AMP.
60
+ data (str): Path to data.
61
+ trainset (torch.utils.data.Dataset): Training dataset.
62
+ testset (torch.utils.data.Dataset): Testing dataset.
63
+ ema (nn.Module): EMA (Exponential Moving Average) of the model.
64
+ lf (nn.Module): Loss function.
65
+ scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
66
+ best_fitness (float): The best fitness value achieved.
67
+ fitness (float): Current fitness value.
68
+ loss (float): Current loss value.
69
+ tloss (float): Total loss value.
70
+ loss_names (list): List of loss names.
71
+ csv (Path): Path to results CSV file.
72
+ """
73
+
74
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
75
+ """
76
+ Initializes the BaseTrainer class.
77
+
78
+ Args:
79
+ cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
80
+ overrides (dict, optional): Configuration overrides. Defaults to None.
81
+ """
82
+ self.args = get_cfg(cfg, overrides)
83
+ self.device = select_device(self.args.device, self.args.batch)
84
+ self.check_resume()
85
+ self.validator = None
86
+ self.model = None
87
+ self.metrics = None
88
+ self.plots = {}
89
+ init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
90
+
91
+ # Dirs
92
+ project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
93
+ name = self.args.name or f'{self.args.mode}'
94
+ if hasattr(self.args, 'save_dir'):
95
+ self.save_dir = Path(self.args.save_dir)
96
+ else:
97
+ self.save_dir = Path(
98
+ increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
99
+ self.wdir = self.save_dir / 'weights' # weights dir
100
+ if RANK in (-1, 0):
101
+ self.wdir.mkdir(parents=True, exist_ok=True) # make dir
102
+ self.args.save_dir = str(self.save_dir)
103
+ yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
104
+ self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
105
+ self.save_period = self.args.save_period
106
+
107
+ self.batch_size = self.args.batch
108
+ self.epochs = self.args.epochs
109
+ self.start_epoch = 0
110
+ if RANK == -1:
111
+ print_args(vars(self.args))
112
+
113
+ # Device
114
+ if self.device.type == 'cpu':
115
+ self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
116
+
117
+ # Model and Dataset
118
+ self.model = self.args.model
119
+ try:
120
+ if self.args.task == 'classify':
121
+ self.data = check_cls_dataset(self.args.data)
122
+ elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment'):
123
+ self.data = check_det_dataset(self.args.data)
124
+ if 'yaml_file' in self.data:
125
+ self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
126
+ except Exception as e:
127
+ raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
128
+
129
+ self.trainset, self.testset = self.get_dataset(self.data)
130
+ self.ema = None
131
+
132
+ # Optimization utils init
133
+ self.lf = None
134
+ self.scheduler = None
135
+
136
+ # Epoch level metrics
137
+ self.best_fitness = None
138
+ self.fitness = None
139
+ self.loss = None
140
+ self.tloss = None
141
+ self.loss_names = ['Loss']
142
+ self.csv = self.save_dir / 'results.csv'
143
+ self.plot_idx = [0, 1, 2]
144
+
145
+ # Callbacks
146
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
147
+ if RANK in (-1, 0):
148
+ callbacks.add_integration_callbacks(self)
149
+
150
+ def add_callback(self, event: str, callback):
151
+ """
152
+ Appends the given callback.
153
+ """
154
+ self.callbacks[event].append(callback)
155
+
156
+ def set_callback(self, event: str, callback):
157
+ """
158
+ Overrides the existing callbacks with the given callback.
159
+ """
160
+ self.callbacks[event] = [callback]
161
+
162
+ def run_callbacks(self, event: str):
163
+ """Run all existing callbacks associated with a particular event."""
164
+ for callback in self.callbacks.get(event, []):
165
+ callback(self)
166
+
167
+ def train(self):
168
+ """Allow device='', device=None on Multi-GPU systems to default to device=0."""
169
+ if isinstance(self.args.device, int) or self.args.device: # i.e. device=0 or device=[0,1,2,3]
170
+ world_size = torch.cuda.device_count()
171
+ elif torch.cuda.is_available(): # i.e. device=None or device=''
172
+ world_size = 1 # default to device 0
173
+ else: # i.e. device='cpu' or 'mps'
174
+ world_size = 0
175
+
176
+ # Run subprocess if DDP training, else train normally
177
+ if world_size > 1 and 'LOCAL_RANK' not in os.environ:
178
+ # Argument checks
179
+ if self.args.rect:
180
+ LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting rect=False")
181
+ self.args.rect = False
182
+ # Command
183
+ cmd, file = generate_ddp_command(world_size, self)
184
+ try:
185
+ LOGGER.info(f'DDP command: {cmd}')
186
+ subprocess.run(cmd, check=True)
187
+ except Exception as e:
188
+ raise e
189
+ finally:
190
+ ddp_cleanup(self, str(file))
191
+ else:
192
+ self._do_train(world_size)
193
+
194
+ def _setup_ddp(self, world_size):
195
+ """Initializes and sets the DistributedDataParallel parameters for training."""
196
+ torch.cuda.set_device(RANK)
197
+ self.device = torch.device('cuda', RANK)
198
+ LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
199
+ os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
200
+ dist.init_process_group(
201
+ 'nccl' if dist.is_nccl_available() else 'gloo',
202
+ timeout=timedelta(seconds=10800), # 3 hours
203
+ rank=RANK,
204
+ world_size=world_size)
205
+
206
+ def _setup_train(self, world_size):
207
+ """
208
+ Builds dataloaders and optimizer on correct rank process.
209
+ """
210
+ # Model
211
+ self.run_callbacks('on_pretrain_routine_start')
212
+ ckpt = self.setup_model()
213
+ self.model = self.model.to(self.device)
214
+ self.set_model_attributes()
215
+ # Check AMP
216
+ self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
217
+ if self.amp and RANK in (-1, 0): # Single-GPU and DDP
218
+ callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
219
+ self.amp = torch.tensor(check_amp(self.model), device=self.device)
220
+ callbacks.default_callbacks = callbacks_backup # restore callbacks
221
+ if RANK > -1 and world_size > 1: # DDP
222
+ dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
223
+ self.amp = bool(self.amp) # as boolean
224
+ self.scaler = amp.GradScaler(enabled=self.amp)
225
+ if world_size > 1:
226
+ self.model = DDP(self.model, device_ids=[RANK])
227
+ # Check imgsz
228
+ gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
229
+ self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
230
+ # Batch size
231
+ if self.batch_size == -1:
232
+ if RANK == -1: # single-GPU only, estimate best batch size
233
+ self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
234
+ else:
235
+ SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
236
+ 'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
237
+
238
+ # Dataloaders
239
+ batch_size = self.batch_size // max(world_size, 1)
240
+ self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
241
+ if RANK in (-1, 0):
242
+ self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
243
+ self.validator = self.get_validator()
244
+ metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
245
+ self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
246
+ self.ema = ModelEMA(self.model)
247
+ if self.args.plots:
248
+ self.plot_training_labels()
249
+
250
+ # Optimizer
251
+ self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
252
+ weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
253
+ iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
254
+ self.optimizer = self.build_optimizer(model=self.model,
255
+ name=self.args.optimizer,
256
+ lr=self.args.lr0,
257
+ momentum=self.args.momentum,
258
+ decay=weight_decay,
259
+ iterations=iterations)
260
+ # Scheduler
261
+ if self.args.cos_lr:
262
+ self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
263
+ else:
264
+ self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
265
+ self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
266
+ self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
267
+ self.resume_training(ckpt)
268
+ self.scheduler.last_epoch = self.start_epoch - 1 # do not move
269
+ self.run_callbacks('on_pretrain_routine_end')
270
+
271
+ def _do_train(self, world_size=1):
272
+ """Train completed, evaluate and plot if specified by arguments."""
273
+ if world_size > 1:
274
+ self._setup_ddp(world_size)
275
+
276
+ self._setup_train(world_size)
277
+
278
+ self.epoch_time = None
279
+ self.epoch_time_start = time.time()
280
+ self.train_time_start = time.time()
281
+ nb = len(self.train_loader) # number of batches
282
+ nw = max(round(self.args.warmup_epochs *
283
+ nb), 100) if self.args.warmup_epochs > 0 else -1 # number of warmup iterations
284
+ last_opt_step = -1
285
+ self.run_callbacks('on_train_start')
286
+ LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
287
+ f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
288
+ f"Logging results to {colorstr('bold', self.save_dir)}\n"
289
+ f'Starting training for {self.epochs} epochs...')
290
+ if self.args.close_mosaic:
291
+ base_idx = (self.epochs - self.args.close_mosaic) * nb
292
+ self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
293
+ epoch = self.epochs # predefine for resume fully trained model edge cases
294
+ for epoch in range(self.start_epoch, self.epochs):
295
+ self.epoch = epoch
296
+ self.run_callbacks('on_train_epoch_start')
297
+ self.model.train()
298
+ if RANK != -1:
299
+ self.train_loader.sampler.set_epoch(epoch)
300
+ pbar = enumerate(self.train_loader)
301
+ # Update dataloader attributes (optional)
302
+ if epoch == (self.epochs - self.args.close_mosaic):
303
+ LOGGER.info('Closing dataloader mosaic')
304
+ if hasattr(self.train_loader.dataset, 'mosaic'):
305
+ self.train_loader.dataset.mosaic = False
306
+ if hasattr(self.train_loader.dataset, 'close_mosaic'):
307
+ self.train_loader.dataset.close_mosaic(hyp=self.args)
308
+ self.train_loader.reset()
309
+
310
+ if RANK in (-1, 0):
311
+ LOGGER.info(self.progress_string())
312
+ pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
313
+ self.tloss = None
314
+ self.optimizer.zero_grad()
315
+ for i, batch in pbar:
316
+ self.run_callbacks('on_train_batch_start')
317
+ # Warmup
318
+ ni = i + nb * epoch
319
+ if ni <= nw:
320
+ xi = [0, nw] # x interp
321
+ self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
322
+ for j, x in enumerate(self.optimizer.param_groups):
323
+ # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
324
+ x['lr'] = np.interp(
325
+ ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
326
+ if 'momentum' in x:
327
+ x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
328
+
329
+ # Forward
330
+ with torch.cuda.amp.autocast(self.amp):
331
+ batch = self.preprocess_batch(batch)
332
+ self.loss, self.loss_items = self.model(batch)
333
+ if RANK != -1:
334
+ self.loss *= world_size
335
+ self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
336
+ else self.loss_items
337
+
338
+ # Backward
339
+ self.scaler.scale(self.loss).backward()
340
+
341
+ # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
342
+ if ni - last_opt_step >= self.accumulate:
343
+ self.optimizer_step()
344
+ last_opt_step = ni
345
+
346
+ # Log
347
+ mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
348
+ loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
349
+ losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
350
+ if RANK in (-1, 0):
351
+ pbar.set_description(
352
+ ('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
353
+ (f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
354
+ self.run_callbacks('on_batch_end')
355
+ if self.args.plots and ni in self.plot_idx:
356
+ self.plot_training_samples(batch, ni)
357
+
358
+ self.run_callbacks('on_train_batch_end')
359
+
360
+ self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
361
+
362
+ self.scheduler.step()
363
+ self.run_callbacks('on_train_epoch_end')
364
+
365
+ if RANK in (-1, 0):
366
+
367
+ # Validation
368
+ self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
369
+ final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
370
+
371
+ if self.args.val or final_epoch:
372
+ self.metrics, self.fitness = self.validate()
373
+ self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
374
+ self.stop = self.stopper(epoch + 1, self.fitness)
375
+
376
+ # Save model
377
+ if self.args.save or (epoch + 1 == self.epochs):
378
+ self.save_model()
379
+ self.run_callbacks('on_model_save')
380
+
381
+ tnow = time.time()
382
+ self.epoch_time = tnow - self.epoch_time_start
383
+ self.epoch_time_start = tnow
384
+ self.run_callbacks('on_fit_epoch_end')
385
+ torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
386
+
387
+ # Early Stopping
388
+ if RANK != -1: # if DDP training
389
+ broadcast_list = [self.stop if RANK == 0 else None]
390
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
391
+ if RANK != 0:
392
+ self.stop = broadcast_list[0]
393
+ if self.stop:
394
+ break # must break all DDP ranks
395
+
396
+ if RANK in (-1, 0):
397
+ # Do final val with best.pt
398
+ LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
399
+ f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
400
+ self.final_eval()
401
+ if self.args.plots:
402
+ self.plot_metrics()
403
+ self.run_callbacks('on_train_end')
404
+ torch.cuda.empty_cache()
405
+ self.run_callbacks('teardown')
406
+
407
+ def save_model(self):
408
+ """Save model checkpoints based on various conditions."""
409
+ ckpt = {
410
+ 'epoch': self.epoch,
411
+ 'best_fitness': self.best_fitness,
412
+ 'model': deepcopy(de_parallel(self.model)).half(),
413
+ 'ema': deepcopy(self.ema.ema).half(),
414
+ 'updates': self.ema.updates,
415
+ 'optimizer': self.optimizer.state_dict(),
416
+ 'train_args': vars(self.args), # save as dict
417
+ 'date': datetime.now().isoformat(),
418
+ 'version': __version__}
419
+
420
+ # Use dill (if exists) to serialize the lambda functions where pickle does not do this
421
+ try:
422
+ import dill as pickle
423
+ except ImportError:
424
+ import pickle
425
+
426
+ # Save last, best and delete
427
+ torch.save(ckpt, self.last, pickle_module=pickle)
428
+ if self.best_fitness == self.fitness:
429
+ torch.save(ckpt, self.best, pickle_module=pickle)
430
+ if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
431
+ torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
432
+ del ckpt
433
+
434
+ @staticmethod
435
+ def get_dataset(data):
436
+ """
437
+ Get train, val path from data dict if it exists. Returns None if data format is not recognized.
438
+ """
439
+ return data['train'], data.get('val') or data.get('test')
440
+
441
+ def setup_model(self):
442
+ """
443
+ load/create/download model for any task.
444
+ """
445
+ if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
446
+ return
447
+
448
+ model, weights = self.model, None
449
+ ckpt = None
450
+ if str(model).endswith('.pt'):
451
+ weights, ckpt = attempt_load_one_weight(model)
452
+ cfg = ckpt['model'].yaml
453
+ else:
454
+ cfg = model
455
+ self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
456
+ return ckpt
457
+
458
+ def optimizer_step(self):
459
+ """Perform a single step of the training optimizer with gradient clipping and EMA update."""
460
+ self.scaler.unscale_(self.optimizer) # unscale gradients
461
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
462
+ self.scaler.step(self.optimizer)
463
+ self.scaler.update()
464
+ self.optimizer.zero_grad()
465
+ if self.ema:
466
+ self.ema.update(self.model)
467
+
468
+ def preprocess_batch(self, batch):
469
+ """
470
+ Allows custom preprocessing model inputs and ground truths depending on task type.
471
+ """
472
+ return batch
473
+
474
+ def validate(self):
475
+ """
476
+ Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
477
+ """
478
+ metrics = self.validator(self)
479
+ fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
480
+ if not self.best_fitness or self.best_fitness < fitness:
481
+ self.best_fitness = fitness
482
+ return metrics, fitness
483
+
484
+ def get_model(self, cfg=None, weights=None, verbose=True):
485
+ """Get model and raise NotImplementedError for loading cfg files."""
486
+ raise NotImplementedError("This task trainer doesn't support loading cfg files")
487
+
488
+ def get_validator(self):
489
+ """Returns a NotImplementedError when the get_validator function is called."""
490
+ raise NotImplementedError('get_validator function not implemented in trainer')
491
+
492
+ def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
493
+ """
494
+ Returns dataloader derived from torch.data.Dataloader.
495
+ """
496
+ raise NotImplementedError('get_dataloader function not implemented in trainer')
497
+
498
+ def build_dataset(self, img_path, mode='train', batch=None):
499
+ """Build dataset"""
500
+ raise NotImplementedError('build_dataset function not implemented in trainer')
501
+
502
+ def label_loss_items(self, loss_items=None, prefix='train'):
503
+ """
504
+ Returns a loss dict with labelled training loss items tensor
505
+ """
506
+ # Not needed for classification but necessary for segmentation & detection
507
+ return {'loss': loss_items} if loss_items is not None else ['loss']
508
+
509
+ def set_model_attributes(self):
510
+ """
511
+ To set or update model parameters before training.
512
+ """
513
+ self.model.names = self.data['names']
514
+
515
+ def build_targets(self, preds, targets):
516
+ """Builds target tensors for training YOLO model."""
517
+ pass
518
+
519
+ def progress_string(self):
520
+ """Returns a string describing training progress."""
521
+ return ''
522
+
523
+ # TODO: may need to put these following functions into callback
524
+ def plot_training_samples(self, batch, ni):
525
+ """Plots training samples during YOLOv5 training."""
526
+ pass
527
+
528
+ def plot_training_labels(self):
529
+ """Plots training labels for YOLO model."""
530
+ pass
531
+
532
+ def save_metrics(self, metrics):
533
+ """Saves training metrics to a CSV file."""
534
+ keys, vals = list(metrics.keys()), list(metrics.values())
535
+ n = len(metrics) + 1 # number of cols
536
+ s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
537
+ with open(self.csv, 'a') as f:
538
+ f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
539
+
540
+ def plot_metrics(self):
541
+ """Plot and display metrics visually."""
542
+ pass
543
+
544
+ def on_plot(self, name, data=None):
545
+ """Registers plots (e.g. to be consumed in callbacks)"""
546
+ self.plots[name] = {'data': data, 'timestamp': time.time()}
547
+
548
+ def final_eval(self):
549
+ """Performs final evaluation and validation for object detection YOLO model."""
550
+ for f in self.last, self.best:
551
+ if f.exists():
552
+ strip_optimizer(f) # strip optimizers
553
+ if f is self.best:
554
+ LOGGER.info(f'\nValidating {f}...')
555
+ self.metrics = self.validator(model=f)
556
+ self.metrics.pop('fitness', None)
557
+ self.run_callbacks('on_fit_epoch_end')
558
+
559
+ def check_resume(self):
560
+ """Check if resume checkpoint exists and update arguments accordingly."""
561
+ resume = self.args.resume
562
+ if resume:
563
+ try:
564
+ exists = isinstance(resume, (str, Path)) and Path(resume).exists()
565
+ last = Path(check_file(resume) if exists else get_latest_run())
566
+
567
+ # Check that resume data YAML exists, otherwise strip to force re-download of dataset
568
+ ckpt_args = attempt_load_weights(last).args
569
+ if not Path(ckpt_args['data']).exists():
570
+ ckpt_args['data'] = self.args.data
571
+
572
+ self.args = get_cfg(ckpt_args)
573
+ self.args.model, resume = str(last), True # reinstate
574
+ except Exception as e:
575
+ raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
576
+ "i.e. 'yolo train resume model=path/to/last.pt'") from e
577
+ self.resume = resume
578
+
579
+ def resume_training(self, ckpt):
580
+ """Resume YOLO training from given epoch and best fitness."""
581
+ if ckpt is None:
582
+ return
583
+ best_fitness = 0.0
584
+ start_epoch = ckpt['epoch'] + 1
585
+ if ckpt['optimizer'] is not None:
586
+ self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
587
+ best_fitness = ckpt['best_fitness']
588
+ if self.ema and ckpt.get('ema'):
589
+ self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
590
+ self.ema.updates = ckpt['updates']
591
+ if self.resume:
592
+ assert start_epoch > 0, \
593
+ f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
594
+ f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
595
+ LOGGER.info(
596
+ f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
597
+ if self.epochs < start_epoch:
598
+ LOGGER.info(
599
+ f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
600
+ self.epochs += ckpt['epoch'] # finetune additional epochs
601
+ self.best_fitness = best_fitness
602
+ self.start_epoch = start_epoch
603
+ if start_epoch > (self.epochs - self.args.close_mosaic):
604
+ LOGGER.info('Closing dataloader mosaic')
605
+ if hasattr(self.train_loader.dataset, 'mosaic'):
606
+ self.train_loader.dataset.mosaic = False
607
+ if hasattr(self.train_loader.dataset, 'close_mosaic'):
608
+ self.train_loader.dataset.close_mosaic(hyp=self.args)
609
+
610
+ def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
611
+ """
612
+ Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
613
+ momentum, weight decay, and number of iterations.
614
+
615
+ Args:
616
+ model (torch.nn.Module): The model for which to build an optimizer.
617
+ name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
618
+ based on the number of iterations. Default: 'auto'.
619
+ lr (float, optional): The learning rate for the optimizer. Default: 0.001.
620
+ momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
621
+ decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
622
+ iterations (float, optional): The number of iterations, which determines the optimizer if
623
+ name is 'auto'. Default: 1e5.
624
+
625
+ Returns:
626
+ (torch.optim.Optimizer): The constructed optimizer.
627
+ """
628
+
629
+ g = [], [], [] # optimizer parameter groups
630
+ bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
631
+ if name == 'auto':
632
+ nc = getattr(model, 'nc', 10) # number of classes
633
+ lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
634
+ name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
635
+ self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
636
+
637
+ for module_name, module in model.named_modules():
638
+ for param_name, param in module.named_parameters(recurse=False):
639
+ fullname = f'{module_name}.{param_name}' if module_name else param_name
640
+ if 'bias' in fullname: # bias (no decay)
641
+ g[2].append(param)
642
+ elif isinstance(module, bn): # weight (no decay)
643
+ g[1].append(param)
644
+ else: # weight (with decay)
645
+ g[0].append(param)
646
+
647
+ if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
648
+ optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
649
+ elif name == 'RMSProp':
650
+ optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
651
+ elif name == 'SGD':
652
+ optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
653
+ else:
654
+ raise NotImplementedError(
655
+ f"Optimizer '{name}' not found in list of available optimizers "
656
+ f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
657
+ 'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
658
+
659
+ optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
660
+ optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
661
+ LOGGER.info(
662
+ f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
663
+ f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
664
+ return optimizer
ultralytics/engine/validator.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Check a model's accuracy on a test or val split of a dataset
4
+
5
+ Usage:
6
+ $ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
7
+
8
+ Usage - formats:
9
+ $ yolo mode=val model=yolov8n.pt # PyTorch
10
+ yolov8n.torchscript # TorchScript
11
+ yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
12
+ yolov8n_openvino_model # OpenVINO
13
+ yolov8n.engine # TensorRT
14
+ yolov8n.mlmodel # CoreML (macOS-only)
15
+ yolov8n_saved_model # TensorFlow SavedModel
16
+ yolov8n.pb # TensorFlow GraphDef
17
+ yolov8n.tflite # TensorFlow Lite
18
+ yolov8n_edgetpu.tflite # TensorFlow Edge TPU
19
+ yolov8n_paddle_model # PaddlePaddle
20
+ """
21
+ import json
22
+ import time
23
+ from pathlib import Path
24
+
25
+ import torch
26
+ from tqdm import tqdm
27
+
28
+ from ultralytics.cfg import get_cfg
29
+ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
30
+ from ultralytics.nn.autobackend import AutoBackend
31
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
32
+ from ultralytics.utils.checks import check_imgsz
33
+ from ultralytics.utils.files import increment_path
34
+ from ultralytics.utils.ops import Profile
35
+ from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
36
+
37
+
38
+ class BaseValidator:
39
+ """
40
+ BaseValidator
41
+
42
+ A base class for creating validators.
43
+
44
+ Attributes:
45
+ dataloader (DataLoader): Dataloader to use for validation.
46
+ pbar (tqdm): Progress bar to update during validation.
47
+ args (SimpleNamespace): Configuration for the validator.
48
+ model (nn.Module): Model to validate.
49
+ data (dict): Data dictionary.
50
+ device (torch.device): Device to use for validation.
51
+ batch_i (int): Current batch index.
52
+ training (bool): Whether the model is in training mode.
53
+ speed (float): Batch processing speed in seconds.
54
+ jdict (dict): Dictionary to store validation results.
55
+ save_dir (Path): Directory to save results.
56
+ """
57
+
58
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
59
+ """
60
+ Initializes a BaseValidator instance.
61
+
62
+ Args:
63
+ dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
64
+ save_dir (Path): Directory to save results.
65
+ pbar (tqdm.tqdm): Progress bar for displaying progress.
66
+ args (SimpleNamespace): Configuration for the validator.
67
+ """
68
+ self.dataloader = dataloader
69
+ self.pbar = pbar
70
+ self.args = args or get_cfg(DEFAULT_CFG)
71
+ self.model = None
72
+ self.data = None
73
+ self.device = None
74
+ self.batch_i = None
75
+ self.training = True
76
+ self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
77
+ self.jdict = None
78
+
79
+ project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
80
+ name = self.args.name or f'{self.args.mode}'
81
+ self.save_dir = save_dir or increment_path(Path(project) / name,
82
+ exist_ok=self.args.exist_ok if RANK in (-1, 0) else True)
83
+ (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
84
+
85
+ if self.args.conf is None:
86
+ self.args.conf = 0.001 # default conf=0.001
87
+
88
+ self.plots = {}
89
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
90
+
91
+ @smart_inference_mode()
92
+ def __call__(self, trainer=None, model=None):
93
+ """
94
+ Supports validation of a pre-trained model if passed or a model being trained
95
+ if trainer is passed (trainer gets priority).
96
+ """
97
+ self.training = trainer is not None
98
+ augment = self.args.augment and (not self.training)
99
+ if self.training:
100
+ self.device = trainer.device
101
+ self.data = trainer.data
102
+ model = trainer.ema.ema or trainer.model
103
+ self.args.half = self.device.type != 'cpu' # force FP16 val during training
104
+ model = model.half() if self.args.half else model.float()
105
+ self.model = model
106
+ self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
107
+ self.args.plots = trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
108
+ model.eval()
109
+ else:
110
+ callbacks.add_integration_callbacks(self)
111
+ self.run_callbacks('on_val_start')
112
+ assert model is not None, 'Either trainer or model is needed for validation'
113
+ model = AutoBackend(model,
114
+ device=select_device(self.args.device, self.args.batch),
115
+ dnn=self.args.dnn,
116
+ data=self.args.data,
117
+ fp16=self.args.half)
118
+ self.model = model
119
+ self.device = model.device # update device
120
+ self.args.half = model.fp16 # update half
121
+ stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
122
+ imgsz = check_imgsz(self.args.imgsz, stride=stride)
123
+ if engine:
124
+ self.args.batch = model.batch_size
125
+ elif not pt and not jit:
126
+ self.args.batch = 1 # export.py models default to batch-size 1
127
+ LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
128
+
129
+ if isinstance(self.args.data, str) and self.args.data.split('.')[-1] in ('yaml', 'yml'):
130
+ self.data = check_det_dataset(self.args.data)
131
+ elif self.args.task == 'classify':
132
+ self.data = check_cls_dataset(self.args.data, split=self.args.split)
133
+ else:
134
+ raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
135
+
136
+ if self.device.type == 'cpu':
137
+ self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
138
+ if not pt:
139
+ self.args.rect = False
140
+ self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
141
+
142
+ model.eval()
143
+ model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
144
+
145
+ dt = Profile(), Profile(), Profile(), Profile()
146
+ n_batches = len(self.dataloader)
147
+ desc = self.get_desc()
148
+ # NOTE: keeping `not self.training` in tqdm will eliminate pbar after segmentation evaluation during training,
149
+ # which may affect classification task since this arg is in yolov5/classify/val.py.
150
+ # bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
151
+ bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
152
+ self.init_metrics(de_parallel(model))
153
+ self.jdict = [] # empty before each val
154
+ for batch_i, batch in enumerate(bar):
155
+ self.run_callbacks('on_val_batch_start')
156
+ self.batch_i = batch_i
157
+ # Preprocess
158
+ with dt[0]:
159
+ batch = self.preprocess(batch)
160
+
161
+ # Inference
162
+ with dt[1]:
163
+ preds = model(batch['img'], augment=augment)
164
+
165
+ # Loss
166
+ with dt[2]:
167
+ if self.training:
168
+ self.loss += model.loss(batch, preds)[1]
169
+
170
+ # Postprocess
171
+ with dt[3]:
172
+ preds = self.postprocess(preds)
173
+
174
+ self.update_metrics(preds, batch)
175
+ if self.args.plots and batch_i < 3:
176
+ self.plot_val_samples(batch, batch_i)
177
+ self.plot_predictions(batch, preds, batch_i)
178
+
179
+ self.run_callbacks('on_val_batch_end')
180
+ stats = self.get_stats()
181
+ self.check_stats(stats)
182
+ self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
183
+ self.finalize_metrics()
184
+ self.print_results()
185
+ self.run_callbacks('on_val_end')
186
+ if self.training:
187
+ model.float()
188
+ results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
189
+ return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
190
+ else:
191
+ LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
192
+ tuple(self.speed.values()))
193
+ if self.args.save_json and self.jdict:
194
+ with open(str(self.save_dir / 'predictions.json'), 'w') as f:
195
+ LOGGER.info(f'Saving {f.name}...')
196
+ json.dump(self.jdict, f) # flatten and save
197
+ stats = self.eval_json(stats) # update stats
198
+ if self.args.plots or self.args.save_json:
199
+ LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
200
+ return stats
201
+
202
+ def add_callback(self, event: str, callback):
203
+ """Appends the given callback."""
204
+ self.callbacks[event].append(callback)
205
+
206
+ def run_callbacks(self, event: str):
207
+ """Runs all callbacks associated with a specified event."""
208
+ for callback in self.callbacks.get(event, []):
209
+ callback(self)
210
+
211
+ def get_dataloader(self, dataset_path, batch_size):
212
+ """Get data loader from dataset path and batch size."""
213
+ raise NotImplementedError('get_dataloader function not implemented for this validator')
214
+
215
+ def build_dataset(self, img_path):
216
+ """Build dataset"""
217
+ raise NotImplementedError('build_dataset function not implemented in validator')
218
+
219
+ def preprocess(self, batch):
220
+ """Preprocesses an input batch."""
221
+ return batch
222
+
223
+ def postprocess(self, preds):
224
+ """Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
225
+ return preds
226
+
227
+ def init_metrics(self, model):
228
+ """Initialize performance metrics for the YOLO model."""
229
+ pass
230
+
231
+ def update_metrics(self, preds, batch):
232
+ """Updates metrics based on predictions and batch."""
233
+ pass
234
+
235
+ def finalize_metrics(self, *args, **kwargs):
236
+ """Finalizes and returns all metrics."""
237
+ pass
238
+
239
+ def get_stats(self):
240
+ """Returns statistics about the model's performance."""
241
+ return {}
242
+
243
+ def check_stats(self, stats):
244
+ """Checks statistics."""
245
+ pass
246
+
247
+ def print_results(self):
248
+ """Prints the results of the model's predictions."""
249
+ pass
250
+
251
+ def get_desc(self):
252
+ """Get description of the YOLO model."""
253
+ pass
254
+
255
+ @property
256
+ def metric_keys(self):
257
+ """Returns the metric keys used in YOLO training/validation."""
258
+ return []
259
+
260
+ def on_plot(self, name, data=None):
261
+ """Registers plots (e.g. to be consumed in callbacks)"""
262
+ self.plots[name] = {'data': data, 'timestamp': time.time()}
263
+
264
+ # TODO: may need to put these following functions into callback
265
+ def plot_val_samples(self, batch, ni):
266
+ """Plots validation samples during training."""
267
+ pass
268
+
269
+ def plot_predictions(self, batch, preds, ni):
270
+ """Plots YOLO model predictions on batch images."""
271
+ pass
272
+
273
+ def pred_to_json(self, preds, batch):
274
+ """Convert predictions to JSON format."""
275
+ pass
276
+
277
+ def eval_json(self, stats):
278
+ """Evaluate and return JSON format of prediction statistics."""
279
+ pass
ultralytics/hub/__init__.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import requests
4
+
5
+ from ultralytics.data.utils import HUBDatasetStats
6
+ from ultralytics.hub.auth import Auth
7
+ from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
8
+ from ultralytics.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
9
+
10
+
11
+ def login(api_key=''):
12
+ """
13
+ Log in to the Ultralytics HUB API using the provided API key.
14
+
15
+ Args:
16
+ api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
17
+
18
+ Example:
19
+ ```python
20
+ from ultralytics import hub
21
+ hub.login('API_KEY')
22
+ ```
23
+ """
24
+ Auth(api_key, verbose=True)
25
+
26
+
27
+ def logout():
28
+ """
29
+ Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'.
30
+
31
+ Example:
32
+ ```python
33
+ from ultralytics import hub
34
+ hub.logout()
35
+ ```
36
+ """
37
+ SETTINGS['api_key'] = ''
38
+ yaml_save(USER_CONFIG_DIR / 'settings.yaml', SETTINGS)
39
+ LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
40
+
41
+
42
+ def start(key=''):
43
+ """
44
+ Start training models with Ultralytics HUB (DEPRECATED).
45
+
46
+ Args:
47
+ key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
48
+ or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
49
+ """
50
+ api_key, model_id = key.split('_')
51
+ LOGGER.warning(f"""
52
+ WARNING ⚠️ ultralytics.start() is deprecated after 8.0.60. Updated usage to train Ultralytics HUB models is:
53
+
54
+ from ultralytics import YOLO, hub
55
+
56
+ hub.login('{api_key}')
57
+ model = YOLO('{HUB_WEB_ROOT}/models/{model_id}')
58
+ model.train()""")
59
+
60
+
61
+ def reset_model(model_id=''):
62
+ """Reset a trained model to an untrained state."""
63
+ r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
64
+ if r.status_code == 200:
65
+ LOGGER.info(f'{PREFIX}Model reset successfully')
66
+ return
67
+ LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
68
+
69
+
70
+ def export_fmts_hub():
71
+ """Returns a list of HUB-supported export formats."""
72
+ from ultralytics.engine.exporter import export_formats
73
+ return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
74
+
75
+
76
+ def export_model(model_id='', format='torchscript'):
77
+ """Export a model to all formats."""
78
+ assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
79
+ r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export',
80
+ json={'format': format},
81
+ headers={'x-api-key': Auth().api_key})
82
+ assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
83
+ LOGGER.info(f'{PREFIX}{format} export started ✅')
84
+
85
+
86
+ def get_export(model_id='', format='torchscript'):
87
+ """Get an exported model dictionary with download URL."""
88
+ assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
89
+ r = requests.post(f'{HUB_API_ROOT}/get-export',
90
+ json={
91
+ 'apiKey': Auth().api_key,
92
+ 'modelId': model_id,
93
+ 'format': format})
94
+ assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
95
+ return r.json()
96
+
97
+
98
+ def check_dataset(path='', task='detect'):
99
+ """
100
+ Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is
101
+ uploaded to the HUB. Usage examples are given below.
102
+
103
+ Args:
104
+ path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''.
105
+ task (str, optional): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Defaults to 'detect'.
106
+
107
+ Example:
108
+ ```python
109
+ from ultralytics.hub import check_dataset
110
+
111
+ check_dataset('path/to/coco8.zip', task='detect') # detect dataset
112
+ check_dataset('path/to/coco8-seg.zip', task='segment') # segment dataset
113
+ check_dataset('path/to/coco8-pose.zip', task='pose') # pose dataset
114
+ ```
115
+ """
116
+ HUBDatasetStats(path=path, task=task).get_json()
117
+ LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
118
+
119
+
120
+ if __name__ == '__main__':
121
+ start()
ultralytics/hub/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (4.91 kB). View file
 
ultralytics/hub/__pycache__/auth.cpython-39.pyc ADDED
Binary file (4.2 kB). View file