import re import json import os import torch import torch.distributed as dist import utils def pre_caption(caption,max_words=50): caption = re.sub( r"([.!\"()*#:;~])", ' ', caption.lower(), ) caption = re.sub( r"\s{2,}", ' ', caption, ) caption = caption.rstrip('\n') caption = caption.strip(' ') #truncate caption caption_words = caption.split(' ') if len(caption_words)>max_words: caption = ' '.join(caption_words[:max_words]) return caption def pre_question(question,max_ques_words=50): question = re.sub( r"([.!\"()*#:;~])", '', question.lower(), ) question = question.rstrip(' ') #truncate question question_words = question.split(' ') if len(question_words)>max_ques_words: question = ' '.join(question_words[:max_ques_words]) return question def save_result(result, result_dir, filename, remove_duplicate=''): result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) final_result_file = os.path.join(result_dir, '%s.json'%filename) json.dump(result,open(result_file,'w')) dist.barrier() if utils.is_main_process(): # combine results from all processes result = [] for rank in range(utils.get_world_size()): result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) res = json.load(open(result_file,'r')) result += res if remove_duplicate: result_new = [] id_list = [] for res in result: if res[remove_duplicate] not in id_list: id_list.append(res[remove_duplicate]) result_new.append(res) result = result_new json.dump(result,open(final_result_file,'w')) print('result file saved to %s'%final_result_file) return final_result_file from pycocotools.coco import COCO from pycocoevalcap.eval import COCOEvalCap from torchvision.datasets.utils import download_url def coco_caption_eval(coco_gt_root, results_file, split): urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} download_url(urls[split],coco_gt_root) annotation_file = os.path.join(coco_gt_root,filenames[split]) # create coco object and coco_result object coco = COCO(annotation_file) coco_result = coco.loadRes(results_file) # create coco_eval object by taking coco and coco_result coco_eval = COCOEvalCap(coco, coco_result) # evaluate on a subset of images by setting # coco_eval.params['image_id'] = coco_result.getImgIds() # please remove this line when evaluating the full validation set # coco_eval.params['image_id'] = coco_result.getImgIds() # evaluate results # SPICE will take a few minutes the first time, but speeds up due to caching coco_eval.evaluate() # print output evaluation scores for metric, score in coco_eval.eval.items(): print(f'{metric}: {score:.3f}') return coco_eval