croppie_coffee_ug / scripts /custom_YOLO.py
rgautroncgiar's picture
improving documentation
09deda7
from pathlib import Path
import numpy as np
np.random.seed(123)
import ultralytics
ultralytics.checks()
from ultralytics import YOLO
# imports for the YOLO custom class
from typing import Union
from ultralytics import yolo # noqa
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel,
attempt_load_one_weight)
from ultralytics.yolo.utils import (LOGGER, RANK, yaml_load)
from ultralytics.yolo.utils.checks import check_pip_update_available, check_yaml
TASK_MAP = {
'classify': [
ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator,
yolo.v8.classify.ClassificationPredictor],
'detect': [
DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator,
yolo.v8.detect.DetectionPredictor],
'segment': [
SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator,
yolo.v8.segment.SegmentationPredictor],
'pose': [PoseModel, yolo.v8.pose.PoseTrainer, yolo.v8.pose.PoseValidator, yolo.v8.pose.PosePredictor]}
# /imports for the YOLO custom class
class YOLO_custom(YOLO):
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
super().__init__(model, task)
def train(self, hyp: dict = None, **kwargs):
"""
CAUTION: OVERWRITES THE ORIGINAL METHOD TO ACCEPT HYPERPARAMETERS
Trains the model on a given dataset.
Args:
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
if self.session: # Ultralytics HUB session
if any(kwargs):
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
kwargs = self.session.train_args
self.session.check_disk_space()
check_pip_update_available()
overrides = self.overrides.copy()
overrides.update(kwargs)
if kwargs.get('cfg'):
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
overrides = yaml_load(check_yaml(kwargs['cfg']))
#********************** update hyp start **********************
if hyp:
if isinstance(hyp, dict):
LOGGER.info(f"'hyp' dict passed -> overriding the hyperparameters found in 'hyp'.")
for k, v in list(hyp.items()):
if v is None:
del hyp[k]
overrides.update(hyp)
else:
LOGGER.warning(f"WARNING the 'hyp' variable MUST be a dict")
#********************** update hyp end **********************
overrides['mode'] = 'train'
if not overrides.get('data'):
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get('resume'):
overrides['resume'] = self.ckpt_path
self.task = overrides.get('task') or self.task
self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train()
# update model and cfg after training
if RANK in (-1, 0):
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
if __name__ == "__main__":
pass