File size: 7,094 Bytes
32b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
"""
A main training script.
"""
# Copyright (c) Facebook, Inc. and its affiliates.
import warnings
warnings.filterwarnings('ignore') # never print matching warnings
import logging
import os
from collections import OrderedDict
import torch
import uniperceiver.utils.comm as comm
from uniperceiver.config import get_cfg, CfgNode
from uniperceiver.engine import DefaultTrainer, default_argument_parser, default_setup, launch, build_engine, add_moe_arguments
#!TODO re-implement hooks
from uniperceiver.engine import hooks
from uniperceiver.modeling import add_config
from uniperceiver.utils.env import init_distributed_mode, check_dist_portfile
try:
import deepspeed
DEEPSPEED_INSTALLED = True
except:
DEEPSPEED_INSTALLED = False
import copy
def add_data_prefix(cfg):
# TODO: more flexible method
data_dir = os.getenv("DATA_PATH", None)
mapping_list = [
[cfg.DATALOADER, 'FEATS_FOLDER', ['DATALOADER',]],
[cfg.DATALOADER, 'ANNO_FOLDER', ['DATALOADER', ]],
[cfg.DATALOADER, 'CLASS_NAME_FILE', ['DATALOADER', ]],
[cfg.INFERENCE, 'VOCAB', ['INFERENCE', ]],
[cfg.INFERENCE, 'VAL_ANNFILE', ['INFERENCE', ]],
[cfg.INFERENCE, 'TEST_ANNFILE', ['INFERENCE',]],
[cfg.MODEL, 'WEIGHTS', ['MODEL',]],
]
whitelist = ["BERT", "CLIP", "CLIP_CAPTION"]
if data_dir:
for node, attr ,_ in mapping_list:
if node[attr] != '' and not node[attr].startswith('.') and not node[attr].startswith('/') and not node[attr].startswith('work_dirs') and not node[attr].startswith('cluster') and not node[attr].startswith('s3://') and node[attr] not in whitelist:
setattr(node, attr, os.path.join(data_dir, node[attr]))
for task in cfg.TASKS:
for _, item, key_list in mapping_list:
config_tmp = task
for key in key_list:
if key in config_tmp:
config_tmp = config_tmp[key]
if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith('/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith('cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
config_tmp[item] = os.path.join(data_dir, config_tmp[item])
mapping_list = [
['', 'FILE_PATH', ['SHARED_TARGETS_CFG',]],
]
if cfg.SHARED_TARGETS is None:
cfg.SHARED_TARGETS = []
for share_targets in cfg.SHARED_TARGETS:
for _, item, key_list in mapping_list:
config_tmp = share_targets
for key in key_list:
config_tmp = config_tmp[key]
if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith(
'/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith(
'cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
config_tmp[item] = os.path.join(data_dir, config_tmp[item])
def add_default_setting_for_multitask_config(cfg):
# merge some default config in (CfgNode) uniperceiver/config/defaults.py to each task config (dict)
tasks_config_temp = cfg.TASKS
num_tasks = len(tasks_config_temp)
cfg.pop('TASKS', None)
cfg.TASKS = [copy.deepcopy(cfg) for _ in range(num_tasks)]
for i, task_config in enumerate(tasks_config_temp):
cfg.TASKS[i].merge_from_other_cfg(CfgNode(task_config))
cfg.TASKS[i] = cfg.TASKS[i].to_dict_object()
pass
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
tmp_cfg = cfg.load_from_file_tmp(args.config_file)
add_config(cfg, tmp_cfg)
cfg.merge_from_file(args.config_file)
add_data_prefix(cfg)
cfg.merge_from_list(args.opts)
#
add_default_setting_for_multitask_config(cfg)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
"""
If you'd like to do anything fancier than the standard training logic,
consider writing your own training loop (see plain_train_net.py) or
subclassing the trainer.
"""
trainer = build_engine(cfg)
trainer.resume_or_load(resume=args.resume)
trainer.cast_layers()
if args.eval_only:
print('---------------------------')
print('eval model only')
print('---------------------------\n')
res = None
if trainer.val_data_loader is not None:
if trainer.model_ema is not None and args.eval_ema:
if comm.is_main_process():
print('using ema model for evaluation')
res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
else:
if args.eval_ema and comm.is_main_process():
print('no ema model exists! using master model for evaluation')
res = trainer.test(trainer.cfg, trainer.model, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
if comm.is_main_process():
print(res)
if trainer.test_data_loader is not None:
if trainer.model_ema is not None and args.eval_ema:
if comm.is_main_process():
print('using ema model for evaluation')
res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
else:
if args.eval_ema and comm.is_main_process():
print('no ema model exists! using master model for evaluation')
res = trainer.test(trainer.cfg, trainer.model, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
if comm.is_main_process():
print(res)
return res
return trainer.train()
def get_args_parser():
parser = default_argument_parser()
if DEEPSPEED_INSTALLED:
parser = deepspeed.add_config_arguments(parser)
parser = add_moe_arguments(parser)
parser.add_argument('--init_method', default='slurm', type=str)
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument("--eval-ema", action="store_true", help="perform evaluation using ema")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args_parser()
print("Command Line Args:", args)
if args.init_method == 'slurm':
# slurm init
check_dist_portfile()
init_distributed_mode(args)
main(args)
elif args.init_method == 'pytorch':
main(args)
else:
# follow 'd2' use default `mp.spawn` to init dist training
print('using \'mp.spawn\' for dist init! ')
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)
|