|
|
|
|
|
import contextlib |
|
import inspect |
|
import logging.config |
|
import os |
|
import platform |
|
import re |
|
import subprocess |
|
import sys |
|
import threading |
|
import urllib |
|
import uuid |
|
from pathlib import Path |
|
from types import SimpleNamespace |
|
from typing import Union |
|
|
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
import yaml |
|
|
|
from ultralytics import __version__ |
|
|
|
|
|
RANK = int(os.getenv('RANK', -1)) |
|
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) |
|
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) |
|
|
|
|
|
FILE = Path(__file__).resolve() |
|
ROOT = FILE.parents[2] |
|
DEFAULT_CFG_PATH = ROOT / 'yolo/cfg/default.yaml' |
|
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) |
|
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' |
|
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' |
|
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' |
|
LOGGING_NAME = 'ultralytics' |
|
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) |
|
HELP_MSG = \ |
|
""" |
|
Usage examples for running YOLOv8: |
|
|
|
1. Install the ultralytics package: |
|
|
|
pip install ultralytics |
|
|
|
2. Use the Python SDK: |
|
|
|
from ultralytics import YOLO |
|
|
|
# Load a model |
|
model = YOLO('yolov8n.yaml') # build a new model from scratch |
|
model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training) |
|
|
|
# Use the model |
|
results = model.train(data="coco128.yaml", epochs=3) # train the model |
|
results = model.val() # evaluate model performance on the validation set |
|
results = model('https://ultralytics.com/images/bus.jpg') # predict on an image |
|
success = model.export(format='onnx') # export the model to ONNX format |
|
|
|
3. Use the command line interface (CLI): |
|
|
|
YOLOv8 'yolo' CLI commands use the following syntax: |
|
|
|
yolo TASK MODE ARGS |
|
|
|
Where TASK (optional) is one of [detect, segment, classify] |
|
MODE (required) is one of [train, val, predict, export] |
|
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults. |
|
See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg' |
|
|
|
- Train a detection model for 10 epochs with an initial learning_rate of 0.01 |
|
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01 |
|
|
|
- Predict a YouTube video using a pretrained segmentation model at image size 320: |
|
yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320 |
|
|
|
- Val a pretrained detection model at batch-size 1 and image size 640: |
|
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640 |
|
|
|
- Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required) |
|
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128 |
|
|
|
- Run special commands: |
|
yolo help |
|
yolo checks |
|
yolo version |
|
yolo settings |
|
yolo copy-cfg |
|
yolo cfg |
|
|
|
Docs: https://docs.ultralytics.com |
|
Community: https://community.ultralytics.com |
|
GitHub: https://github.com/ultralytics/ultralytics |
|
""" |
|
|
|
|
|
torch.set_printoptions(linewidth=320, precision=4, profile='default') |
|
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) |
|
cv2.setNumThreads(0) |
|
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) |
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' |
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
|
|
class SimpleClass: |
|
""" |
|
Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute |
|
access methods for easier debugging and usage. |
|
""" |
|
|
|
def __str__(self): |
|
"""Return a human-readable string representation of the object.""" |
|
attr = [] |
|
for a in dir(self): |
|
v = getattr(self, a) |
|
if not callable(v) and not a.startswith('_'): |
|
if isinstance(v, SimpleClass): |
|
|
|
s = f'{a}: {v.__module__}.{v.__class__.__name__} object' |
|
else: |
|
s = f'{a}: {repr(v)}' |
|
attr.append(s) |
|
return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr) |
|
|
|
def __repr__(self): |
|
"""Return a machine-readable string representation of the object.""" |
|
return self.__str__() |
|
|
|
def __getattr__(self, attr): |
|
"""Custom attribute access error message with helpful information.""" |
|
name = self.__class__.__name__ |
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") |
|
|
|
|
|
class IterableSimpleNamespace(SimpleNamespace): |
|
""" |
|
Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and |
|
enables usage with dict() and for loops. |
|
""" |
|
|
|
def __iter__(self): |
|
"""Return an iterator of key-value pairs from the namespace's attributes.""" |
|
return iter(vars(self).items()) |
|
|
|
def __str__(self): |
|
"""Return a human-readable string representation of the object.""" |
|
return '\n'.join(f'{k}={v}' for k, v in vars(self).items()) |
|
|
|
def __getattr__(self, attr): |
|
"""Custom attribute access error message with helpful information.""" |
|
name = self.__class__.__name__ |
|
raise AttributeError(f""" |
|
'{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics |
|
'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace |
|
{DEFAULT_CFG_PATH} with the latest version from |
|
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/cfg/default.yaml |
|
""") |
|
|
|
def get(self, key, default=None): |
|
"""Return the value of the specified key if it exists; otherwise, return the default value.""" |
|
return getattr(self, key, default) |
|
|
|
|
|
def plt_settings(rcparams=None, backend='Agg'): |
|
""" |
|
Decorator to temporarily set rc parameters and the backend for a plotting function. |
|
|
|
Usage: |
|
decorator: @plt_settings({"font.size": 12}) |
|
context manager: with plt_settings({"font.size": 12}): |
|
|
|
Args: |
|
rcparams (dict): Dictionary of rc parameters to set. |
|
backend (str, optional): Name of the backend to use. Defaults to 'Agg'. |
|
|
|
Returns: |
|
callable: Decorated function with temporarily set rc parameters and backend. |
|
""" |
|
|
|
if rcparams is None: |
|
rcparams = {'font.size': 11} |
|
|
|
def decorator(func): |
|
"""Decorator to apply temporary rc parameters and backend to a function.""" |
|
|
|
def wrapper(*args, **kwargs): |
|
"""Sets rc parameters and backend, calls the original function, and restores the settings.""" |
|
original_backend = plt.get_backend() |
|
plt.switch_backend(backend) |
|
|
|
with plt.rc_context(rcparams): |
|
result = func(*args, **kwargs) |
|
|
|
plt.switch_backend(original_backend) |
|
return result |
|
|
|
return wrapper |
|
|
|
return decorator |
|
|
|
|
|
def set_logging(name=LOGGING_NAME, verbose=True): |
|
"""Sets up logging for the given name.""" |
|
rank = int(os.getenv('RANK', -1)) |
|
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR |
|
logging.config.dictConfig({ |
|
'version': 1, |
|
'disable_existing_loggers': False, |
|
'formatters': { |
|
name: { |
|
'format': '%(message)s'}}, |
|
'handlers': { |
|
name: { |
|
'class': 'logging.StreamHandler', |
|
'formatter': name, |
|
'level': level}}, |
|
'loggers': { |
|
name: { |
|
'level': level, |
|
'handlers': [name], |
|
'propagate': False}}}) |
|
|
|
|
|
class EmojiFilter(logging.Filter): |
|
""" |
|
A custom logging filter class for removing emojis in log messages. |
|
|
|
This filter is particularly useful for ensuring compatibility with Windows terminals |
|
that may not support the display of emojis in log messages. |
|
""" |
|
|
|
def filter(self, record): |
|
"""Filter logs by emoji unicode characters on windows.""" |
|
record.msg = emojis(record.msg) |
|
return super().filter(record) |
|
|
|
|
|
|
|
set_logging(LOGGING_NAME, verbose=VERBOSE) |
|
LOGGER = logging.getLogger(LOGGING_NAME) |
|
if WINDOWS: |
|
LOGGER.addFilter(EmojiFilter()) |
|
|
|
|
|
def yaml_save(file='data.yaml', data=None): |
|
""" |
|
Save YAML data to a file. |
|
|
|
Args: |
|
file (str, optional): File name. Default is 'data.yaml'. |
|
data (dict): Data to save in YAML format. |
|
|
|
Returns: |
|
None: Data is saved to the specified file. |
|
""" |
|
if data is None: |
|
data = {} |
|
file = Path(file) |
|
if not file.parent.exists(): |
|
|
|
file.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
for k, v in data.items(): |
|
if isinstance(v, Path): |
|
data[k] = str(v) |
|
|
|
|
|
with open(file, 'w') as f: |
|
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True) |
|
|
|
|
|
def yaml_load(file='data.yaml', append_filename=False): |
|
""" |
|
Load YAML data from a file. |
|
|
|
Args: |
|
file (str, optional): File name. Default is 'data.yaml'. |
|
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False. |
|
|
|
Returns: |
|
dict: YAML data and file name. |
|
""" |
|
with open(file, errors='ignore', encoding='utf-8') as f: |
|
s = f.read() |
|
|
|
|
|
if not s.isprintable(): |
|
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s) |
|
|
|
|
|
return {**yaml.safe_load(s), 'yaml_file': str(file)} if append_filename else yaml.safe_load(s) |
|
|
|
|
|
def yaml_print(yaml_file: Union[str, Path, dict]) -> None: |
|
""" |
|
Pretty prints a yaml file or a yaml-formatted dictionary. |
|
|
|
Args: |
|
yaml_file: The file path of the yaml file or a yaml-formatted dictionary. |
|
|
|
Returns: |
|
None |
|
""" |
|
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file |
|
dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True) |
|
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}") |
|
|
|
|
|
|
|
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH) |
|
for k, v in DEFAULT_CFG_DICT.items(): |
|
if isinstance(v, str) and v.lower() == 'none': |
|
DEFAULT_CFG_DICT[k] = None |
|
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys() |
|
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT) |
|
|
|
|
|
def is_colab(): |
|
""" |
|
Check if the current script is running inside a Google Colab notebook. |
|
|
|
Returns: |
|
bool: True if running inside a Colab notebook, False otherwise. |
|
""" |
|
return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ |
|
|
|
|
|
def is_kaggle(): |
|
""" |
|
Check if the current script is running inside a Kaggle kernel. |
|
|
|
Returns: |
|
bool: True if running inside a Kaggle kernel, False otherwise. |
|
""" |
|
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com' |
|
|
|
|
|
def is_jupyter(): |
|
""" |
|
Check if the current script is running inside a Jupyter Notebook. |
|
Verified on Colab, Jupyterlab, Kaggle, Paperspace. |
|
|
|
Returns: |
|
bool: True if running inside a Jupyter Notebook, False otherwise. |
|
""" |
|
with contextlib.suppress(Exception): |
|
from IPython import get_ipython |
|
return get_ipython() is not None |
|
return False |
|
|
|
|
|
def is_docker() -> bool: |
|
""" |
|
Determine if the script is running inside a Docker container. |
|
|
|
Returns: |
|
bool: True if the script is running inside a Docker container, False otherwise. |
|
""" |
|
file = Path('/proc/self/cgroup') |
|
if file.exists(): |
|
with open(file) as f: |
|
return 'docker' in f.read() |
|
else: |
|
return False |
|
|
|
|
|
def is_online() -> bool: |
|
""" |
|
Check internet connectivity by attempting to connect to a known online host. |
|
|
|
Returns: |
|
bool: True if connection is successful, False otherwise. |
|
""" |
|
import socket |
|
|
|
for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': |
|
try: |
|
test_connection = socket.create_connection(address=(host, 53), timeout=2) |
|
except (socket.timeout, socket.gaierror, OSError): |
|
continue |
|
else: |
|
|
|
test_connection.close() |
|
return True |
|
return False |
|
|
|
|
|
ONLINE = is_online() |
|
|
|
|
|
def is_pip_package(filepath: str = __name__) -> bool: |
|
""" |
|
Determines if the file at the given filepath is part of a pip package. |
|
|
|
Args: |
|
filepath (str): The filepath to check. |
|
|
|
Returns: |
|
bool: True if the file is part of a pip package, False otherwise. |
|
""" |
|
import importlib.util |
|
|
|
|
|
spec = importlib.util.find_spec(filepath) |
|
|
|
|
|
return spec is not None and spec.origin is not None |
|
|
|
|
|
def is_dir_writeable(dir_path: Union[str, Path]) -> bool: |
|
""" |
|
Check if a directory is writeable. |
|
|
|
Args: |
|
dir_path (str) or (Path): The path to the directory. |
|
|
|
Returns: |
|
bool: True if the directory is writeable, False otherwise. |
|
""" |
|
return os.access(str(dir_path), os.W_OK) |
|
|
|
|
|
def is_pytest_running(): |
|
""" |
|
Determines whether pytest is currently running or not. |
|
|
|
Returns: |
|
(bool): True if pytest is running, False otherwise. |
|
""" |
|
return ('PYTEST_CURRENT_TEST' in os.environ) or ('pytest' in sys.modules) or ('pytest' in Path(sys.argv[0]).stem) |
|
|
|
|
|
def is_github_actions_ci() -> bool: |
|
""" |
|
Determine if the current environment is a GitHub Actions CI Python runner. |
|
|
|
Returns: |
|
(bool): True if the current environment is a GitHub Actions CI Python runner, False otherwise. |
|
""" |
|
return 'GITHUB_ACTIONS' in os.environ and 'RUNNER_OS' in os.environ and 'RUNNER_TOOL_CACHE' in os.environ |
|
|
|
|
|
def is_git_dir(): |
|
""" |
|
Determines whether the current file is part of a git repository. |
|
If the current file is not part of a git repository, returns None. |
|
|
|
Returns: |
|
(bool): True if current file is part of a git repository. |
|
""" |
|
return get_git_dir() is not None |
|
|
|
|
|
def get_git_dir(): |
|
""" |
|
Determines whether the current file is part of a git repository and if so, returns the repository root directory. |
|
If the current file is not part of a git repository, returns None. |
|
|
|
Returns: |
|
(Path) or (None): Git root directory if found or None if not found. |
|
""" |
|
for d in Path(__file__).parents: |
|
if (d / '.git').is_dir(): |
|
return d |
|
return None |
|
|
|
|
|
def get_git_origin_url(): |
|
""" |
|
Retrieves the origin URL of a git repository. |
|
|
|
Returns: |
|
(str) or (None): The origin URL of the git repository. |
|
""" |
|
if is_git_dir(): |
|
with contextlib.suppress(subprocess.CalledProcessError): |
|
origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url']) |
|
return origin.decode().strip() |
|
return None |
|
|
|
|
|
def get_git_branch(): |
|
""" |
|
Returns the current git branch name. If not in a git repository, returns None. |
|
|
|
Returns: |
|
(str) or (None): The current git branch name. |
|
""" |
|
if is_git_dir(): |
|
with contextlib.suppress(subprocess.CalledProcessError): |
|
origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) |
|
return origin.decode().strip() |
|
return None |
|
|
|
|
|
def get_default_args(func): |
|
"""Returns a dictionary of default arguments for a function. |
|
|
|
Args: |
|
func (callable): The function to inspect. |
|
|
|
Returns: |
|
dict: A dictionary where each key is a parameter name, and each value is the default value of that parameter. |
|
""" |
|
signature = inspect.signature(func) |
|
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty} |
|
|
|
|
|
def get_user_config_dir(sub_dir='Ultralytics'): |
|
""" |
|
Get the user config directory. |
|
|
|
Args: |
|
sub_dir (str): The name of the subdirectory to create. |
|
|
|
Returns: |
|
Path: The path to the user config directory. |
|
""" |
|
|
|
if WINDOWS: |
|
path = Path.home() / 'AppData' / 'Roaming' / sub_dir |
|
elif MACOS: |
|
path = Path.home() / 'Library' / 'Application Support' / sub_dir |
|
elif LINUX: |
|
path = Path.home() / '.config' / sub_dir |
|
else: |
|
raise ValueError(f'Unsupported operating system: {platform.system()}') |
|
|
|
|
|
if not is_dir_writeable(str(path.parent)): |
|
path = Path('/tmp') / sub_dir |
|
|
|
|
|
path.mkdir(parents=True, exist_ok=True) |
|
|
|
return path |
|
|
|
|
|
USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR', get_user_config_dir())) |
|
SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml' |
|
|
|
|
|
def emojis(string=''): |
|
"""Return platform-dependent emoji-safe version of string.""" |
|
return string.encode().decode('ascii', 'ignore') if WINDOWS else string |
|
|
|
|
|
def colorstr(*input): |
|
"""Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world').""" |
|
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) |
|
colors = { |
|
'black': '\033[30m', |
|
'red': '\033[31m', |
|
'green': '\033[32m', |
|
'yellow': '\033[33m', |
|
'blue': '\033[34m', |
|
'magenta': '\033[35m', |
|
'cyan': '\033[36m', |
|
'white': '\033[37m', |
|
'bright_black': '\033[90m', |
|
'bright_red': '\033[91m', |
|
'bright_green': '\033[92m', |
|
'bright_yellow': '\033[93m', |
|
'bright_blue': '\033[94m', |
|
'bright_magenta': '\033[95m', |
|
'bright_cyan': '\033[96m', |
|
'bright_white': '\033[97m', |
|
'end': '\033[0m', |
|
'bold': '\033[1m', |
|
'underline': '\033[4m'} |
|
return ''.join(colors[x] for x in args) + f'{string}' + colors['end'] |
|
|
|
|
|
class TryExcept(contextlib.ContextDecorator): |
|
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager.""" |
|
|
|
def __init__(self, msg='', verbose=True): |
|
"""Initialize TryExcept class with optional message and verbosity settings.""" |
|
self.msg = msg |
|
self.verbose = verbose |
|
|
|
def __enter__(self): |
|
"""Executes when entering TryExcept context, initializes instance.""" |
|
pass |
|
|
|
def __exit__(self, exc_type, value, traceback): |
|
"""Defines behavior when exiting a 'with' block, prints error message if necessary.""" |
|
if self.verbose and value: |
|
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}")) |
|
return True |
|
|
|
|
|
def threaded(func): |
|
"""Multi-threads a target function and returns thread. Usage: @threaded decorator.""" |
|
|
|
def wrapper(*args, **kwargs): |
|
"""Multi-threads a given function and returns the thread.""" |
|
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) |
|
thread.start() |
|
return thread |
|
|
|
return wrapper |
|
|
|
|
|
def set_sentry(): |
|
""" |
|
Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed and |
|
sync=True in settings. Run 'yolo settings' to see and update settings YAML file. |
|
|
|
Conditions required to send errors (ALL conditions must be met or no errors will be reported): |
|
- sentry_sdk package is installed |
|
- sync=True in YOLO settings |
|
- pytest is not running |
|
- running in a pip package installation |
|
- running in a non-git directory |
|
- running with rank -1 or 0 |
|
- online environment |
|
- CLI used to run package (checked with 'yolo' as the name of the main CLI command) |
|
|
|
The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundError |
|
exceptions and to exclude events with 'out of memory' in their exception message. |
|
|
|
Additionally, the function sets custom tags and user information for Sentry events. |
|
""" |
|
|
|
def before_send(event, hint): |
|
""" |
|
Modify the event before sending it to Sentry based on specific exception types and messages. |
|
|
|
Args: |
|
event (dict): The event dictionary containing information about the error. |
|
hint (dict): A dictionary containing additional information about the error. |
|
|
|
Returns: |
|
dict: The modified event or None if the event should not be sent to Sentry. |
|
""" |
|
if 'exc_info' in hint: |
|
exc_type, exc_value, tb = hint['exc_info'] |
|
if exc_type in (KeyboardInterrupt, FileNotFoundError) \ |
|
or 'out of memory' in str(exc_value): |
|
return None |
|
|
|
event['tags'] = { |
|
'sys_argv': sys.argv[0], |
|
'sys_argv_name': Path(sys.argv[0]).name, |
|
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other', |
|
'os': ENVIRONMENT} |
|
return event |
|
|
|
if SETTINGS['sync'] and \ |
|
RANK in (-1, 0) and \ |
|
Path(sys.argv[0]).name == 'yolo' and \ |
|
not TESTS_RUNNING and \ |
|
ONLINE and \ |
|
is_pip_package() and \ |
|
not is_git_dir(): |
|
|
|
|
|
try: |
|
import sentry_sdk |
|
except ImportError: |
|
return |
|
|
|
sentry_sdk.init( |
|
dsn='https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016', |
|
debug=False, |
|
traces_sample_rate=1.0, |
|
release=__version__, |
|
environment='production', |
|
before_send=before_send, |
|
ignore_errors=[KeyboardInterrupt, FileNotFoundError]) |
|
sentry_sdk.set_user({'id': SETTINGS['uuid']}) |
|
|
|
|
|
for logger in 'sentry_sdk', 'sentry_sdk.errors': |
|
logging.getLogger(logger).setLevel(logging.CRITICAL) |
|
|
|
|
|
def get_settings(file=SETTINGS_YAML, version='0.0.3'): |
|
""" |
|
Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist. |
|
|
|
Args: |
|
file (Path): Path to the Ultralytics settings YAML file. Defaults to 'settings.yaml' in the USER_CONFIG_DIR. |
|
version (str): Settings version. If min settings version not met, new default settings will be saved. |
|
|
|
Returns: |
|
dict: Dictionary of settings key-value pairs. |
|
""" |
|
import hashlib |
|
|
|
from ultralytics.yolo.utils.checks import check_version |
|
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first |
|
|
|
git_dir = get_git_dir() |
|
root = git_dir or Path() |
|
datasets_root = (root.parent if git_dir and is_dir_writeable(root.parent) else root).resolve() |
|
defaults = { |
|
'datasets_dir': str(datasets_root / 'datasets'), |
|
'weights_dir': str(root / 'weights'), |
|
'runs_dir': str(root / 'runs'), |
|
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), |
|
'sync': True, |
|
'api_key': '', |
|
'settings_version': version} |
|
|
|
with torch_distributed_zero_first(RANK): |
|
if not file.exists(): |
|
yaml_save(file, defaults) |
|
settings = yaml_load(file) |
|
|
|
|
|
correct = \ |
|
settings \ |
|
and settings.keys() == defaults.keys() \ |
|
and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \ |
|
and check_version(settings['settings_version'], version) |
|
if not correct: |
|
LOGGER.warning('WARNING ⚠️ Ultralytics settings reset to defaults. This is normal and may be due to a ' |
|
'recent ultralytics package update, but may have overwritten previous settings. ' |
|
f"\nView and update settings with 'yolo settings' or at '{file}'") |
|
settings = defaults |
|
yaml_save(file, settings) |
|
|
|
return settings |
|
|
|
|
|
def set_settings(kwargs, file=SETTINGS_YAML): |
|
""" |
|
Function that runs on a first-time ultralytics package installation to set up global settings and create necessary |
|
directories. |
|
""" |
|
SETTINGS.update(kwargs) |
|
yaml_save(file, SETTINGS) |
|
|
|
|
|
def deprecation_warn(arg, new_arg, version=None): |
|
"""Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument.""" |
|
if not version: |
|
version = float(__version__[:3]) + 0.2 |
|
LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. " |
|
f"Please use '{new_arg}' instead.") |
|
|
|
|
|
def clean_url(url): |
|
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt.""" |
|
url = str(Path(url)).replace(':/', '://') |
|
return urllib.parse.unquote(url).split('?')[0] |
|
|
|
|
|
def url2file(url): |
|
"""Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt.""" |
|
return Path(clean_url(url)).name |
|
|
|
|
|
|
|
|
|
|
|
PREFIX = colorstr('Ultralytics: ') |
|
SETTINGS = get_settings() |
|
DATASETS_DIR = Path(SETTINGS['datasets_dir']) |
|
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \ |
|
'Docker' if is_docker() else platform.system() |
|
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci() |
|
set_sentry() |
|
|
|
|
|
from .patches import imread, imshow, imwrite |
|
|
|
|
|
if Path(inspect.stack()[0].filename).parent.parent.as_posix() in inspect.stack()[-1].filename: |
|
cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow |
|
|