import torch |
import torch.distributed as dist |
from vlmeval.config import supported_VLM |
from vlmeval.dataset import build_dataset |
from vlmeval.inference import infer_data_job |
from vlmeval.inference_video import infer_data_job_video |
from vlmeval.inference_mt import infer_data_job_mt |
from vlmeval.smp import * |
from vlmeval.utils.result_transfer import MMMU_result_transfer, MMTBench_result_transfer |
def build_model_from_config(cfg): |
import vlmeval.api |
import vlmeval.vlm |
config = cp.deepcopy(cfg) |
assert 'class' in config |
cls_name = config.pop('class') |
if hasattr(vlmeval.api, cls_name): |
return getattr(vlmeval.api, cls_name)(**config) |
elif hasattr(vlmeval.vlm, cls_name): |
return getattr(vlmeval.vlm, cls_name)(**config) |
else: |
raise ValueError(f'Class {cls_name} is not supported in `vlmeval.api` or `vlmeval.vlm`') |
def build_dataset_from_config(cfg): |
import vlmeval.dataset |
config = cp.deepcopy(cfg) |
assert 'class' in config |
cls_name = config.pop('class') |
if hasattr(vlmeval.dataset, cls_name): |
return getattr(vlmeval.dataset, cls_name)(**config) |
else: |
raise ValueError(f'Class {cls_name} is not supported in `vlmeval.dataset`') |
def parse_args(): |
help_msg = """\ |
You can launch the evaluation by setting either --data and --model or --config. |
--data and --model: |
Each Arg should be a list of strings, specifying the names of datasets and models. |
To find all supported model names, please refer to the `vlmeval/config.py` of check the output of the command \ |
`vlmutil mlist all` in the terminal (you should first have vlmeval installed). |
To find all supported dataset names, please refer to the `vlmeval/dataset/__init__.py` file. The python script \ |
to print all supported dataset names is as follows: |
```python |
from vlmeval.dataset import SUPPORTED_DATASETS |
``` |
or you can check the output of the command `vlmutil dlist all` in the terminal. |
--config: |
Launch the evaluation by specifying the path to the config json file. Sample Json Content: |
```json |
{ |
"model": { |
"GPT4o_20240806_T00_HIGH": { |
"class": "GPT4V", |
"model": "gpt-4o-2024-08-06", |
"temperature": 0, |
"img_detail": "high" |
}, |
"GPT4o_20240806_T10_Low": { |
"class": "GPT4V", |
"model": "gpt-4o-2024-08-06", |
"temperature": 1.0, |
"img_detail": "low" |
} |
}, |
"data": { |
"MME-RealWorld-Lite": { |
"class": "MMERealWorld", |
"dataset": "MME-RealWorld-Lite" |
}, |
"MMBench_DEV_EN_V11": { |
"class": "ImageMCQDataset", |
"dataset": "MMBench_DEV_EN_V11" |
} |
} |
} |
``` |
Currently, only `model` and `data` are supported fields. The content of each field is a dictionary. |
For `model`, the key is the name of the model, and the value is a dictionary containing the following keys: |
- `class`: The class name of the model, which should be a class in `vlmeval.vlm` or `vlmeval.api`. |
- Other keys are specific to the model, please refer to the corresponding class. |
For `data`, the key is the name of the dataset (should be the same as the `dataset` field in most cases, \ |
except for video datasets), and the value is a dictionary containing the following keys: |
- `class`: The class name of the dataset, which should be a class in `vlmeval.dataset`. |
- `dataset`: The name of the dataset, which should be a string that is accepted by the `dataset` argument of the \ |
corresponding class. |
- Other keys are specific to the dataset, please refer to the corresponding class. |
The keys in the `model` and `data` fields will be used for naming the prediction files and evaluation results. |
When launching with `--config`, args for video datasets, such as `--nframe`, `--pack`, `--use-subtitle`, `--fps`, \ |
and args for API VLMs, such as `--retry`, `--verbose`, will be ignored. |
""" |
parser = argparse.ArgumentParser(description=help_msg, formatter_class=argparse.RawTextHelpFormatter) |
parser.add_argument('--data', type=str, nargs='+', help='Names of Datasets') |
parser.add_argument('--model', type=str, nargs='+', help='Names of Models') |
parser.add_argument('--config', type=str, help='Path to the Config Json File') |
parser.add_argument('--nframe', type=int, default=8) |
parser.add_argument('--pack', action='store_true') |
parser.add_argument('--use-subtitle', action='store_true') |
parser.add_argument('--fps', type=float, default=-1) |
parser.add_argument('--work-dir', type=str, default='./outputs', help='select the output directory') |
parser.add_argument('--mode', type=str, default='all', choices=['all', 'infer']) |
parser.add_argument('--nproc', type=int, default=4, help='Parallel API calling') |
parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs') |
parser.add_argument('--judge', type=str, default=None) |
parser.add_argument('--verbose', action='store_true') |
parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ') |
parser.add_argument('--reuse', action='store_true') |
args = parser.parse_args() |
return args |
def main(): |
logger = get_logger('RUN') |
rank, world_size = get_rank_and_world_size() |
args = parse_args() |
use_config, cfg = False, None |
if args.config is not None: |
assert args.data is None and args.model is None, '--data and --model should not be set when using --config' |
use_config, cfg = True, load(args.config) |
args.model = list(cfg['model'].keys()) |
args.data = list(cfg['data'].keys()) |
else: |
assert len(args.data), '--data should be a list of data files' |
if rank == 0: |
if not args.reuse: |
logger.warning('--reuse is not set, will not reuse previous (before one day) temporary files') |
else: |
logger.warning('--reuse is set, will reuse the latest prediction & temporary pickle files') |
if 'MMEVAL_ROOT' in os.environ: |
args.work_dir = os.environ['MMEVAL_ROOT'] |
if not use_config: |
for k, v in supported_VLM.items(): |
if hasattr(v, 'keywords') and 'retry' in v.keywords and args.retry is not None: |
v.keywords['retry'] = args.retry |
supported_VLM[k] = v |
if hasattr(v, 'keywords') and 'verbose' in v.keywords and args.verbose is not None: |
v.keywords['verbose'] = args.verbose |
supported_VLM[k] = v |
if world_size > 1: |
local_rank = os.environ.get('LOCAL_RANK', 0) |
torch.cuda.set_device(int(local_rank)) |
dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=3600)) |
for _, model_name in enumerate(args.model): |
model = None |
date, commit_id = timestr('day'), githash(digits=8) |
eval_id = f"T{date}_G{commit_id}" |
pred_root = osp.join(args.work_dir, model_name, eval_id) |
pred_root_meta = osp.join(args.work_dir, model_name) |
os.makedirs(pred_root_meta, exist_ok=True) |
prev_pred_roots = ls(osp.join(args.work_dir, model_name), mode='dir') |
if len(prev_pred_roots) and args.reuse: |
prev_pred_roots.sort() |
if not osp.exists(pred_root): |
os.makedirs(pred_root, exist_ok=True) |
if use_config: |
model = build_model_from_config(cfg['model'][model_name]) |
for _, dataset_name in enumerate(args.data): |
try: |
result_file_base = f'{model_name}_{dataset_name}.xlsx' |
if use_config: |
if world_size > 1: |
if rank == 0: |
dataset = build_dataset_from_config(cfg['data'][dataset_name]) |
dist.barrier() |
dataset = build_dataset_from_config(cfg['data'][dataset_name]) |
if dataset is None: |
logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ') |
continue |
else: |
dataset_kwargs = {} |
if dataset_name in ['MMLongBench_DOC', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI']: |
dataset_kwargs['model'] = model_name |
if dataset_name == 'MMBench-Video': |
dataset_kwargs['pack'] = args.pack |
if dataset_name == 'Video-MME': |
dataset_kwargs['use_subtitle'] = args.use_subtitle |
if world_size > 1: |
if rank == 0: |
dataset = build_dataset(dataset_name, **dataset_kwargs) |
dist.barrier() |
dataset = build_dataset(dataset_name, **dataset_kwargs) |
if dataset is None: |
logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ') |
continue |
if args.fps > 0: |
if dataset_name == 'MVBench': |
raise ValueError('MVBench does not support fps setting, please transfer to MVBench_MP4!') |
args.nframe = 0 |
if dataset_name in ['MMBench-Video']: |
packstr = 'pack' if args.pack else 'nopack' |
if args.nframe > 0: |
result_file_base = f'{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx' |
else: |
result_file_base = f'{model_name}_{dataset_name}_{args.fps}fps_{packstr}.xlsx' |
elif dataset.MODALITY == 'VIDEO': |
if args.pack: |
logger.info(f'{dataset_name} not support Pack Mode, directly change to unpack') |
args.pack = False |
packstr = 'pack' if args.pack else 'nopack' |
if args.nframe > 0: |
result_file_base = f'{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx' |
else: |
result_file_base = f'{model_name}_{dataset_name}_{args.fps}fps_{packstr}.xlsx' |
if dataset_name in ['Video-MME', 'LongVideoBench']: |
subtitlestr = 'subs' if args.use_subtitle else 'nosubs' |
result_file_base = result_file_base.replace('.xlsx', f'_{subtitlestr}.xlsx') |
if dataset.TYPE == 'MT': |
result_file_base = result_file_base.replace('.xlsx', '.tsv') |
result_file = osp.join(pred_root, result_file_base) |
if rank == 0 and len(prev_pred_roots): |
prev_result_file = None |
prev_pkl_file_list = [] |
for root in prev_pred_roots[::-1]: |
if osp.exists(osp.join(root, result_file_base)): |
prev_result_file = osp.join(root, result_file_base) |
break |
elif commit_id in root and len(ls(root)) and root != pred_root: |
temp_files = ls(root, match=[dataset_name, '.pkl']) |
if len(temp_files): |
prev_pkl_file_list.extend(temp_files) |
break |
if not args.reuse: |
prev_result_file = None |
prev_pkl_file_list = [] |
if prev_result_file is not None: |
logger.warning( |
f'--reuse is set, will reuse the prediction file {prev_result_file}.') |
if prev_result_file != result_file: |
shutil.copy(prev_result_file, result_file) |
elif len(prev_pkl_file_list): |
for fname in prev_pkl_file_list: |
target_path = osp.join(pred_root, osp.basename(fname)) |
if not osp.exists(target_path): |
shutil.copy(fname, target_path) |
logger.info(f'--reuse is set, will reuse the prediction pickle file {fname}.') |
else: |
logger.warning(f'File already exists: {target_path}') |
if world_size > 1: |
dist.barrier() |
if model is None: |
model = model_name |
if dataset.MODALITY == 'VIDEO': |
model = infer_data_job_video( |
model, |
work_dir=pred_root, |
model_name=model_name, |
dataset=dataset, |
nframe=args.nframe, |
pack=args.pack, |
verbose=args.verbose, |
subtitle=args.use_subtitle, |
api_nproc=args.nproc, |
fps=args.fps) |
elif dataset.TYPE == 'MT': |
model = infer_data_job_mt( |
model, |
work_dir=pred_root, |
model_name=model_name, |
dataset=dataset, |
verbose=args.verbose, |
api_nproc=args.nproc, |
ignore_failed=args.ignore) |
else: |
model = infer_data_job( |
model, |
work_dir=pred_root, |
model_name=model_name, |
dataset=dataset, |
verbose=args.verbose, |
api_nproc=args.nproc, |
ignore_failed=args.ignore) |
judge_kwargs = { |
'nproc': args.nproc, |
'verbose': args.verbose, |
'retry': args.retry if args.retry is not None else 3 |
} |
if args.retry is not None: |
judge_kwargs['retry'] = args.retry |
if args.judge is not None: |
judge_kwargs['model'] = args.judge |
else: |
if dataset.TYPE in ['MCQ', 'Y/N']: |
judge_kwargs['model'] = 'chatgpt-0125' |
elif listinstr(['MMVet', 'LLaVABench', 'MMBench-Video'], dataset_name): |
judge_kwargs['model'] = 'gpt-4-turbo' |
elif listinstr(['MathVista', 'MathVerse', 'MathVision', 'DynaMath'], dataset_name): |
judge_kwargs['model'] = 'gpt-4o-mini' |
elif listinstr(['MMLongBench', 'MMDU', 'DUDE', 'SLIDEVQA', 'MIA-Bench', 'WildVision'], dataset_name): |
judge_kwargs['model'] = 'gpt-4o' |
if rank == 0: |
logger.info(judge_kwargs) |
if world_size > 1: |
dist.barrier() |
if rank == 0: |
if dataset_name in ['MMMU_TEST']: |
result_json = MMMU_result_transfer(result_file) |
logger.info(f'Transfer MMMU_TEST result to json for official evaluation, ' |
f'json file saved in {result_json}') |
continue |
elif 'MMT-Bench_ALL' in dataset_name: |
submission_file = MMTBench_result_transfer(result_file, **judge_kwargs) |
logger.info(f'Extract options from prediction of MMT-Bench FULL split for official evaluation ' |
f'(https://eval.ai/web/challenges/challenge-page/2328/overview), ' |
f'submission file saved in {submission_file}') |
continue |
if args.mode == 'infer': |
continue |
if 'MLLMGuard_DS' in dataset_name: |
logger.info('The evaluation of MLLMGuard_DS is not supported yet. ') |
continue |
elif 'AesBench_TEST' == dataset_name: |
logger.info(f'The results are saved in {result_file}. ' |
f'Please send it to the AesBench Team via huangyipo@hotmail.com.') |
continue |
elif dataset_name in ['DocVQA_TEST', 'InfoVQA_TEST', 'Q-Bench1_TEST', 'A-Bench_TEST']: |
logger.info(f'{dataset_name} is a test split without ground-truth. ' |
'Thus only the inference part is supported for those datasets. ') |
continue |
elif dataset_name in [ |
'MMBench_TEST_CN', 'MMBench_TEST_EN', 'MMBench', 'MMBench_CN', |
'MMBench_TEST_CN_V11', 'MMBench_TEST_EN_V11', 'MMBench_V11', 'MMBench_CN_V11' |
] and not MMBenchOfficialServer(dataset_name): |
logger.error( |
f'Can not evaluate {dataset_name} on non-official servers, will skip the evaluation.') |
continue |
eval_proxy = os.environ.get('EVAL_PROXY', None) |
old_proxy = os.environ.get('HTTP_PROXY', '') |
if eval_proxy is not None: |
proxy_set(eval_proxy) |
eval_results = dataset.evaluate(result_file, **judge_kwargs) |
if eval_results is not None: |
assert isinstance(eval_results, dict) or isinstance(eval_results, pd.DataFrame) |
logger.info(f'The evaluation of model {model_name} x dataset {dataset_name} has finished! ') |
logger.info('Evaluation Results:') |
if isinstance(eval_results, dict): |
logger.info('\n' + json.dumps(eval_results, indent=4)) |
elif isinstance(eval_results, pd.DataFrame): |
if len(eval_results) < len(eval_results.columns): |
eval_results = eval_results.T |
logger.info('\n' + tabulate(eval_results)) |
if eval_proxy is not None: |
proxy_set(old_proxy) |
files = os.listdir(pred_root) |
files = [x for x in files if f'{model_name}_{dataset_name}' in x] |
for f in files: |
cwd = os.getcwd() |
file_addr = osp.join(cwd, pred_root, f) |
link_addr = osp.join(cwd, pred_root_meta, f) |
if osp.exists(link_addr) or osp.islink(link_addr): |
os.remove(link_addr) |
os.symlink(file_addr, link_addr) |
except Exception as e: |
logger.exception(f'Model {model_name} x Dataset {dataset_name} combination failed: {e}, ' |
'skipping this combination.') |
continue |
if world_size > 1: |
dist.barrier() |
if world_size > 1: |
dist.destroy_process_group() |
if __name__ == '__main__': |
load_env() |
main() |