|
""" |
|
A main training script. |
|
""" |
|
|
|
|
|
|
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
import logging |
|
import os |
|
from collections import OrderedDict |
|
import torch |
|
import uniperceiver.utils.comm as comm |
|
from uniperceiver.config import get_cfg, CfgNode |
|
from uniperceiver.engine import DefaultTrainer, default_argument_parser, default_setup, launch, build_engine, add_moe_arguments |
|
|
|
|
|
from uniperceiver.engine import hooks |
|
from uniperceiver.modeling import add_config |
|
from uniperceiver.utils.env import init_distributed_mode, check_dist_portfile |
|
try: |
|
import deepspeed |
|
DEEPSPEED_INSTALLED = True |
|
except: |
|
DEEPSPEED_INSTALLED = False |
|
|
|
import copy |
|
|
|
def add_data_prefix(cfg): |
|
|
|
data_dir = os.getenv("DATA_PATH", None) |
|
mapping_list = [ |
|
[cfg.DATALOADER, 'FEATS_FOLDER', ['DATALOADER',]], |
|
[cfg.DATALOADER, 'ANNO_FOLDER', ['DATALOADER', ]], |
|
[cfg.DATALOADER, 'CLASS_NAME_FILE', ['DATALOADER', ]], |
|
[cfg.INFERENCE, 'VOCAB', ['INFERENCE', ]], |
|
[cfg.INFERENCE, 'VAL_ANNFILE', ['INFERENCE', ]], |
|
[cfg.INFERENCE, 'TEST_ANNFILE', ['INFERENCE',]], |
|
[cfg.MODEL, 'WEIGHTS', ['MODEL',]], |
|
] |
|
whitelist = ["BERT", "CLIP", "CLIP_CAPTION"] |
|
if data_dir: |
|
for node, attr ,_ in mapping_list: |
|
if node[attr] != '' and not node[attr].startswith('.') and not node[attr].startswith('/') and not node[attr].startswith('work_dirs') and not node[attr].startswith('cluster') and not node[attr].startswith('s3://') and node[attr] not in whitelist: |
|
setattr(node, attr, os.path.join(data_dir, node[attr])) |
|
for task in cfg.TASKS: |
|
for _, item, key_list in mapping_list: |
|
config_tmp = task |
|
for key in key_list: |
|
if key in config_tmp: |
|
config_tmp = config_tmp[key] |
|
if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith('/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith('cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist: |
|
config_tmp[item] = os.path.join(data_dir, config_tmp[item]) |
|
|
|
mapping_list = [ |
|
['', 'FILE_PATH', ['SHARED_TARGETS_CFG',]], |
|
] |
|
if cfg.SHARED_TARGETS is None: |
|
cfg.SHARED_TARGETS = [] |
|
for share_targets in cfg.SHARED_TARGETS: |
|
for _, item, key_list in mapping_list: |
|
config_tmp = share_targets |
|
for key in key_list: |
|
config_tmp = config_tmp[key] |
|
if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith( |
|
'/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith( |
|
'cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist: |
|
config_tmp[item] = os.path.join(data_dir, config_tmp[item]) |
|
|
|
|
|
|
|
def add_default_setting_for_multitask_config(cfg): |
|
|
|
|
|
tasks_config_temp = cfg.TASKS |
|
num_tasks = len(tasks_config_temp) |
|
cfg.pop('TASKS', None) |
|
|
|
cfg.TASKS = [copy.deepcopy(cfg) for _ in range(num_tasks)] |
|
|
|
for i, task_config in enumerate(tasks_config_temp): |
|
cfg.TASKS[i].merge_from_other_cfg(CfgNode(task_config)) |
|
cfg.TASKS[i] = cfg.TASKS[i].to_dict_object() |
|
pass |
|
|
|
|
|
def setup(args): |
|
""" |
|
Create configs and perform basic setups. |
|
""" |
|
cfg = get_cfg() |
|
tmp_cfg = cfg.load_from_file_tmp(args.config_file) |
|
add_config(cfg, tmp_cfg) |
|
|
|
cfg.merge_from_file(args.config_file) |
|
add_data_prefix(cfg) |
|
|
|
cfg.merge_from_list(args.opts) |
|
|
|
add_default_setting_for_multitask_config(cfg) |
|
cfg.freeze() |
|
default_setup(cfg, args) |
|
return cfg |
|
|
|
def main(args): |
|
cfg = setup(args) |
|
|
|
""" |
|
If you'd like to do anything fancier than the standard training logic, |
|
consider writing your own training loop (see plain_train_net.py) or |
|
subclassing the trainer. |
|
""" |
|
trainer = build_engine(cfg) |
|
trainer.resume_or_load(resume=args.resume) |
|
trainer.cast_layers() |
|
|
|
if args.eval_only: |
|
print('---------------------------') |
|
print('eval model only') |
|
print('---------------------------\n') |
|
res = None |
|
if trainer.val_data_loader is not None: |
|
|
|
if trainer.model_ema is not None and args.eval_ema: |
|
if comm.is_main_process(): |
|
print('using ema model for evaluation') |
|
res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.val_data_loader, trainer.val_evaluator, epoch=-1) |
|
else: |
|
if args.eval_ema and comm.is_main_process(): |
|
print('no ema model exists! using master model for evaluation') |
|
res = trainer.test(trainer.cfg, trainer.model, trainer.val_data_loader, trainer.val_evaluator, epoch=-1) |
|
|
|
if comm.is_main_process(): |
|
print(res) |
|
|
|
if trainer.test_data_loader is not None: |
|
if trainer.model_ema is not None and args.eval_ema: |
|
if comm.is_main_process(): |
|
print('using ema model for evaluation') |
|
res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.test_data_loader, trainer.test_evaluator, epoch=-1) |
|
else: |
|
if args.eval_ema and comm.is_main_process(): |
|
print('no ema model exists! using master model for evaluation') |
|
res = trainer.test(trainer.cfg, trainer.model, trainer.test_data_loader, trainer.test_evaluator, epoch=-1) |
|
if comm.is_main_process(): |
|
print(res) |
|
return res |
|
|
|
return trainer.train() |
|
|
|
def get_args_parser(): |
|
parser = default_argument_parser() |
|
if DEEPSPEED_INSTALLED: |
|
parser = deepspeed.add_config_arguments(parser) |
|
parser = add_moe_arguments(parser) |
|
|
|
parser.add_argument('--init_method', default='slurm', type=str) |
|
parser.add_argument('--local_rank', default=0, type=int) |
|
parser.add_argument("--eval-ema", action="store_true", help="perform evaluation using ema") |
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
if __name__ == "__main__": |
|
args = get_args_parser() |
|
print("Command Line Args:", args) |
|
if args.init_method == 'slurm': |
|
|
|
check_dist_portfile() |
|
init_distributed_mode(args) |
|
main(args) |
|
elif args.init_method == 'pytorch': |
|
main(args) |
|
else: |
|
|
|
print('using \'mp.spawn\' for dist init! ') |
|
launch( |
|
main, |
|
args.num_gpus, |
|
num_machines=args.num_machines, |
|
machine_rank=args.machine_rank, |
|
dist_url=args.dist_url, |
|
args=(args,), |
|
) |
|
|