unit_test / main.py
herrius's picture
Upload 259 files
32b542e
"""
A main training script.
"""
# Copyright (c) Facebook, Inc. and its affiliates.
import warnings
warnings.filterwarnings('ignore') # never print matching warnings
import logging
import os
from collections import OrderedDict
import torch
import uniperceiver.utils.comm as comm
from uniperceiver.config import get_cfg, CfgNode
from uniperceiver.engine import DefaultTrainer, default_argument_parser, default_setup, launch, build_engine, add_moe_arguments
#!TODO re-implement hooks
from uniperceiver.engine import hooks
from uniperceiver.modeling import add_config
from uniperceiver.utils.env import init_distributed_mode, check_dist_portfile
try:
import deepspeed
DEEPSPEED_INSTALLED = True
except:
DEEPSPEED_INSTALLED = False
import copy
def add_data_prefix(cfg):
# TODO: more flexible method
data_dir = os.getenv("DATA_PATH", None)
mapping_list = [
[cfg.DATALOADER, 'FEATS_FOLDER', ['DATALOADER',]],
[cfg.DATALOADER, 'ANNO_FOLDER', ['DATALOADER', ]],
[cfg.DATALOADER, 'CLASS_NAME_FILE', ['DATALOADER', ]],
[cfg.INFERENCE, 'VOCAB', ['INFERENCE', ]],
[cfg.INFERENCE, 'VAL_ANNFILE', ['INFERENCE', ]],
[cfg.INFERENCE, 'TEST_ANNFILE', ['INFERENCE',]],
[cfg.MODEL, 'WEIGHTS', ['MODEL',]],
]
whitelist = ["BERT", "CLIP", "CLIP_CAPTION"]
if data_dir:
for node, attr ,_ in mapping_list:
if node[attr] != '' and not node[attr].startswith('.') and not node[attr].startswith('/') and not node[attr].startswith('work_dirs') and not node[attr].startswith('cluster') and not node[attr].startswith('s3://') and node[attr] not in whitelist:
setattr(node, attr, os.path.join(data_dir, node[attr]))
for task in cfg.TASKS:
for _, item, key_list in mapping_list:
config_tmp = task
for key in key_list:
if key in config_tmp:
config_tmp = config_tmp[key]
if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith('/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith('cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
config_tmp[item] = os.path.join(data_dir, config_tmp[item])
mapping_list = [
['', 'FILE_PATH', ['SHARED_TARGETS_CFG',]],
]
if cfg.SHARED_TARGETS is None:
cfg.SHARED_TARGETS = []
for share_targets in cfg.SHARED_TARGETS:
for _, item, key_list in mapping_list:
config_tmp = share_targets
for key in key_list:
config_tmp = config_tmp[key]
if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith(
'/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith(
'cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
config_tmp[item] = os.path.join(data_dir, config_tmp[item])
def add_default_setting_for_multitask_config(cfg):
# merge some default config in (CfgNode) uniperceiver/config/defaults.py to each task config (dict)
tasks_config_temp = cfg.TASKS
num_tasks = len(tasks_config_temp)
cfg.pop('TASKS', None)
cfg.TASKS = [copy.deepcopy(cfg) for _ in range(num_tasks)]
for i, task_config in enumerate(tasks_config_temp):
cfg.TASKS[i].merge_from_other_cfg(CfgNode(task_config))
cfg.TASKS[i] = cfg.TASKS[i].to_dict_object()
pass
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
tmp_cfg = cfg.load_from_file_tmp(args.config_file)
add_config(cfg, tmp_cfg)
cfg.merge_from_file(args.config_file)
add_data_prefix(cfg)
cfg.merge_from_list(args.opts)
#
add_default_setting_for_multitask_config(cfg)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
"""
If you'd like to do anything fancier than the standard training logic,
consider writing your own training loop (see plain_train_net.py) or
subclassing the trainer.
"""
trainer = build_engine(cfg)
trainer.resume_or_load(resume=args.resume)
trainer.cast_layers()
if args.eval_only:
print('---------------------------')
print('eval model only')
print('---------------------------\n')
res = None
if trainer.val_data_loader is not None:
if trainer.model_ema is not None and args.eval_ema:
if comm.is_main_process():
print('using ema model for evaluation')
res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
else:
if args.eval_ema and comm.is_main_process():
print('no ema model exists! using master model for evaluation')
res = trainer.test(trainer.cfg, trainer.model, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
if comm.is_main_process():
print(res)
if trainer.test_data_loader is not None:
if trainer.model_ema is not None and args.eval_ema:
if comm.is_main_process():
print('using ema model for evaluation')
res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
else:
if args.eval_ema and comm.is_main_process():
print('no ema model exists! using master model for evaluation')
res = trainer.test(trainer.cfg, trainer.model, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
if comm.is_main_process():
print(res)
return res
return trainer.train()
def get_args_parser():
parser = default_argument_parser()
if DEEPSPEED_INSTALLED:
parser = deepspeed.add_config_arguments(parser)
parser = add_moe_arguments(parser)
parser.add_argument('--init_method', default='slurm', type=str)
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument("--eval-ema", action="store_true", help="perform evaluation using ema")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args_parser()
print("Command Line Args:", args)
if args.init_method == 'slurm':
# slurm init
check_dist_portfile()
init_distributed_mode(args)
main(args)
elif args.init_method == 'pytorch':
main(args)
else:
# follow 'd2' use default `mp.spawn` to init dist training
print('using \'mp.spawn\' for dist init! ')
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)