yolov6 / tools /train.py
Theivaprakasham's picture
adding app
be49b0b
raw
history blame contribute delete
No virus
3.49 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
import os
import os.path as osp
import torch
import torch.distributed as dist
import sys
ROOT = os.getcwd()
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from yolov6.core.engine import Trainer
from yolov6.utils.config import Config
from yolov6.utils.events import LOGGER, save_yaml
from yolov6.utils.envs import get_envs, select_device, set_random_seed
def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description='YOLOv6 PyTorch Training', add_help=add_help)
parser.add_argument('--data-path', default='./data/coco.yaml', type=str, help='dataset path')
parser.add_argument('--conf-file', default='./configs/yolov6s.py', type=str, help='experiment description file')
parser.add_argument('--img-size', type=int, default=640, help='train, val image size (pixels)')
parser.add_argument('--batch-size', default=32, type=int, help='total batch size for all GPUs')
parser.add_argument('--epochs', default=400, type=int, help='number of total epochs to run')
parser.add_argument('--workers', default=8, type=int, help='number of data loading workers (default: 8)')
parser.add_argument('--device', default='0', type=str, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--noval', action='store_true', help='only evaluate in final epoch')
parser.add_argument('--check-images', action='store_true', help='check images when initializing datasets')
parser.add_argument('--check-labels', action='store_true', help='check label files when initializing datasets')
parser.add_argument('--output-dir', default='./runs/train', type=str, help='path to save outputs')
parser.add_argument('--name', default='exp', type=str, help='experiment name, save to output_dir/name')
parser.add_argument('--dist_url', type=str, default="tcp://127.0.0.1:8888")
parser.add_argument('--gpu_count', type=int, default=0)
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
return parser
def check_and_init(args):
'''check config files and device, and initialize '''
# check files
args.save_dir = osp.join(args.output_dir, args.name)
os.makedirs(args.save_dir, exist_ok=True)
cfg = Config.fromfile(args.conf_file)
# check device
device = select_device(args.device)
# set random seed
set_random_seed(1+args.rank, deterministic=(args.rank == -1))
# save args
save_yaml(vars(args), osp.join(args.save_dir, 'args.yaml'))
return cfg, device
def main(args):
'''main function of training'''
# Setup
args.rank, args.local_rank, args.world_size = get_envs()
LOGGER.info(f'training args are: {args}\n')
cfg, device = check_and_init(args)
if args.local_rank != -1: # if DDP mode
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda', args.local_rank)
LOGGER.info('Initializing process group... ')
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", \
init_method=args.dist_url, rank=args.local_rank, world_size=args.world_size)
# Start
trainer = Trainer(args, cfg, device)
trainer.train()
# End
if args.world_size > 1 and args.rank == 0:
LOGGER.info('Destroying process group... ')
dist.destroy_process_group()
if __name__ == '__main__':
args = get_args_parser().parse_args()
main(args)