Spaces:
Runtime error
Runtime error
Upload 164 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- LICENSE.md +14 -0
- LICENSE_Lavis.md +14 -0
- datasets/training_datasets/video_text_data/video_instruct_100/download_script.py +94 -0
- demo_job.sh +21 -0
- environment.yml +331 -0
- eval_video.py +221 -0
- jobs_video/eval/choose_best_ckpt/choose_best_ckpt.py +14 -0
- jobs_video/eval/choose_best_ckpt/evalualtion_ckpt.sh +17 -0
- jobs_video/eval/llama2_evalualtion.sh +37 -0
- jobs_video/eval/mistral_evalualtion.sh +39 -0
- jobs_video/eval/submit_job.py +19 -0
- jobs_video/train/stage_2_llama2.sh +23 -0
- jobs_video/train/stage_2_mistral.sh +23 -0
- jobs_video/train/stage_3_llama2.sh +23 -0
- jobs_video/train/stage_3_mistral.sh +23 -0
- minigpt4/__init__.py +31 -0
- minigpt4/common/__init__.py +0 -0
- minigpt4/common/config.py +474 -0
- minigpt4/common/dist_utils.py +146 -0
- minigpt4/common/eval_utils.py +224 -0
- minigpt4/common/gradcam.py +24 -0
- minigpt4/common/logger.py +195 -0
- minigpt4/common/optims.py +119 -0
- minigpt4/common/registry.py +330 -0
- minigpt4/common/utils.py +424 -0
- minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py +89 -0
- minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py +1 -0
- minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py +192 -0
- minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py +73 -0
- minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py +1 -0
- minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py +179 -0
- minigpt4/common/vqa_tools/VQA/README.md +80 -0
- minigpt4/common/vqa_tools/__init__.py +8 -0
- minigpt4/common/vqa_tools/aokvqa/LICENSE +201 -0
- minigpt4/common/vqa_tools/aokvqa/README.md +207 -0
- minigpt4/common/vqa_tools/aokvqa/data_scripts/build_vocab.py +45 -0
- minigpt4/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py +26 -0
- minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py +50 -0
- minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py +51 -0
- minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py +62 -0
- minigpt4/common/vqa_tools/aokvqa/environment.yml +36 -0
- minigpt4/common/vqa_tools/aokvqa/evaluation/eval_predictions.py +97 -0
- minigpt4/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py +13 -0
- minigpt4/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py +31 -0
- minigpt4/common/vqa_tools/aokvqa/evaluation/remap_predictions.py +44 -0
- minigpt4/common/vqa_tools/aokvqa/gpt3/README.md +14 -0
- minigpt4/common/vqa_tools/aokvqa/gpt3/caption_inputs.py +23 -0
- minigpt4/common/vqa_tools/aokvqa/gpt3/query_gpt3.py +79 -0
- minigpt4/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py +16 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
repo_imgs/sample_1.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
repo_imgs/sample_2.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
repo_imgs/sample_3.gif filter=lfs diff=lfs merge=lfs -text
|
LICENSE.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright 2023 Deyao Zhu
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
11 |
+
|
12 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
LICENSE_Lavis.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Salesforce, Inc.
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
11 |
+
|
12 |
+
3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
datasets/training_datasets/video_text_data/video_instruct_100/download_script.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from tqdm import tqdm
|
3 |
+
from pytubefix import YouTube
|
4 |
+
|
5 |
+
import xml.etree.ElementTree as ET
|
6 |
+
import os
|
7 |
+
|
8 |
+
with open ('VideoInstruct100K.json','r') as f :
|
9 |
+
data=json.load(f)
|
10 |
+
|
11 |
+
# Usage
|
12 |
+
existed_video_id={}
|
13 |
+
for video_name in os.listdir('videos'):
|
14 |
+
video_id = video_name.split('.')[0]
|
15 |
+
existed_video_id[video_id]=True
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def download_video_with_subtitles(video_id):
|
20 |
+
# Create a YouTube object.
|
21 |
+
yt = YouTube(f'https://www.youtube.com/watch?v={video_id}')
|
22 |
+
|
23 |
+
video_filename = f"{video_id}.mp4"
|
24 |
+
video_downloaded=False
|
25 |
+
try :
|
26 |
+
# Get the video stream with the highest resolution and download the video.
|
27 |
+
stream = yt.streams.get_highest_resolution()
|
28 |
+
stream.download(output_path='videos', filename=video_filename)
|
29 |
+
video_downloaded=True
|
30 |
+
except Exception as e:
|
31 |
+
print(f"Error downloading video {video_id}: {str(e)}")
|
32 |
+
video_downloaded=False
|
33 |
+
if not video_downloaded:
|
34 |
+
return False,False
|
35 |
+
|
36 |
+
# Get the video's available captions (subtitles).
|
37 |
+
captions = yt.captions.all()
|
38 |
+
|
39 |
+
# Download the captions if available in xml format.
|
40 |
+
caption_downloaded = False
|
41 |
+
for caption in captions:
|
42 |
+
caption_code = caption.code
|
43 |
+
# select only english captions
|
44 |
+
if 'en' in caption_code:
|
45 |
+
caption.download(title=f"{video_id}", output_path='subtitles_xml',srt=False)
|
46 |
+
caption_downloaded = True
|
47 |
+
return video_downloaded,caption_downloaded
|
48 |
+
def convert_xml_vtt(xml_path, vtt_path):
|
49 |
+
# Parse the XML subtitle file
|
50 |
+
tree = ET.parse(xml_path)
|
51 |
+
root = tree.getroot()
|
52 |
+
|
53 |
+
# Initialize a list to store VTT subtitle entries
|
54 |
+
vtt_subtitle = []
|
55 |
+
|
56 |
+
# Function to convert time in milliseconds to WebVTT format
|
57 |
+
def ms_to_vtt_time(milliseconds):
|
58 |
+
seconds, milliseconds = divmod(milliseconds, 1000)
|
59 |
+
minutes, seconds = divmod(seconds, 60)
|
60 |
+
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
61 |
+
|
62 |
+
# Iterate through subtitle elements
|
63 |
+
toggle = True
|
64 |
+
for p in root.findall(".//p"):
|
65 |
+
if toggle:
|
66 |
+
start_time = int(p.get("t"))
|
67 |
+
subtitle_text = " ".join(s.text.strip() for s in p.findall(".//s"))
|
68 |
+
# duration = int(p.get("d")) if p.get("d") is not None else 0
|
69 |
+
if not toggle:
|
70 |
+
end_time = int(p.get("t"))
|
71 |
+
# Format and append the VTT entry to the list
|
72 |
+
vtt_subtitle.append(f"{ms_to_vtt_time(start_time)} --> {ms_to_vtt_time(end_time)}\n{subtitle_text}\n")
|
73 |
+
toggle = not toggle
|
74 |
+
# Join the VTT entries into a single string
|
75 |
+
vtt_content = "WEBVTT\n\n" + "\n".join(vtt_subtitle)
|
76 |
+
|
77 |
+
# Save the VTT content to a file
|
78 |
+
with open(vtt_path, "w", encoding="utf-8") as vtt_file:
|
79 |
+
vtt_file.write(vtt_content)
|
80 |
+
import os
|
81 |
+
os.makedirs('videos', exist_ok=True)
|
82 |
+
os.makedirs('subtitles_vtt', exist_ok=True)
|
83 |
+
os.makedirs('subtitles_xml', exist_ok=True)
|
84 |
+
for video_path in tqdm(data,desc='Downloading videos') :
|
85 |
+
video_id=video_path.split('/')[-1].split('.')[0]
|
86 |
+
if existed_video_id.get(video_id,False):
|
87 |
+
continue
|
88 |
+
video_downloaded,caption_downloaded=download_video_with_subtitles(video_id)
|
89 |
+
if caption_downloaded:
|
90 |
+
# convert xml to vtt
|
91 |
+
xml_file_path=f'subtitles_xml/{video_id} (a.en).xml'
|
92 |
+
convert_xml_vtt(xml_file_path,f'subtitles_vtt/{video_id}.vtt')
|
93 |
+
|
94 |
+
|
demo_job.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=batch
|
3 |
+
#SBATCH --job-name=video_demo_llama2
|
4 |
+
#SBATCH --output=video_demo_llama2.out
|
5 |
+
#SBATCH --error=video_demo_llama2.err
|
6 |
+
#SBATCH --time=0-10:30:00
|
7 |
+
#SBATCH --mem=100G
|
8 |
+
#SBATCH --gres=gpu:a100:1
|
9 |
+
#SBATCH --nodes=1
|
10 |
+
|
11 |
+
# Choose the model to test
|
12 |
+
# Mistral
|
13 |
+
# ckpt="checkpoints/video_mistral_checkpoint_last.pth"
|
14 |
+
# config="test_configs/mistral_test_config.yaml"
|
15 |
+
|
16 |
+
# Llama2
|
17 |
+
ckpt="checkpoints/video_llama_checkpoint_last.pth"
|
18 |
+
config="test_configs/llama2_test_config.yaml"
|
19 |
+
|
20 |
+
|
21 |
+
python minigpt4_video_demo.py --cfg-path $config --ckpt $ckpt
|
environment.yml
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: minigpt4_video_test_v100
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
dependencies:
|
5 |
+
- _libgcc_mutex=0.1=conda_forge
|
6 |
+
- _openmp_mutex=4.5=2_gnu
|
7 |
+
- archspec=0.2.2=pyhd8ed1ab_0
|
8 |
+
- boltons=23.1.1=pyhd8ed1ab_0
|
9 |
+
- brotli-python=1.1.0=py39h3d6467e_1
|
10 |
+
- bzip2=1.0.8=hd590300_5
|
11 |
+
- c-ares=1.25.0=hd590300_0
|
12 |
+
- ca-certificates=2024.2.2=hbcca054_0
|
13 |
+
- certifi=2024.2.2=pyhd8ed1ab_0
|
14 |
+
- cffi=1.16.0=py39h7a31438_0
|
15 |
+
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
16 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
17 |
+
- conda=23.11.0=py39hf3d152e_1
|
18 |
+
- conda-libmamba-solver=23.12.0=pyhd8ed1ab_0
|
19 |
+
- conda-package-handling=2.2.0=pyh38be061_0
|
20 |
+
- conda-package-streaming=0.9.0=pyhd8ed1ab_0
|
21 |
+
- cudatoolkit=11.8.0=h4ba93d1_12
|
22 |
+
- cudatoolkit-dev=11.7.0=h1de0b5d_6
|
23 |
+
- distro=1.9.0=pyhd8ed1ab_0
|
24 |
+
- faiss=1.7.4=py39cuda112h460e57a_0_cuda
|
25 |
+
- fmt=10.1.1=h00ab1b0_1
|
26 |
+
- freetype=2.12.1=h267a509_2
|
27 |
+
- gmp=6.1.2=hf484d3e_1000
|
28 |
+
- gnutls=3.5.19=h2a4e5f8_1
|
29 |
+
- icu=73.2=h59595ed_0
|
30 |
+
- idna=3.6=pyhd8ed1ab_0
|
31 |
+
- jsonpatch=1.33=pyhd8ed1ab_0
|
32 |
+
- jsonpointer=2.4=py39hf3d152e_3
|
33 |
+
- keyutils=1.6.1=h166bdaf_0
|
34 |
+
- krb5=1.21.2=h659d440_0
|
35 |
+
- ld_impl_linux-64=2.40=h41732ed_0
|
36 |
+
- libarchive=3.7.2=h2aa1ff5_1
|
37 |
+
- libblas=3.9.0=20_linux64_openblas
|
38 |
+
- libcblas=3.9.0=20_linux64_openblas
|
39 |
+
- libcurl=8.5.0=hca28451_0
|
40 |
+
- libedit=3.1.20191231=he28a2e2_2
|
41 |
+
- libev=4.33=hd590300_2
|
42 |
+
- libfaiss=1.7.4=cuda112hb18a002_0_cuda
|
43 |
+
- libfaiss-avx2=1.7.4=cuda112h1234567_0_cuda
|
44 |
+
- libffi=3.4.2=h7f98852_5
|
45 |
+
- libgcc-ng=13.2.0=h807b86a_3
|
46 |
+
- libgfortran-ng=13.2.0=h69a702a_3
|
47 |
+
- libgfortran5=13.2.0=ha4646dd_3
|
48 |
+
- libgomp=13.2.0=h807b86a_3
|
49 |
+
- libiconv=1.17=hd590300_2
|
50 |
+
- liblapack=3.9.0=20_linux64_openblas
|
51 |
+
- libmamba=1.5.6=had39da4_0
|
52 |
+
- libmambapy=1.5.6=py39h10defb6_0
|
53 |
+
- libnghttp2=1.58.0=h47da74e_1
|
54 |
+
- libnsl=2.0.1=hd590300_0
|
55 |
+
- libopenblas=0.3.25=pthreads_h413a1c8_0
|
56 |
+
- libpng=1.6.39=h753d276_0
|
57 |
+
- libsolv=0.7.27=hfc55251_0
|
58 |
+
- libsqlite=3.44.2=h2797004_0
|
59 |
+
- libssh2=1.11.0=h0841786_0
|
60 |
+
- libstdcxx-ng=13.2.0=h7e041cc_3
|
61 |
+
- libuuid=2.38.1=h0b41bf4_0
|
62 |
+
- libxcrypt=4.4.36=hd590300_1
|
63 |
+
- libxml2=2.12.3=h232c23b_0
|
64 |
+
- libzlib=1.2.13=hd590300_5
|
65 |
+
- lz4-c=1.9.4=hcb278e6_0
|
66 |
+
- lzo=2.10=h516909a_1000
|
67 |
+
- menuinst=2.0.1=py39hf3d152e_0
|
68 |
+
- ncurses=6.4=h59595ed_2
|
69 |
+
- nettle=3.3=0
|
70 |
+
- numpy=1.26.3=py39h474f0d3_0
|
71 |
+
- openh264=1.8.0=hdbcaa40_1000
|
72 |
+
- openssl=3.2.1=hd590300_0
|
73 |
+
- packaging=23.2=pyhd8ed1ab_0
|
74 |
+
- pip=23.3.2=pyhd8ed1ab_0
|
75 |
+
- platformdirs=4.1.0=pyhd8ed1ab_0
|
76 |
+
- pluggy=1.3.0=pyhd8ed1ab_0
|
77 |
+
- pybind11-abi=4=hd8ed1ab_3
|
78 |
+
- pycosat=0.6.6=py39hd1e30aa_0
|
79 |
+
- pycparser=2.21=pyhd8ed1ab_0
|
80 |
+
- pysocks=1.7.1=pyha2e5f31_6
|
81 |
+
- python=3.9.18=h0755675_1_cpython
|
82 |
+
- python_abi=3.9=4_cp39
|
83 |
+
- readline=8.2=h8228510_1
|
84 |
+
- reproc=14.2.4.post0=hd590300_1
|
85 |
+
- reproc-cpp=14.2.4.post0=h59595ed_1
|
86 |
+
- requests=2.31.0=pyhd8ed1ab_0
|
87 |
+
- ruamel.yaml=0.18.5=py39hd1e30aa_0
|
88 |
+
- ruamel.yaml.clib=0.2.7=py39hd1e30aa_2
|
89 |
+
- tk=8.6.13=noxft_h4845f30_101
|
90 |
+
- tqdm=4.66.1=pyhd8ed1ab_0
|
91 |
+
- urllib3=2.1.0=pyhd8ed1ab_0
|
92 |
+
- wheel=0.42.0=pyhd8ed1ab_0
|
93 |
+
- x264=1!152.20180717=h14c3975_1001
|
94 |
+
- xz=5.2.6=h166bdaf_0
|
95 |
+
- yaml-cpp=0.8.0=h59595ed_0
|
96 |
+
- zlib=1.2.13=hd590300_5
|
97 |
+
- zstandard=0.22.0=py39h6e5214e_0
|
98 |
+
- zstd=1.5.5=hfc55251_0
|
99 |
+
- pip:
|
100 |
+
- accelerate==0.25.0
|
101 |
+
- aiofiles==23.2.1
|
102 |
+
- aiohttp==3.9.1
|
103 |
+
- aiosignal==1.3.1
|
104 |
+
- altair==5.2.0
|
105 |
+
- annotated-types==0.6.0
|
106 |
+
- antlr4-python3-runtime==4.9.3
|
107 |
+
- anyio==4.2.0
|
108 |
+
- appdirs==1.4.4
|
109 |
+
- asgiref==3.7.2
|
110 |
+
- async-timeout==4.0.3
|
111 |
+
- attrs==23.2.0
|
112 |
+
- backoff==2.2.1
|
113 |
+
- bcrypt==4.1.2
|
114 |
+
- beautifulsoup4==4.12.2
|
115 |
+
- bitarray==2.9.2
|
116 |
+
- bitsandbytes==0.42.0
|
117 |
+
- bleach==6.1.0
|
118 |
+
- blinker==1.7.0
|
119 |
+
- braceexpand==0.1.7
|
120 |
+
- build==1.0.3
|
121 |
+
- cachetools==5.3.2
|
122 |
+
- chardet==5.2.0
|
123 |
+
- chroma-hnswlib==0.7.3
|
124 |
+
- chromadb==0.4.22
|
125 |
+
- click==8.1.7
|
126 |
+
- cmake==3.25.0
|
127 |
+
- colbert-ai==0.2.18
|
128 |
+
- coloredlogs==15.0.1
|
129 |
+
- contourpy==1.2.0
|
130 |
+
- cycler==0.12.1
|
131 |
+
- datasets==2.17.0
|
132 |
+
- decorator==4.4.2
|
133 |
+
- decord==0.6.0
|
134 |
+
- deprecated==1.2.14
|
135 |
+
- dill==0.3.8
|
136 |
+
- docker-pycreds==0.4.0
|
137 |
+
- docopt==0.6.2
|
138 |
+
- einops==0.7.0
|
139 |
+
- exceptiongroup==1.2.0
|
140 |
+
- faiss-gpu==1.7.2
|
141 |
+
- fastapi==0.108.0
|
142 |
+
- ffmpeg==1.4
|
143 |
+
- ffmpeg-python==0.2.0
|
144 |
+
- ffmpy==0.3.1
|
145 |
+
- filelock==3.13.1
|
146 |
+
- flash-attn==2.5.4
|
147 |
+
- flask==3.0.2
|
148 |
+
- flatbuffers==23.5.26
|
149 |
+
- fonttools==4.47.0
|
150 |
+
- frozenlist==1.4.1
|
151 |
+
- fsspec==2023.10.0
|
152 |
+
- ftfy==6.1.3
|
153 |
+
- future==0.18.3
|
154 |
+
- gdown==4.7.1
|
155 |
+
- git-python==1.0.3
|
156 |
+
- gitdb==4.0.11
|
157 |
+
- gitpython==3.1.40
|
158 |
+
- google-auth==2.26.1
|
159 |
+
- googleapis-common-protos==1.62.0
|
160 |
+
- gradio
|
161 |
+
- gradio-client
|
162 |
+
- h11==0.14.0
|
163 |
+
- h5py==3.10.0
|
164 |
+
- httpcore==1.0.2
|
165 |
+
- httptools==0.6.1
|
166 |
+
- httpx==0.26.0
|
167 |
+
- huggingface-hub==0.21.1
|
168 |
+
- humanfriendly==10.0
|
169 |
+
- imageio==2.33.1
|
170 |
+
- imageio-ffmpeg==0.4.9
|
171 |
+
- importlib-metadata==6.11.0
|
172 |
+
- importlib-resources==6.1.1
|
173 |
+
- inquirerpy==0.3.4
|
174 |
+
- iopath==0.1.10
|
175 |
+
- itsdangerous==2.1.2
|
176 |
+
- jinja2==3.1.2
|
177 |
+
- joblib==1.3.2
|
178 |
+
- jsonschema==4.20.0
|
179 |
+
- jsonschema-specifications==2023.12.1
|
180 |
+
- kaggle==1.6.0
|
181 |
+
- kiwisolver==1.4.5
|
182 |
+
- kubernetes==29.0.0
|
183 |
+
- lazy-loader==0.3
|
184 |
+
- lit==15.0.7
|
185 |
+
- llvmlite==0.41.1
|
186 |
+
- markdown-it-py==3.0.0
|
187 |
+
- matplotlib==3.8.2
|
188 |
+
- mdurl==0.1.2
|
189 |
+
- mmh3==4.1.0
|
190 |
+
- monotonic==1.6
|
191 |
+
- more-itertools==10.1.0
|
192 |
+
- moviepy==1.0.3
|
193 |
+
- mpmath==1.3.0
|
194 |
+
- multidict==6.0.4
|
195 |
+
- multiprocess==0.70.16
|
196 |
+
- mutagen==1.47.0
|
197 |
+
- networkx==3.2.1
|
198 |
+
- ninja==1.11.1.1
|
199 |
+
- nltk==3.8.1
|
200 |
+
- numba==0.58.1
|
201 |
+
- nvidia-cublas-cu11==11.10.3.66
|
202 |
+
- nvidia-cublas-cu12==12.1.3.1
|
203 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
204 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
205 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
206 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
207 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
208 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
209 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
210 |
+
- nvidia-cufft-cu12==11.0.2.54
|
211 |
+
- nvidia-curand-cu12==10.3.2.106
|
212 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
213 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
214 |
+
- nvidia-nccl-cu12==2.18.1
|
215 |
+
- nvidia-nvjitlink-cu12==12.3.101
|
216 |
+
- nvidia-nvtx-cu12==12.1.105
|
217 |
+
- omegaconf==2.3.0
|
218 |
+
- onnxruntime==1.16.3
|
219 |
+
- openai==0.28.0
|
220 |
+
- openai-whisper==20231117
|
221 |
+
- opencv-python==4.7.0.72
|
222 |
+
- opentelemetry-api==1.22.0
|
223 |
+
- opentelemetry-exporter-otlp-proto-common==1.22.0
|
224 |
+
- opentelemetry-exporter-otlp-proto-grpc==1.22.0
|
225 |
+
- opentelemetry-instrumentation==0.43b0
|
226 |
+
- opentelemetry-instrumentation-asgi==0.43b0
|
227 |
+
- opentelemetry-instrumentation-fastapi==0.43b0
|
228 |
+
- opentelemetry-proto==1.22.0
|
229 |
+
- opentelemetry-sdk==1.22.0
|
230 |
+
- opentelemetry-semantic-conventions==0.43b0
|
231 |
+
- opentelemetry-util-http==0.43b0
|
232 |
+
- orjson==3.9.10
|
233 |
+
- overrides==7.4.0
|
234 |
+
- pandas==2.0.0
|
235 |
+
- pathtools==0.1.2
|
236 |
+
- peft==0.2.0
|
237 |
+
- pfzy==0.3.4
|
238 |
+
- pillow==10.2.0
|
239 |
+
- plotly==5.18.0
|
240 |
+
- portalocker==2.8.2
|
241 |
+
- posthog==3.3.0
|
242 |
+
- proglog==0.1.10
|
243 |
+
- progressbar2==4.3.2
|
244 |
+
- prompt-toolkit==3.0.43
|
245 |
+
- protobuf==4.25.1
|
246 |
+
- psutil==5.9.7
|
247 |
+
- pulsar-client==3.4.0
|
248 |
+
- pyarrow==15.0.0
|
249 |
+
- pyarrow-hotfix==0.6
|
250 |
+
- pyasn1==0.5.1
|
251 |
+
- pyasn1-modules==0.3.0
|
252 |
+
- pycocoevalcap==1.2
|
253 |
+
- pycocotools==2.0.6
|
254 |
+
- pycryptodomex==3.19.1
|
255 |
+
- pydantic==2.5.3
|
256 |
+
- pydantic-core==2.14.6
|
257 |
+
- pydub==0.25.1
|
258 |
+
- pygments==2.17.2
|
259 |
+
- pyparsing==3.1.1
|
260 |
+
- pypika==0.48.9
|
261 |
+
- pyproject-hooks==1.0.0
|
262 |
+
- pysrt==1.1.2
|
263 |
+
- python-dateutil==2.8.2
|
264 |
+
- python-dotenv==1.0.0
|
265 |
+
- python-multipart==0.0.6
|
266 |
+
- python-slugify==8.0.1
|
267 |
+
- python-utils==3.8.1
|
268 |
+
- pytubefix
|
269 |
+
- pytz==2023.3.post1
|
270 |
+
- pyyaml==6.0.1
|
271 |
+
- referencing==0.32.0
|
272 |
+
- regex==2023.12.25
|
273 |
+
- rich==13.7.0
|
274 |
+
- rouge==1.0.1
|
275 |
+
- rpds-py==0.16.2
|
276 |
+
- rsa==4.9
|
277 |
+
- safetensors==0.4.1
|
278 |
+
- scikit-image==0.22.0
|
279 |
+
- scikit-learn==1.3.2
|
280 |
+
- scipy==1.11.4
|
281 |
+
- seaborn==0.13.1
|
282 |
+
- semantic-version==2.10.0
|
283 |
+
- sentence-transformers==2.2.2
|
284 |
+
- sentencepiece==0.1.97
|
285 |
+
- sentry-sdk==1.39.1
|
286 |
+
- setproctitle==1.3.3
|
287 |
+
- setuptools==69.0.3
|
288 |
+
- shellingham==1.5.4
|
289 |
+
- six==1.16.0
|
290 |
+
- smmap==5.0.1
|
291 |
+
- sniffio==1.3.0
|
292 |
+
- soundfile==0.12.1
|
293 |
+
- soupsieve==2.5
|
294 |
+
- starlette==0.32.0.post1
|
295 |
+
- sympy==1.12
|
296 |
+
- tenacity==8.2.3
|
297 |
+
- text-unidecode==1.3
|
298 |
+
- threadpoolctl==3.2.0
|
299 |
+
- tifffile==2023.12.9
|
300 |
+
- tiktoken==0.5.2
|
301 |
+
- timm==0.6.13
|
302 |
+
- tokenizers==0.15.2
|
303 |
+
- tomli==2.0.1
|
304 |
+
- tomlkit==0.12.0
|
305 |
+
- toolz==0.12.0
|
306 |
+
- torch==2.0.1
|
307 |
+
- torchaudio==2.0.2
|
308 |
+
- torchvision==0.15.2
|
309 |
+
- transformers==4.37.2
|
310 |
+
- triton==2.0.0
|
311 |
+
- typer==0.9.0
|
312 |
+
- typing-extensions==4.9.0
|
313 |
+
- tzdata==2023.4
|
314 |
+
- ujson==5.9.0
|
315 |
+
- uvicorn==0.25.0
|
316 |
+
- uvloop==0.19.0
|
317 |
+
- visual-genome==1.1.1
|
318 |
+
- wandb==0.14.2
|
319 |
+
- watchfiles==0.21.0
|
320 |
+
- wcwidth==0.2.13
|
321 |
+
- webdataset==0.2.48
|
322 |
+
- webencodings==0.5.1
|
323 |
+
- websocket-client==1.7.0
|
324 |
+
- websockets
|
325 |
+
- webvtt-py==0.4.6
|
326 |
+
- wrapt==1.16.0
|
327 |
+
- xxhash==3.4.1
|
328 |
+
- yarl==1.9.4
|
329 |
+
- youtube-dl==2021.12.17
|
330 |
+
- yt-dlp
|
331 |
+
- zipp
|
eval_video.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from tqdm import tqdm
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
|
6 |
+
from minigpt4.conversation.conversation import CONV_VISION
|
7 |
+
from minigpt4.processors.blip_processors import Blip2ImageTrainProcessor,BlipCaptionProcessor
|
8 |
+
from minigpt4.datasets.datasets.video_datasets import VideoChatGPTEvalDataset,VideoChatGPTEval_consistancy,Video_validation_Dataset,TVQAEVAL,TVQAEVAL_Long
|
9 |
+
|
10 |
+
parser = eval_parser()
|
11 |
+
parser.add_argument("--dataset", type=str, default='msvd', help="dataset to evaluate")
|
12 |
+
parser.add_argument("--add_subtitles",action='store_true',help="whether to add subtitles to the video")
|
13 |
+
parser.add_argument("--name", type=str, default='3_datasets', help="evaluation name")
|
14 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
15 |
+
parser.add_argument("--start", type=int, default=0, help="start from video number")
|
16 |
+
parser.add_argument("--end", type=int, default=10000000, help="end at video number")
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
print(args.ckpt)
|
20 |
+
print(args.name)
|
21 |
+
print(args.cfg_path)
|
22 |
+
if "test_configs/mistral_test_config.yaml" == args.cfg_path:
|
23 |
+
llm_name="mistral"
|
24 |
+
else:
|
25 |
+
llm_name="llama2"
|
26 |
+
print("using captions",args.add_subtitles)
|
27 |
+
|
28 |
+
model, vis_processor = init_model(args)
|
29 |
+
conv_temp = CONV_VISION.copy()
|
30 |
+
conv_temp.system = ""
|
31 |
+
if args.dataset == 'video_chatgpt_generic':
|
32 |
+
ann_path="datasets/evaluation_datasets/videochatgpt_benchmark/generic_qa.json"
|
33 |
+
videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Test_Videos"
|
34 |
+
subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
|
35 |
+
videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/benchmark/generic"
|
36 |
+
annotations_keys=['Q','A','video_name']
|
37 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
38 |
+
elif args.dataset == 'video_chatgpt_temporal':
|
39 |
+
ann_path="datasets/evaluation_datasets/videochatgpt_benchmark/temporal_qa.json"
|
40 |
+
videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Test_Videos"
|
41 |
+
subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
|
42 |
+
videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/benchmark/temporal"
|
43 |
+
annotations_keys=['Q','A','video_name']
|
44 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
45 |
+
elif args.dataset == 'video_chatgpt_consistency':
|
46 |
+
ann_path="datasets/evaluation_datasets/videochatgpt_benchmark/consistency_qa.json"
|
47 |
+
videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Test_Videos"
|
48 |
+
subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
|
49 |
+
annotations_keys=[['Q1','Q2'],'A','video_name']
|
50 |
+
data = VideoChatGPTEval_consistancy(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
51 |
+
|
52 |
+
elif args.dataset == 'msrvtt':
|
53 |
+
ann_path="datasets/evaluation_datasets/msrvtt/val_qa_edited.json"
|
54 |
+
videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/MSRVTT/videos/all"
|
55 |
+
subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
|
56 |
+
videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/msrvtt"
|
57 |
+
annotations_keys=['question','answer','video_id']
|
58 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
59 |
+
|
60 |
+
elif args.dataset == 'msvd':
|
61 |
+
ann_path="datasets/evaluation_datasets/msvd/val_qa_edited.json"
|
62 |
+
videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/MSVD-QA/videos"
|
63 |
+
subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
|
64 |
+
videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/msvd"
|
65 |
+
annotations_keys=['question','answer','video_id']
|
66 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
67 |
+
elif args.dataset == 'activitynet':
|
68 |
+
ann_path="datasets/evaluation_datasets/activityNet/test_qa.json"
|
69 |
+
videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Activity_net/Activity_net_videos"
|
70 |
+
subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles/"
|
71 |
+
videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/activity_net"
|
72 |
+
annotations_keys=['question','answer','video_id']
|
73 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
74 |
+
elif args.dataset == 'tgif':
|
75 |
+
ann_path="datasets/evaluation_datasets/tgif/Test_frameqa_question.json"
|
76 |
+
videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/TGIF/mp4s"
|
77 |
+
subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles"
|
78 |
+
videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/tgif"
|
79 |
+
annotations_keys=['question','answer','gif_name']
|
80 |
+
# annotations_keys=['question','description','gif_name']
|
81 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=False,llm_name=llm_name)
|
82 |
+
elif args.dataset == 'tvqa':
|
83 |
+
# TVQA dataset
|
84 |
+
ann_path="datasets/evaluation_datasets/tvqa_short/tvqa_val.json"
|
85 |
+
videos_path= "/ibex/project/c2090/datasets/TVR_dataset/videos/video_files/frames_hq/"
|
86 |
+
subtitles_path="/ibex/project/c2090/datasets/TVR_dataset/TVRetrieval/data/tvqa_preprocessed_subtitles.json"
|
87 |
+
videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/tvqa"
|
88 |
+
data = TVQAEVAL(vis_processor, videos_path, ann_path,subtitles_path,videos_features_path,add_subtitles=args.add_subtitles,llm_name=llm_name)
|
89 |
+
|
90 |
+
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
91 |
+
|
92 |
+
minigpt4_predict = []
|
93 |
+
sub="subtitles" if args.add_subtitles else "no_subtitles"
|
94 |
+
if args.start == 0 and args.end == 10000000:
|
95 |
+
save_path = f'results/{args.name}_{args.dataset}_{sub}.json'
|
96 |
+
else:
|
97 |
+
print("start from video number",args.start)
|
98 |
+
print("end at video number",args.end)
|
99 |
+
save_path = f'results/{args.name}_{args.dataset}_{sub}_{args.start}_{args.end}.json'
|
100 |
+
|
101 |
+
os.makedirs("results", exist_ok=True)
|
102 |
+
c=0
|
103 |
+
pred_result = {}
|
104 |
+
gt_result = {}
|
105 |
+
if args.dataset == 'video_chatgpt_consistency':
|
106 |
+
for images, texts_1,texts_2, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
|
107 |
+
if args.start<= c <args.end :
|
108 |
+
texts_q1 = prepare_texts(texts_1, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
|
109 |
+
texts_q2 = prepare_texts(texts_2, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
|
110 |
+
models_answers_q1 = model.generate(images, texts_q1, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
|
111 |
+
models_answers_q2 = model.generate(images, texts_q2, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
|
112 |
+
for video_id,model_answer_q1,model_answer_q2, gt_answer,text_q1,text_q2 in zip(videos_ids,models_answers_q1,models_answers_q2, gt_answers,texts_q1,texts_q2):
|
113 |
+
result = dict()
|
114 |
+
result['video_name'] = video_id
|
115 |
+
result['Q1'] = text_q1.split('\n')[-1].replace('[/INST]','')
|
116 |
+
result['Q2'] = text_q2.split('\n')[-1].replace('[/INST]','')
|
117 |
+
result['A'] = gt_answer
|
118 |
+
result['pred1'] = model_answer_q1
|
119 |
+
result['pred2'] = model_answer_q2
|
120 |
+
pred_result[video_id] = [model_answer_q1,model_answer_q2]
|
121 |
+
gt_result[video_id] = [gt_answer]
|
122 |
+
minigpt4_predict.append(result)
|
123 |
+
# save results every 100 videos to avoid losing results
|
124 |
+
if c%100==0:
|
125 |
+
with open(save_path, 'w') as f:
|
126 |
+
json.dump(minigpt4_predict, f)
|
127 |
+
if c >= args.end :
|
128 |
+
break
|
129 |
+
c+=1
|
130 |
+
|
131 |
+
elif args.dataset == 'tvr':
|
132 |
+
for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
|
133 |
+
if args.start<= c <args.end :
|
134 |
+
texts = prepare_texts(texts, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
|
135 |
+
models_answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
|
136 |
+
for video_id,model_answer, gt_answer,text in zip(videos_ids,models_answers, gt_answers,texts):
|
137 |
+
result = dict()
|
138 |
+
result['video_name'] = video_id
|
139 |
+
result['Q'] = text.split('\n')[-1].replace('[/INST]','')
|
140 |
+
result['A'] = gt_answer
|
141 |
+
result['pred'] = model_answer
|
142 |
+
pred_result[video_id] = [model_answer]
|
143 |
+
gt_result[video_id] = [gt_answer]
|
144 |
+
minigpt4_predict.append(result)
|
145 |
+
# save results every 100 videos to avoid losing results
|
146 |
+
if c%100==0:
|
147 |
+
with open(save_path, 'w') as f:
|
148 |
+
json.dump(minigpt4_predict, f)
|
149 |
+
if c >= args.end :
|
150 |
+
break
|
151 |
+
c+=1
|
152 |
+
elif args.dataset == 'ego_schema' or args.dataset == 'tvqa' or args.dataset == 'tvqa_long_videos':
|
153 |
+
for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
|
154 |
+
if args.start<= c <args.end :
|
155 |
+
texts = prepare_texts(texts, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
|
156 |
+
models_answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
|
157 |
+
for video_id,model_answer, gt_answer,text in zip(videos_ids,models_answers, gt_answers,texts):
|
158 |
+
result = dict()
|
159 |
+
result['video_name'] = video_id
|
160 |
+
if args.dataset == 'tvqa_long_videos':
|
161 |
+
result['Q'] = text.split('\n\n')[1:]
|
162 |
+
else:
|
163 |
+
result['Q'] = text.split('\n')[1:]
|
164 |
+
result['A'] = gt_answer
|
165 |
+
result['pred'] = model_answer
|
166 |
+
pred_result[video_id] = [model_answer]
|
167 |
+
gt_result[video_id] = [gt_answer]
|
168 |
+
minigpt4_predict.append(result)
|
169 |
+
# save results every 100 videos to avoid losing results
|
170 |
+
if c%100==0:
|
171 |
+
with open(save_path, 'w') as f:
|
172 |
+
json.dump(minigpt4_predict, f)
|
173 |
+
if c >= args.end :
|
174 |
+
break
|
175 |
+
c+=1
|
176 |
+
else:
|
177 |
+
for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
|
178 |
+
if args.start<= c <args.end :
|
179 |
+
texts = prepare_texts(texts, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
|
180 |
+
models_answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
|
181 |
+
for video_id,model_answer, gt_answer,text in zip(videos_ids,models_answers, gt_answers,texts):
|
182 |
+
result = dict()
|
183 |
+
result['video_name'] = video_id
|
184 |
+
result['Q'] = text.split('\n')[-1].replace('[/INST]','')
|
185 |
+
result['A'] = gt_answer
|
186 |
+
result['pred'] = model_answer
|
187 |
+
pred_result[video_id] = [model_answer]
|
188 |
+
gt_result[video_id] = [gt_answer]
|
189 |
+
minigpt4_predict.append(result)
|
190 |
+
# save results every 100 videos to avoid losing results
|
191 |
+
if c%100==0:
|
192 |
+
with open(save_path, 'w') as f:
|
193 |
+
json.dump(minigpt4_predict, f)
|
194 |
+
if c >= args.end :
|
195 |
+
break
|
196 |
+
c+=1
|
197 |
+
|
198 |
+
with open(save_path, 'w') as f:
|
199 |
+
json.dump(minigpt4_predict, f)
|
200 |
+
print("saved results to",save_path)
|
201 |
+
# save results
|
202 |
+
# bleu_save_path = f'results/{args.name}_{args.dataset}_bleu.json'
|
203 |
+
# cider_save_path = f'results/{args.name}_{args.dataset}_cider.json'
|
204 |
+
# chatgpt_eval_save_path = f'results/{args.name}_{args.dataset}_chatgpt_eval.json'
|
205 |
+
# bleu_results=eval_bleu(minigpt4_predict)
|
206 |
+
# with open(bleu_save_path, 'w') as f:
|
207 |
+
# json.dump(bleu_results, f)
|
208 |
+
# print("bleu_results",bleu_results)
|
209 |
+
# cider_results=eval_cider(pred_result,gt_result)
|
210 |
+
# with open(cider_save_path, 'w') as f:
|
211 |
+
# json.dump(cider_results, f)
|
212 |
+
# print("mean_cider_scores:",cider_results['mean_cider_scores'])
|
213 |
+
|
214 |
+
# chatgpt_results=chat_gpt_eval(pred_result,gt_result)
|
215 |
+
|
216 |
+
# with open(chatgpt_eval_save_path, 'w') as f:
|
217 |
+
# json.dump(chatgpt_results, f)
|
218 |
+
# print("avg_chatgpt_score",chatgpt_results['avg_chatgpt_score'])
|
219 |
+
# print(chatgpt_results)
|
220 |
+
|
221 |
+
|
jobs_video/eval/choose_best_ckpt/choose_best_ckpt.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
ckpt_dir = 'ckpt_dir'
|
4 |
+
print(f'number of ckpts: {len(os.listdir(ckpt_dir))}')
|
5 |
+
for ckpt in sorted(os.listdir(ckpt_dir)):
|
6 |
+
if not ckpt.endswith('.pth'):
|
7 |
+
continue
|
8 |
+
ckpt_path = os.path.join(ckpt_dir,ckpt)
|
9 |
+
job_name="cmd_webvid_video_instruct_"+ckpt.split(".")[0]
|
10 |
+
# submit a job with this ckpt file
|
11 |
+
os.system(f'sbatch ./evalualtion_ckpt.sh {ckpt_path} {job_name}')
|
12 |
+
# print(f'sbatch ./evalualtion_ckpt.sh {ckpt_path} {job_name}')
|
13 |
+
# print(f'job {job_name} submitted')
|
14 |
+
# break
|
jobs_video/eval/choose_best_ckpt/evalualtion_ckpt.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=batch
|
3 |
+
#SBATCH --job-name=val%j
|
4 |
+
#SBATCH --output=val%j.out
|
5 |
+
#SBATCH --error=val%j.err
|
6 |
+
#SBATCH --time=0-10:00:00
|
7 |
+
#SBATCH --mem=100G
|
8 |
+
#SBATCH --gres=gpu:a100:1
|
9 |
+
#SBATCH --nodes=1
|
10 |
+
## run the application:
|
11 |
+
NAME=$2 # Name of the experiment
|
12 |
+
DATASET="dataset_name" # available datasets: tvqa, msrvtt, msvd, activitynet,tgif,video_chatgpt_generic,video_chatgpt_temporal,video_chatgpt_consistency
|
13 |
+
BATCH_SIZE=2 # batch size
|
14 |
+
CKPT_PATH=$1 # path to the checkpoint
|
15 |
+
cfg_path="test_configs/mistral_test_config.yaml" # path to the config file
|
16 |
+
cd ../../../
|
17 |
+
python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --add_subtitles
|
jobs_video/eval/llama2_evalualtion.sh
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=batch
|
3 |
+
#SBATCH --job-name=llama2_best%j
|
4 |
+
#SBATCH --output=llama2_best%j.out
|
5 |
+
#SBATCH --error=llama2_best%j.err
|
6 |
+
#SBATCH --time=0-23:00:00
|
7 |
+
#SBATCH --mem=100G
|
8 |
+
#SBATCH --gres=gpu:a100:1
|
9 |
+
#SBATCH --nodes=1
|
10 |
+
## run the application:
|
11 |
+
NAME="llama2_best" # Name of the experiment
|
12 |
+
DATASET="tvqa" # available datasets: tvqa, msrvtt, msvd, activitynet,tgif ,video_chatgpt_generic,video_chatgpt_temporal,video_chatgpt_consistency
|
13 |
+
BATCH_SIZE=8
|
14 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth" # path to the checkpoint
|
15 |
+
cfg_path="test_configs/llama2_test_config.yaml" # path to the config file
|
16 |
+
# # if the number of samples are large you can specify the start and end index to evaluate on several machines
|
17 |
+
# pass the start and end index as arguments
|
18 |
+
start=$1 # start index
|
19 |
+
end=$2 # end index
|
20 |
+
# if start and end are not provided, then use the whole dataset
|
21 |
+
if [ -z "$START" ]
|
22 |
+
then
|
23 |
+
START=0
|
24 |
+
fi
|
25 |
+
if [ -z "$END" ]
|
26 |
+
then
|
27 |
+
END=10000000
|
28 |
+
fi
|
29 |
+
echo "Start: $START"
|
30 |
+
echo "End: $END"
|
31 |
+
|
32 |
+
cd ../../
|
33 |
+
# without subtitles
|
34 |
+
python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --start $start --end $end
|
35 |
+
|
36 |
+
# with subtitles
|
37 |
+
# python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --add_subtitles --start $start --end $end
|
jobs_video/eval/mistral_evalualtion.sh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=batch
|
3 |
+
#SBATCH --mail-user=kirolos.ataallah@kaust.edu.sa
|
4 |
+
#SBATCH --mail-type=ALL
|
5 |
+
#SBATCH --job-name=mistral_best%j
|
6 |
+
#SBATCH --output=mistral_best%j.out
|
7 |
+
#SBATCH --error=mistral_best%j.err
|
8 |
+
#SBATCH --time=0-23:00:00
|
9 |
+
#SBATCH --mem=100G
|
10 |
+
#SBATCH --gres=gpu:a100:1
|
11 |
+
#SBATCH --nodes=1
|
12 |
+
## run the application:
|
13 |
+
NAME="mistral_best" # Name of the experiment
|
14 |
+
DATASET="tvqa" # available datasets: tvqa, msrvtt, msvd, activitynet,tgif,video_chatgpt_generic,video_chatgpt_temporal,video_chatgpt_consistency
|
15 |
+
BATCH_SIZE=4 # batch size for A100 by using subtiles is 2 and without subtitles is 4
|
16 |
+
CKPT_PATH="checkpoints/video_mistral_checkpoint_best.pth" # path to the checkpoint
|
17 |
+
cfg_path="test_configs/mistral_test_config.yaml" # path to the config file
|
18 |
+
# # if the number of samples are large you can specify the start and end index to evaluate on several machines
|
19 |
+
# pass the start and end index as arguments
|
20 |
+
start=$1 # start index
|
21 |
+
end=$2 # end index
|
22 |
+
# if start and end are not provided, then use the whole dataset
|
23 |
+
if [ -z "$START" ]
|
24 |
+
then
|
25 |
+
START=0
|
26 |
+
fi
|
27 |
+
if [ -z "$END" ]
|
28 |
+
then
|
29 |
+
END=10000000
|
30 |
+
fi
|
31 |
+
echo "Start: $START"
|
32 |
+
echo "End: $END"
|
33 |
+
|
34 |
+
cd ../../
|
35 |
+
# without subtitles
|
36 |
+
python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --start $start --end $end
|
37 |
+
|
38 |
+
# with subtitles
|
39 |
+
# python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --add_subtitles --start $start --end $end
|
jobs_video/eval/submit_job.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import sys
|
4 |
+
|
5 |
+
start=0
|
6 |
+
end=7800
|
7 |
+
step=800
|
8 |
+
|
9 |
+
# Mistral
|
10 |
+
for i in range(start,end,step):
|
11 |
+
cmd=f'sbatch ./mistral_evalualtion.sh {i} {i+step}'
|
12 |
+
# print(cmd)
|
13 |
+
os.system(cmd)
|
14 |
+
|
15 |
+
# Llama 2
|
16 |
+
# for i in range(start,end,step):
|
17 |
+
# cmd=f'sbatch ./llama2_evalualtion.sh {i} {i+step}'
|
18 |
+
# # print(cmd)
|
19 |
+
# os.system(cmd)
|
jobs_video/train/stage_2_llama2.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=batch
|
3 |
+
#SBATCH --job-name=test
|
4 |
+
#SBATCH --output=test.out
|
5 |
+
#SBATCH --error=test.err
|
6 |
+
#SBATCH --time=23:00:00
|
7 |
+
#SBATCH --mem=110G
|
8 |
+
#SBATCH --gres=gpu:a100:4
|
9 |
+
#SBATCH --cpus-per-task=16
|
10 |
+
## run the application:
|
11 |
+
job_name=test # Name of the experiment
|
12 |
+
cfg_path="train_configs/224_v2_llama2_video_stage_2.yaml" # path to the config file
|
13 |
+
number_of_gpus=1 # number of gpus
|
14 |
+
# cd ../../
|
15 |
+
|
16 |
+
read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range
|
17 |
+
while :
|
18 |
+
do
|
19 |
+
PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`"
|
20 |
+
ss -lpn | grep -q ":$PORT " || break
|
21 |
+
done
|
22 |
+
echo "Port is $PORT"
|
23 |
+
torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path}
|
jobs_video/train/stage_2_mistral.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=batch
|
3 |
+
#SBATCH --job-name=test
|
4 |
+
#SBATCH --output=test.out
|
5 |
+
#SBATCH --error=test.err
|
6 |
+
#SBATCH --time=23:00:00
|
7 |
+
#SBATCH --mem=110G
|
8 |
+
#SBATCH --gres=gpu:a100:4
|
9 |
+
#SBATCH --cpus-per-task=16
|
10 |
+
## run the application:
|
11 |
+
job_name=test # Name of the experiment
|
12 |
+
cfg_path="train_configs/224_v2_mistral_video_stage_2.yaml" # path to the config file
|
13 |
+
number_of_gpus=1 # number of gpus
|
14 |
+
# cd ../../
|
15 |
+
|
16 |
+
read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range
|
17 |
+
while :
|
18 |
+
do
|
19 |
+
PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`"
|
20 |
+
ss -lpn | grep -q ":$PORT " || break
|
21 |
+
done
|
22 |
+
echo "Port is $PORT"
|
23 |
+
torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path}
|
jobs_video/train/stage_3_llama2.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=batch
|
3 |
+
#SBATCH --job-name=test
|
4 |
+
#SBATCH --output=test.out
|
5 |
+
#SBATCH --error=test.err
|
6 |
+
#SBATCH --time=23:00:00
|
7 |
+
#SBATCH --mem=110G
|
8 |
+
#SBATCH --gres=gpu:a100:4
|
9 |
+
#SBATCH --cpus-per-task=16
|
10 |
+
## run the application:
|
11 |
+
job_name="test" # Name of the experiment
|
12 |
+
cfg_path="train_configs/224_v2_llama2_video_stage_3.yaml" # path to the config file
|
13 |
+
number_of_gpus=1 # number of gpus
|
14 |
+
# cd ../../
|
15 |
+
|
16 |
+
read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range
|
17 |
+
while :
|
18 |
+
do
|
19 |
+
PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`"
|
20 |
+
ss -lpn | grep -q ":$PORT " || break
|
21 |
+
done
|
22 |
+
echo "Port is $PORT"
|
23 |
+
torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path}
|
jobs_video/train/stage_3_mistral.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=batch
|
3 |
+
#SBATCH --job-name=test
|
4 |
+
#SBATCH --output=test.out
|
5 |
+
#SBATCH --error=test.err
|
6 |
+
#SBATCH --time=23:00:00
|
7 |
+
#SBATCH --mem=110G
|
8 |
+
#SBATCH --gres=gpu:a100:4
|
9 |
+
#SBATCH --cpus-per-task=16
|
10 |
+
## run the application:
|
11 |
+
job_name="test" # Name of the experiment
|
12 |
+
cfg_path="train_configs/224_v2_mistral_video_stage_3.yaml" # path to the config file
|
13 |
+
number_of_gpus=1 # number of gpus
|
14 |
+
# cd ../../
|
15 |
+
|
16 |
+
read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range
|
17 |
+
while :
|
18 |
+
do
|
19 |
+
PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`"
|
20 |
+
ss -lpn | grep -q ":$PORT " || break
|
21 |
+
done
|
22 |
+
echo "Port is $PORT"
|
23 |
+
torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path}
|
minigpt4/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
|
13 |
+
from minigpt4.common.registry import registry
|
14 |
+
|
15 |
+
from minigpt4.datasets.builders import *
|
16 |
+
from minigpt4.models import *
|
17 |
+
from minigpt4.processors import *
|
18 |
+
from minigpt4.tasks import *
|
19 |
+
|
20 |
+
|
21 |
+
root_dir = os.path.dirname(os.path.abspath(__file__))
|
22 |
+
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
|
23 |
+
|
24 |
+
registry.register_path("library_root", root_dir)
|
25 |
+
repo_root = os.path.join(root_dir, "..")
|
26 |
+
registry.register_path("repo_root", repo_root)
|
27 |
+
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
|
28 |
+
registry.register_path("cache_root", cache_root)
|
29 |
+
|
30 |
+
registry.register("MAX_INT", sys.maxsize)
|
31 |
+
registry.register("SPLIT_NAMES", ["train", "val", "test"])
|
minigpt4/common/__init__.py
ADDED
File without changes
|
minigpt4/common/config.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import json
|
10 |
+
from typing import Dict
|
11 |
+
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from minigpt4.common.registry import registry
|
14 |
+
|
15 |
+
|
16 |
+
class Config:
|
17 |
+
def __init__(self, args):
|
18 |
+
self.config = {}
|
19 |
+
|
20 |
+
self.args = args
|
21 |
+
|
22 |
+
# Register the config and configuration for setup
|
23 |
+
registry.register("configuration", self)
|
24 |
+
|
25 |
+
user_config = self._build_opt_list(self.args.options)
|
26 |
+
|
27 |
+
config = OmegaConf.load(self.args.cfg_path)
|
28 |
+
|
29 |
+
runner_config = self.build_runner_config(config)
|
30 |
+
model_config = self.build_model_config(config, **user_config)
|
31 |
+
dataset_config = self.build_dataset_config(config)
|
32 |
+
|
33 |
+
# Validate the user-provided runner configuration
|
34 |
+
# model and dataset configuration are supposed to be validated by the respective classes
|
35 |
+
# [TODO] validate the model/dataset configuration
|
36 |
+
# self._validate_runner_config(runner_config)
|
37 |
+
|
38 |
+
# Override the default configuration with user options.
|
39 |
+
self.config = OmegaConf.merge(
|
40 |
+
runner_config, model_config, dataset_config, user_config
|
41 |
+
)
|
42 |
+
|
43 |
+
def _validate_runner_config(self, runner_config):
|
44 |
+
"""
|
45 |
+
This method validates the configuration, such that
|
46 |
+
1) all the user specified options are valid;
|
47 |
+
2) no type mismatches between the user specified options and the config.
|
48 |
+
"""
|
49 |
+
runner_config_validator = create_runner_config_validator()
|
50 |
+
runner_config_validator.validate(runner_config)
|
51 |
+
|
52 |
+
def _build_opt_list(self, opts):
|
53 |
+
opts_dot_list = self._convert_to_dot_list(opts)
|
54 |
+
return OmegaConf.from_dotlist(opts_dot_list)
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def build_model_config(config, **kwargs):
|
58 |
+
model = config.get("model", None)
|
59 |
+
assert model is not None, "Missing model configuration file."
|
60 |
+
|
61 |
+
model_cls = registry.get_model_class(model.arch)
|
62 |
+
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
|
63 |
+
|
64 |
+
model_type = kwargs.get("model.model_type", None)
|
65 |
+
if not model_type:
|
66 |
+
model_type = model.get("model_type", None)
|
67 |
+
# else use the model type selected by user.
|
68 |
+
|
69 |
+
assert model_type is not None, "Missing model_type."
|
70 |
+
|
71 |
+
print("--------------")
|
72 |
+
print("model arch",model.arch)
|
73 |
+
print("model cls",model_cls)
|
74 |
+
|
75 |
+
model_config_path = model_cls.default_config_path(model_type=model_type)
|
76 |
+
|
77 |
+
model_config = OmegaConf.create()
|
78 |
+
# hierarchy override, customized config > default config
|
79 |
+
model_config = OmegaConf.merge(
|
80 |
+
model_config,
|
81 |
+
OmegaConf.load(model_config_path),
|
82 |
+
{"model": config["model"]},
|
83 |
+
)
|
84 |
+
|
85 |
+
return model_config
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def build_runner_config(config):
|
89 |
+
return {"run": config.run}
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def build_dataset_config(config):
|
93 |
+
datasets = config.get("datasets", None)
|
94 |
+
if datasets is None:
|
95 |
+
raise KeyError(
|
96 |
+
"Expecting 'datasets' as the root key for dataset configuration."
|
97 |
+
)
|
98 |
+
|
99 |
+
dataset_config = OmegaConf.create()
|
100 |
+
|
101 |
+
for dataset_name in datasets:
|
102 |
+
|
103 |
+
print("dataset name", dataset_name)
|
104 |
+
builder_cls = registry.get_builder_class(dataset_name)
|
105 |
+
|
106 |
+
dataset_config_type = datasets[dataset_name].get("type", "default")
|
107 |
+
dataset_config_path = builder_cls.default_config_path(
|
108 |
+
type=dataset_config_type
|
109 |
+
)
|
110 |
+
|
111 |
+
# hierarchy override, customized config > default config
|
112 |
+
dataset_config = OmegaConf.merge(
|
113 |
+
dataset_config,
|
114 |
+
OmegaConf.load(dataset_config_path),
|
115 |
+
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
|
116 |
+
)
|
117 |
+
|
118 |
+
return dataset_config
|
119 |
+
|
120 |
+
def _convert_to_dot_list(self, opts):
|
121 |
+
if opts is None:
|
122 |
+
opts = []
|
123 |
+
|
124 |
+
if len(opts) == 0:
|
125 |
+
return opts
|
126 |
+
|
127 |
+
has_equal = opts[0].find("=") != -1
|
128 |
+
|
129 |
+
if has_equal:
|
130 |
+
return opts
|
131 |
+
|
132 |
+
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
|
133 |
+
|
134 |
+
def get_config(self):
|
135 |
+
return self.config
|
136 |
+
|
137 |
+
@property
|
138 |
+
def run_cfg(self):
|
139 |
+
return self.config.run
|
140 |
+
|
141 |
+
@property
|
142 |
+
def datasets_cfg(self):
|
143 |
+
return self.config.datasets
|
144 |
+
|
145 |
+
@property
|
146 |
+
def model_cfg(self):
|
147 |
+
return self.config.model
|
148 |
+
|
149 |
+
def pretty_print(self):
|
150 |
+
logging.info("\n===== Running Parameters =====")
|
151 |
+
logging.info(self._convert_node_to_json(self.config.run))
|
152 |
+
|
153 |
+
logging.info("\n====== Dataset Attributes ======")
|
154 |
+
datasets = self.config.datasets
|
155 |
+
|
156 |
+
for dataset in datasets:
|
157 |
+
if dataset in self.config.datasets:
|
158 |
+
logging.info(f"\n======== {dataset} =======")
|
159 |
+
dataset_config = self.config.datasets[dataset]
|
160 |
+
logging.info(self._convert_node_to_json(dataset_config))
|
161 |
+
else:
|
162 |
+
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
|
163 |
+
|
164 |
+
logging.info(f"\n====== Model Attributes ======")
|
165 |
+
logging.info(self._convert_node_to_json(self.config.model))
|
166 |
+
|
167 |
+
def _convert_node_to_json(self, node):
|
168 |
+
container = OmegaConf.to_container(node, resolve=True)
|
169 |
+
return json.dumps(container, indent=4, sort_keys=True)
|
170 |
+
|
171 |
+
def to_dict(self):
|
172 |
+
return OmegaConf.to_container(self.config)
|
173 |
+
|
174 |
+
|
175 |
+
def node_to_dict(node):
|
176 |
+
return OmegaConf.to_container(node)
|
177 |
+
|
178 |
+
|
179 |
+
class ConfigValidator:
|
180 |
+
"""
|
181 |
+
This is a preliminary implementation to centralize and validate the configuration.
|
182 |
+
May be altered in the future.
|
183 |
+
|
184 |
+
A helper class to validate configurations from yaml file.
|
185 |
+
|
186 |
+
This serves the following purposes:
|
187 |
+
1. Ensure all the options in the yaml are defined, raise error if not.
|
188 |
+
2. when type mismatches are found, the validator will raise an error.
|
189 |
+
3. a central place to store and display helpful messages for supported configurations.
|
190 |
+
|
191 |
+
"""
|
192 |
+
|
193 |
+
class _Argument:
|
194 |
+
def __init__(self, name, choices=None, type=None, help=None):
|
195 |
+
self.name = name
|
196 |
+
self.val = None
|
197 |
+
self.choices = choices
|
198 |
+
self.type = type
|
199 |
+
self.help = help
|
200 |
+
|
201 |
+
def __str__(self):
|
202 |
+
s = f"{self.name}={self.val}"
|
203 |
+
if self.type is not None:
|
204 |
+
s += f", ({self.type})"
|
205 |
+
if self.choices is not None:
|
206 |
+
s += f", choices: {self.choices}"
|
207 |
+
if self.help is not None:
|
208 |
+
s += f", ({self.help})"
|
209 |
+
return s
|
210 |
+
|
211 |
+
def __init__(self, description):
|
212 |
+
self.description = description
|
213 |
+
|
214 |
+
self.arguments = dict()
|
215 |
+
|
216 |
+
self.parsed_args = None
|
217 |
+
|
218 |
+
def __getitem__(self, key):
|
219 |
+
assert self.parsed_args is not None, "No arguments parsed yet."
|
220 |
+
|
221 |
+
return self.parsed_args[key]
|
222 |
+
|
223 |
+
def __str__(self) -> str:
|
224 |
+
return self.format_help()
|
225 |
+
|
226 |
+
def add_argument(self, *args, **kwargs):
|
227 |
+
"""
|
228 |
+
Assume the first argument is the name of the argument.
|
229 |
+
"""
|
230 |
+
self.arguments[args[0]] = self._Argument(*args, **kwargs)
|
231 |
+
|
232 |
+
def validate(self, config=None):
|
233 |
+
"""
|
234 |
+
Convert yaml config (dict-like) to list, required by argparse.
|
235 |
+
"""
|
236 |
+
for k, v in config.items():
|
237 |
+
assert (
|
238 |
+
k in self.arguments
|
239 |
+
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
|
240 |
+
|
241 |
+
if self.arguments[k].type is not None:
|
242 |
+
try:
|
243 |
+
self.arguments[k].val = self.arguments[k].type(v)
|
244 |
+
except ValueError:
|
245 |
+
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
|
246 |
+
|
247 |
+
if self.arguments[k].choices is not None:
|
248 |
+
assert (
|
249 |
+
v in self.arguments[k].choices
|
250 |
+
), f"""{k} must be one of {self.arguments[k].choices}."""
|
251 |
+
|
252 |
+
return config
|
253 |
+
|
254 |
+
def format_arguments(self):
|
255 |
+
return str([f"{k}" for k in sorted(self.arguments.keys())])
|
256 |
+
|
257 |
+
def format_help(self):
|
258 |
+
# description + key-value pair string for each argument
|
259 |
+
help_msg = str(self.description)
|
260 |
+
return help_msg + ", available arguments: " + self.format_arguments()
|
261 |
+
|
262 |
+
def print_help(self):
|
263 |
+
# display help message
|
264 |
+
print(self.format_help())
|
265 |
+
|
266 |
+
|
267 |
+
def create_runner_config_validator():
|
268 |
+
validator = ConfigValidator(description="Runner configurations")
|
269 |
+
|
270 |
+
validator.add_argument(
|
271 |
+
"runner",
|
272 |
+
type=str,
|
273 |
+
choices=["runner_base", "runner_iter"],
|
274 |
+
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
|
275 |
+
runner runs based on iters. Default: runner_base""",
|
276 |
+
)
|
277 |
+
# add argumetns for training dataset ratios
|
278 |
+
validator.add_argument(
|
279 |
+
"train_dataset_ratios",
|
280 |
+
type=Dict[str, float],
|
281 |
+
help="""Ratios of training dataset. This is used in iteration-based runner.
|
282 |
+
Do not support for epoch-based runner because how to define an epoch becomes tricky.
|
283 |
+
Default: None""",
|
284 |
+
)
|
285 |
+
validator.add_argument(
|
286 |
+
"max_iters",
|
287 |
+
type=float,
|
288 |
+
help="Maximum number of iterations to run.",
|
289 |
+
)
|
290 |
+
validator.add_argument(
|
291 |
+
"max_epoch",
|
292 |
+
type=int,
|
293 |
+
help="Maximum number of epochs to run.",
|
294 |
+
)
|
295 |
+
# add arguments for iters_per_inner_epoch
|
296 |
+
validator.add_argument(
|
297 |
+
"iters_per_inner_epoch",
|
298 |
+
type=float,
|
299 |
+
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
|
300 |
+
)
|
301 |
+
lr_scheds_choices = registry.list_lr_schedulers()
|
302 |
+
validator.add_argument(
|
303 |
+
"lr_sched",
|
304 |
+
type=str,
|
305 |
+
choices=lr_scheds_choices,
|
306 |
+
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
|
307 |
+
)
|
308 |
+
task_choices = registry.list_tasks()
|
309 |
+
validator.add_argument(
|
310 |
+
"task",
|
311 |
+
type=str,
|
312 |
+
choices=task_choices,
|
313 |
+
help="Task to use, from {}".format(task_choices),
|
314 |
+
)
|
315 |
+
# add arguments for init_lr
|
316 |
+
validator.add_argument(
|
317 |
+
"init_lr",
|
318 |
+
type=float,
|
319 |
+
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
|
320 |
+
)
|
321 |
+
# add arguments for min_lr
|
322 |
+
validator.add_argument(
|
323 |
+
"min_lr",
|
324 |
+
type=float,
|
325 |
+
help="Minimum learning rate (after decay).",
|
326 |
+
)
|
327 |
+
# add arguments for warmup_lr
|
328 |
+
validator.add_argument(
|
329 |
+
"warmup_lr",
|
330 |
+
type=float,
|
331 |
+
help="Starting learning rate for warmup.",
|
332 |
+
)
|
333 |
+
# add arguments for learning rate decay rate
|
334 |
+
validator.add_argument(
|
335 |
+
"lr_decay_rate",
|
336 |
+
type=float,
|
337 |
+
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
|
338 |
+
)
|
339 |
+
# add arguments for weight decay
|
340 |
+
validator.add_argument(
|
341 |
+
"weight_decay",
|
342 |
+
type=float,
|
343 |
+
help="Weight decay rate.",
|
344 |
+
)
|
345 |
+
# add arguments for training batch size
|
346 |
+
validator.add_argument(
|
347 |
+
"batch_size_train",
|
348 |
+
type=int,
|
349 |
+
help="Training batch size.",
|
350 |
+
)
|
351 |
+
# add arguments for evaluation batch size
|
352 |
+
validator.add_argument(
|
353 |
+
"batch_size_eval",
|
354 |
+
type=int,
|
355 |
+
help="Evaluation batch size, including validation and testing.",
|
356 |
+
)
|
357 |
+
# add arguments for number of workers for data loading
|
358 |
+
validator.add_argument(
|
359 |
+
"num_workers",
|
360 |
+
help="Number of workers for data loading.",
|
361 |
+
)
|
362 |
+
# add arguments for warm up steps
|
363 |
+
validator.add_argument(
|
364 |
+
"warmup_steps",
|
365 |
+
type=int,
|
366 |
+
help="Number of warmup steps. Required if a warmup schedule is used.",
|
367 |
+
)
|
368 |
+
# add arguments for random seed
|
369 |
+
validator.add_argument(
|
370 |
+
"seed",
|
371 |
+
type=int,
|
372 |
+
help="Random seed.",
|
373 |
+
)
|
374 |
+
# add arguments for output directory
|
375 |
+
validator.add_argument(
|
376 |
+
"output_dir",
|
377 |
+
type=str,
|
378 |
+
help="Output directory to save checkpoints and logs.",
|
379 |
+
)
|
380 |
+
# add arguments for whether only use evaluation
|
381 |
+
validator.add_argument(
|
382 |
+
"evaluate",
|
383 |
+
help="Whether to only evaluate the model. If true, training will not be performed.",
|
384 |
+
)
|
385 |
+
# add arguments for splits used for training, e.g. ["train", "val"]
|
386 |
+
validator.add_argument(
|
387 |
+
"train_splits",
|
388 |
+
type=list,
|
389 |
+
help="Splits to use for training.",
|
390 |
+
)
|
391 |
+
# add arguments for splits used for validation, e.g. ["val"]
|
392 |
+
validator.add_argument(
|
393 |
+
"valid_splits",
|
394 |
+
type=list,
|
395 |
+
help="Splits to use for validation. If not provided, will skip the validation.",
|
396 |
+
)
|
397 |
+
# add arguments for splits used for testing, e.g. ["test"]
|
398 |
+
validator.add_argument(
|
399 |
+
"test_splits",
|
400 |
+
type=list,
|
401 |
+
help="Splits to use for testing. If not provided, will skip the testing.",
|
402 |
+
)
|
403 |
+
# add arguments for accumulating gradient for iterations
|
404 |
+
validator.add_argument(
|
405 |
+
"accum_grad_iters",
|
406 |
+
type=int,
|
407 |
+
help="Number of iterations to accumulate gradient for.",
|
408 |
+
)
|
409 |
+
|
410 |
+
# ====== distributed training ======
|
411 |
+
validator.add_argument(
|
412 |
+
"device",
|
413 |
+
type=str,
|
414 |
+
choices=["cpu", "cuda"],
|
415 |
+
help="Device to use. Support 'cuda' or 'cpu' as for now.",
|
416 |
+
)
|
417 |
+
validator.add_argument(
|
418 |
+
"world_size",
|
419 |
+
type=int,
|
420 |
+
help="Number of processes participating in the job.",
|
421 |
+
)
|
422 |
+
validator.add_argument("dist_url", type=str)
|
423 |
+
validator.add_argument("distributed", type=bool)
|
424 |
+
# add arguments to opt using distributed sampler during evaluation or not
|
425 |
+
validator.add_argument(
|
426 |
+
"use_dist_eval_sampler",
|
427 |
+
type=bool,
|
428 |
+
help="Whether to use distributed sampler during evaluation or not.",
|
429 |
+
)
|
430 |
+
|
431 |
+
# ====== task specific ======
|
432 |
+
# generation task specific arguments
|
433 |
+
# add arguments for maximal length of text output
|
434 |
+
validator.add_argument(
|
435 |
+
"max_len",
|
436 |
+
type=int,
|
437 |
+
help="Maximal length of text output.",
|
438 |
+
)
|
439 |
+
# add arguments for minimal length of text output
|
440 |
+
validator.add_argument(
|
441 |
+
"min_len",
|
442 |
+
type=int,
|
443 |
+
help="Minimal length of text output.",
|
444 |
+
)
|
445 |
+
# add arguments number of beams
|
446 |
+
validator.add_argument(
|
447 |
+
"num_beams",
|
448 |
+
type=int,
|
449 |
+
help="Number of beams used for beam search.",
|
450 |
+
)
|
451 |
+
|
452 |
+
# vqa task specific arguments
|
453 |
+
# add arguments for number of answer candidates
|
454 |
+
validator.add_argument(
|
455 |
+
"num_ans_candidates",
|
456 |
+
type=int,
|
457 |
+
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
|
458 |
+
)
|
459 |
+
# add arguments for inference method
|
460 |
+
validator.add_argument(
|
461 |
+
"inference_method",
|
462 |
+
type=str,
|
463 |
+
choices=["genearte", "rank"],
|
464 |
+
help="""Inference method to use for question answering. If rank, requires a answer list.""",
|
465 |
+
)
|
466 |
+
|
467 |
+
# ====== model specific ======
|
468 |
+
validator.add_argument(
|
469 |
+
"k_test",
|
470 |
+
type=int,
|
471 |
+
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
|
472 |
+
)
|
473 |
+
|
474 |
+
return validator
|
minigpt4/common/dist_utils.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import datetime
|
9 |
+
import functools
|
10 |
+
import os
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
import timm.models.hub as timm_hub
|
15 |
+
|
16 |
+
|
17 |
+
def setup_for_distributed(is_master):
|
18 |
+
"""
|
19 |
+
This function disables printing when not in master process
|
20 |
+
"""
|
21 |
+
import builtins as __builtin__
|
22 |
+
|
23 |
+
builtin_print = __builtin__.print
|
24 |
+
|
25 |
+
def print(*args, **kwargs):
|
26 |
+
force = kwargs.pop("force", False)
|
27 |
+
if is_master or force:
|
28 |
+
builtin_print(*args, **kwargs)
|
29 |
+
|
30 |
+
__builtin__.print = print
|
31 |
+
|
32 |
+
|
33 |
+
def is_dist_avail_and_initialized():
|
34 |
+
if not dist.is_available():
|
35 |
+
return False
|
36 |
+
if not dist.is_initialized():
|
37 |
+
return False
|
38 |
+
return True
|
39 |
+
|
40 |
+
|
41 |
+
def get_world_size():
|
42 |
+
if not is_dist_avail_and_initialized():
|
43 |
+
return 1
|
44 |
+
return dist.get_world_size()
|
45 |
+
|
46 |
+
|
47 |
+
def get_rank():
|
48 |
+
if not is_dist_avail_and_initialized():
|
49 |
+
return 0
|
50 |
+
return dist.get_rank()
|
51 |
+
|
52 |
+
|
53 |
+
def is_main_process():
|
54 |
+
return get_rank() == 0
|
55 |
+
|
56 |
+
|
57 |
+
def init_distributed_mode(args):
|
58 |
+
if args.distributed is False:
|
59 |
+
print("Not using distributed mode")
|
60 |
+
args.rank = 0
|
61 |
+
return
|
62 |
+
|
63 |
+
if 'LOCAL_RANK' not in os.environ:
|
64 |
+
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
65 |
+
|
66 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
67 |
+
args.rank = int(os.environ["RANK"])
|
68 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
69 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
70 |
+
elif "SLURM_PROCID" in os.environ:
|
71 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
72 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
73 |
+
else:
|
74 |
+
print("Not using distributed mode")
|
75 |
+
args.distributed = False
|
76 |
+
args.rank = 0
|
77 |
+
return
|
78 |
+
|
79 |
+
args.distributed = True
|
80 |
+
|
81 |
+
torch.cuda.set_device(args.gpu)
|
82 |
+
args.dist_backend = "nccl"
|
83 |
+
print(
|
84 |
+
"| distributed init (rank {}, world {}): {}".format(
|
85 |
+
args.rank, args.world_size, args.dist_url
|
86 |
+
),
|
87 |
+
flush=True,
|
88 |
+
)
|
89 |
+
torch.distributed.init_process_group(
|
90 |
+
backend=args.dist_backend,
|
91 |
+
init_method=args.dist_url,
|
92 |
+
world_size=args.world_size,
|
93 |
+
rank=args.rank,
|
94 |
+
timeout=datetime.timedelta(
|
95 |
+
days=365
|
96 |
+
), # allow auto-downloading and de-compressing
|
97 |
+
)
|
98 |
+
torch.distributed.barrier()
|
99 |
+
setup_for_distributed(args.rank == 0)
|
100 |
+
|
101 |
+
|
102 |
+
def get_dist_info():
|
103 |
+
if torch.__version__ < "1.0":
|
104 |
+
initialized = dist._initialized
|
105 |
+
else:
|
106 |
+
initialized = dist.is_initialized()
|
107 |
+
if initialized:
|
108 |
+
rank = dist.get_rank()
|
109 |
+
world_size = dist.get_world_size()
|
110 |
+
else: # non-distributed training
|
111 |
+
rank = 0
|
112 |
+
world_size = 1
|
113 |
+
return rank, world_size
|
114 |
+
|
115 |
+
|
116 |
+
def main_process(func):
|
117 |
+
@functools.wraps(func)
|
118 |
+
def wrapper(*args, **kwargs):
|
119 |
+
rank, _ = get_dist_info()
|
120 |
+
if rank == 0:
|
121 |
+
return func(*args, **kwargs)
|
122 |
+
|
123 |
+
return wrapper
|
124 |
+
|
125 |
+
|
126 |
+
def download_cached_file(url, check_hash=True, progress=False):
|
127 |
+
"""
|
128 |
+
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
129 |
+
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
130 |
+
"""
|
131 |
+
|
132 |
+
def get_cached_file_path():
|
133 |
+
# a hack to sync the file path across processes
|
134 |
+
parts = torch.hub.urlparse(url)
|
135 |
+
filename = os.path.basename(parts.path)
|
136 |
+
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
137 |
+
|
138 |
+
return cached_file
|
139 |
+
|
140 |
+
if is_main_process():
|
141 |
+
timm_hub.download_cached_file(url, check_hash, progress)
|
142 |
+
|
143 |
+
if is_dist_avail_and_initialized():
|
144 |
+
dist.barrier()
|
145 |
+
|
146 |
+
return get_cached_file_path()
|
minigpt4/common/eval_utils.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
from nltk.translate.bleu_score import sentence_bleu
|
4 |
+
import sys
|
5 |
+
sys.path.append('/home/ataallka/minigpt_video/minigpt_multi_img')
|
6 |
+
from minigpt4.common.registry import registry
|
7 |
+
from minigpt4.common.config import Config
|
8 |
+
|
9 |
+
# imports modules for registration
|
10 |
+
from minigpt4.datasets.builders import *
|
11 |
+
from minigpt4.models import *
|
12 |
+
from minigpt4.processors import *
|
13 |
+
# from minigpt4.runners import *
|
14 |
+
from minigpt4.tasks import *
|
15 |
+
from pycocoevalcap.cider.cider import Cider
|
16 |
+
import os
|
17 |
+
import openai
|
18 |
+
from tqdm import tqdm
|
19 |
+
import json
|
20 |
+
import ast
|
21 |
+
import time
|
22 |
+
|
23 |
+
def eval_parser():
|
24 |
+
parser = argparse.ArgumentParser(description="Demo")
|
25 |
+
parser.add_argument("--cfg-path", help="path to configuration file.",default="test_configs/llama2_test_config.yaml")
|
26 |
+
parser.add_argument("--ckpt", type=str,default='checkpoints/video_llama_checkpoint_last.pth', help="path to checkpoint")
|
27 |
+
parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.")
|
28 |
+
parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens")
|
29 |
+
parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model")
|
30 |
+
parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha")
|
31 |
+
parser.add_argument(
|
32 |
+
"--options",
|
33 |
+
nargs="+",
|
34 |
+
help="override some settings in the used config, the key-value pair "
|
35 |
+
"in xxx=yyy format will be merged into config file (deprecate), "
|
36 |
+
"change to --cfg-options instead.",
|
37 |
+
)
|
38 |
+
return parser
|
39 |
+
|
40 |
+
|
41 |
+
def prepare_texts(texts, conv_temp, template='<Img><ImageHere></Img>', lengths=None):
|
42 |
+
convs = [conv_temp.copy() for _ in range(len(texts))]
|
43 |
+
if lengths is None:
|
44 |
+
[conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for conv, text in zip(convs, texts)]
|
45 |
+
else:
|
46 |
+
templates = [template * length for length in lengths]
|
47 |
+
[conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for template, conv, text in zip(templates, convs, texts)]
|
48 |
+
[conv.append_message(conv.roles[1], None) for conv in convs]
|
49 |
+
texts = [conv.get_prompt() for conv in convs]
|
50 |
+
return texts
|
51 |
+
|
52 |
+
|
53 |
+
def init_model(args):
|
54 |
+
print('Initialization Model')
|
55 |
+
cfg = Config(args)
|
56 |
+
cfg.model_cfg.ckpt = args.ckpt
|
57 |
+
cfg.model_cfg.lora_r = args.lora_r
|
58 |
+
cfg.model_cfg.lora_alpha = args.lora_alpha
|
59 |
+
|
60 |
+
model_config = cfg.model_cfg
|
61 |
+
model_config.low_resource = True
|
62 |
+
model_cls = registry.get_model_class(model_config.arch)
|
63 |
+
model = model_cls.from_config(model_config).to('cuda:0')
|
64 |
+
|
65 |
+
# import pudb; pudb.set_trace()
|
66 |
+
key = list(cfg.datasets_cfg.keys())[0]
|
67 |
+
vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train
|
68 |
+
print(vis_processor_cfg)
|
69 |
+
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
70 |
+
print('Initialization Finished')
|
71 |
+
return model, vis_processor
|
72 |
+
|
73 |
+
def computeIoU(bbox1, bbox2):
|
74 |
+
x1, y1, x2, y2 = bbox1
|
75 |
+
x3, y3, x4, y4 = bbox2
|
76 |
+
intersection_x1 = max(x1, x3)
|
77 |
+
intersection_y1 = max(y1, y3)
|
78 |
+
intersection_x2 = min(x2, x4)
|
79 |
+
intersection_y2 = min(y2, y4)
|
80 |
+
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
|
81 |
+
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
82 |
+
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
|
83 |
+
union_area = bbox1_area + bbox2_area - intersection_area
|
84 |
+
iou = intersection_area / union_area
|
85 |
+
return iou
|
86 |
+
|
87 |
+
def eval_bleu(results):
|
88 |
+
bleus1,bleus2,bleus3,bleus4 = [],[],[],[]
|
89 |
+
for result in tqdm (results,desc="bleu_eval"):
|
90 |
+
gt = result['gt']
|
91 |
+
pred = result['pred']
|
92 |
+
bleus1.append(sentence_bleu([gt.split()], pred.split(), weights=(1,0,0,0)))
|
93 |
+
bleus2.append(sentence_bleu([gt.split()], pred.split(), weights=(0.5,0.5,0,0)))
|
94 |
+
bleus3.append(sentence_bleu([gt.split()], pred.split(), weights=(0.33,0.33,0.33,0)))
|
95 |
+
bleus4.append(sentence_bleu([gt.split()], pred.split()))
|
96 |
+
# print(np.mean(bleus1),np.mean(bleus2),np.mean(bleus3),np.mean(bleus4),flush=True)
|
97 |
+
return {'bleu1':np.mean(bleus1),'bleu2':np.mean(bleus2),'bleu3':np.mean(bleus3),'bleu4':np.mean(bleus4)}
|
98 |
+
|
99 |
+
# Create a Cider object
|
100 |
+
cider_scorer = Cider()
|
101 |
+
def eval_cider(pred_result,gt_result):
|
102 |
+
# Compute CIDEr scores
|
103 |
+
mean_cider_scores, cider_scores = cider_scorer.compute_score(gt_result, pred_result)
|
104 |
+
cider_scores_dict={}
|
105 |
+
for score,pred_vid_id,gt_vid_id in tqdm(zip(cider_scores.tolist(),pred_result,gt_result),desc="cider_eval") :
|
106 |
+
assert pred_vid_id==gt_vid_id
|
107 |
+
cider_scores_dict[pred_vid_id] = score
|
108 |
+
return {'mean_cider_scores':mean_cider_scores,'cider_scores':cider_scores_dict}
|
109 |
+
|
110 |
+
|
111 |
+
openai.api_key_path = "/home/ataallka/chatgpt_api.txt"
|
112 |
+
|
113 |
+
|
114 |
+
def chat_gpt_eval(results,output_path):
|
115 |
+
trial=0
|
116 |
+
gpt_results=[]
|
117 |
+
avg_chatgpt_score=0
|
118 |
+
existed_files={}
|
119 |
+
# read previous results from output path
|
120 |
+
for file in os.listdir(output_path):
|
121 |
+
if file.endswith(".json"):
|
122 |
+
with open(f'{output_path}/{file}') as json_file:
|
123 |
+
data = json.load(json_file)
|
124 |
+
gpt_results.append(data[0])
|
125 |
+
avg_chatgpt_score+=float(data[0]['chatgpt_score'])
|
126 |
+
existed_files[data[0]['video_name']]=True
|
127 |
+
length_output_path=len(os.listdir(output_path))
|
128 |
+
while len (results)!= length_output_path:
|
129 |
+
for res in tqdm(results,desc="chatgpt_eval"):
|
130 |
+
if existed_files.get(res['video_name'],False):
|
131 |
+
continue
|
132 |
+
video_name=res['video_name']
|
133 |
+
sentence_1=res['A']
|
134 |
+
sentence_2=res['pred']
|
135 |
+
try:
|
136 |
+
# prompt=f"given these 2 sentences the first one is the ground truth text and the second sentence is the generated text ,give me a score from 0 to 1 to evaluate how much they are similar to each other, and have the same context and related to each other to evaluate the quality of this generated text.the output should be only the score float number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:"
|
137 |
+
prompt=f"given these 2 sentences the first one is the ground truth descrption of a video and the second sentence is the generated text from a video summarization model,give it a score from 0 to 5 to evaluate the model summarization performance.the output should be only the score number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:"
|
138 |
+
response = openai.ChatCompletion.create(
|
139 |
+
model="gpt-3.5-turbo",
|
140 |
+
messages=[
|
141 |
+
{
|
142 |
+
"role": "user",
|
143 |
+
"content": prompt
|
144 |
+
}],
|
145 |
+
)
|
146 |
+
res['chatgpt_score']=response.choices[0].message['content']
|
147 |
+
out={'video_name':video_name,'chatgpt_score':response.choices[0].message['content']}
|
148 |
+
gpt_results.append(out)
|
149 |
+
# save each video result in a json file
|
150 |
+
with open(f'{output_path}/{video_name}.json', 'w') as f:
|
151 |
+
json.dump([out], f)
|
152 |
+
avg_chatgpt_score+=float(response.choices[0].message['content'])
|
153 |
+
except Exception as e:
|
154 |
+
print("chat gpt error",e)
|
155 |
+
print ("Finished chat gpt evaluation in trial",trial)
|
156 |
+
trial+=1
|
157 |
+
length_output_path=len(os.listdir(output_path))
|
158 |
+
return results,avg_chatgpt_score/len(results)
|
159 |
+
def GPT4_answer(question, answer,pred):
|
160 |
+
try:
|
161 |
+
# Compute the correctness score
|
162 |
+
completion = openai.ChatCompletion.create(
|
163 |
+
# model="gpt-3.5-turbo",
|
164 |
+
model='gpt-4',
|
165 |
+
messages=[
|
166 |
+
{
|
167 |
+
"role": "system",
|
168 |
+
"content":
|
169 |
+
"You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
|
170 |
+
"Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
|
171 |
+
"------"
|
172 |
+
"##INSTRUCTIONS: "
|
173 |
+
"- Focus on the meaningful match between the predicted answer and the correct answer.\n"
|
174 |
+
"- Consider synonyms or paraphrases as valid matches.\n"
|
175 |
+
"- Evaluate the correctness of the prediction compared to the answer."
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"role": "user",
|
179 |
+
"content":
|
180 |
+
"Please evaluate the following video-based question-answer pair:\n\n"
|
181 |
+
f"Question: {question}\n"
|
182 |
+
f"Correct Answer: {answer}\n"
|
183 |
+
f"Predicted Answer: {pred}\n\n"
|
184 |
+
"Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
|
185 |
+
"Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
|
186 |
+
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
|
187 |
+
"For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."
|
188 |
+
}
|
189 |
+
]
|
190 |
+
)
|
191 |
+
# Convert response to a Python dictionary.
|
192 |
+
response_message = completion["choices"][0]["message"]["content"]
|
193 |
+
response_dict = ast.literal_eval(response_message)
|
194 |
+
return response_dict
|
195 |
+
except Exception as e:
|
196 |
+
print(f"Error : {e}")
|
197 |
+
return None
|
198 |
+
def GPT4_evaluation(val_result):
|
199 |
+
scores=[]
|
200 |
+
yes_count=0
|
201 |
+
no_count=0
|
202 |
+
for res in val_result:
|
203 |
+
gpt_response=GPT4_answer(res['Q'],res['A'],res['pred'])
|
204 |
+
if gpt_response is None:
|
205 |
+
continue
|
206 |
+
try:
|
207 |
+
scores.append(float(gpt_response['score']))
|
208 |
+
if 'yes' in gpt_response['pred'].lower():
|
209 |
+
yes_count+=1
|
210 |
+
elif 'no' in gpt_response['pred'].lower():
|
211 |
+
no_count+=1
|
212 |
+
except:
|
213 |
+
continue
|
214 |
+
avg_score=sum(scores)/len(scores)
|
215 |
+
accuracy=(yes_count/(yes_count+no_count))*100
|
216 |
+
print(f"chatgpt score: {avg_score} accuracy: {accuracy}")
|
217 |
+
return avg_score,accuracy
|
218 |
+
|
219 |
+
# with open('results/ckpt_15_res89_res32_Video_validation_Dataset_subtitles.json','r') as f:
|
220 |
+
# results = json.load(f)
|
221 |
+
# t1=time.time()
|
222 |
+
# avg_score,accuracy=GPT4_evaluation(results)
|
223 |
+
# print(f"chatgpt score: {avg_score} accuracy: {accuracy}")
|
224 |
+
# print(f"Time taken: {time.time()-t1}")
|
minigpt4/common/gradcam.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from matplotlib import pyplot as plt
|
3 |
+
from scipy.ndimage import filters
|
4 |
+
from skimage import transform as skimage_transform
|
5 |
+
|
6 |
+
|
7 |
+
def getAttMap(img, attMap, blur=True, overlap=True):
|
8 |
+
attMap -= attMap.min()
|
9 |
+
if attMap.max() > 0:
|
10 |
+
attMap /= attMap.max()
|
11 |
+
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
|
12 |
+
if blur:
|
13 |
+
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
14 |
+
attMap -= attMap.min()
|
15 |
+
attMap /= attMap.max()
|
16 |
+
cmap = plt.get_cmap("jet")
|
17 |
+
attMapV = cmap(attMap)
|
18 |
+
attMapV = np.delete(attMapV, 3, 2)
|
19 |
+
if overlap:
|
20 |
+
attMap = (
|
21 |
+
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
22 |
+
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
23 |
+
)
|
24 |
+
return attMap
|
minigpt4/common/logger.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import datetime
|
9 |
+
import logging
|
10 |
+
import time
|
11 |
+
from collections import defaultdict, deque
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
|
16 |
+
from minigpt4.common import dist_utils
|
17 |
+
|
18 |
+
|
19 |
+
class SmoothedValue(object):
|
20 |
+
"""Track a series of values and provide access to smoothed values over a
|
21 |
+
window or the global series average.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, window_size=20, fmt=None):
|
25 |
+
if fmt is None:
|
26 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
27 |
+
self.deque = deque(maxlen=window_size)
|
28 |
+
self.total = 0.0
|
29 |
+
self.count = 0
|
30 |
+
self.fmt = fmt
|
31 |
+
|
32 |
+
def update(self, value, n=1):
|
33 |
+
self.deque.append(value)
|
34 |
+
self.count += n
|
35 |
+
self.total += value * n
|
36 |
+
|
37 |
+
def synchronize_between_processes(self):
|
38 |
+
"""
|
39 |
+
Warning: does not synchronize the deque!
|
40 |
+
"""
|
41 |
+
if not dist_utils.is_dist_avail_and_initialized():
|
42 |
+
return
|
43 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
44 |
+
dist.barrier()
|
45 |
+
dist.all_reduce(t)
|
46 |
+
t = t.tolist()
|
47 |
+
self.count = int(t[0])
|
48 |
+
self.total = t[1]
|
49 |
+
|
50 |
+
@property
|
51 |
+
def median(self):
|
52 |
+
d = torch.tensor(list(self.deque))
|
53 |
+
return d.median().item()
|
54 |
+
|
55 |
+
@property
|
56 |
+
def avg(self):
|
57 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
58 |
+
return d.mean().item()
|
59 |
+
|
60 |
+
@property
|
61 |
+
def global_avg(self):
|
62 |
+
return self.total / self.count
|
63 |
+
|
64 |
+
@property
|
65 |
+
def max(self):
|
66 |
+
return max(self.deque)
|
67 |
+
|
68 |
+
@property
|
69 |
+
def value(self):
|
70 |
+
return self.deque[-1]
|
71 |
+
|
72 |
+
def __str__(self):
|
73 |
+
return self.fmt.format(
|
74 |
+
median=self.median,
|
75 |
+
avg=self.avg,
|
76 |
+
global_avg=self.global_avg,
|
77 |
+
max=self.max,
|
78 |
+
value=self.value,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
class MetricLogger(object):
|
83 |
+
def __init__(self, delimiter="\t"):
|
84 |
+
self.meters = defaultdict(SmoothedValue)
|
85 |
+
self.delimiter = delimiter
|
86 |
+
|
87 |
+
def update(self, **kwargs):
|
88 |
+
for k, v in kwargs.items():
|
89 |
+
if isinstance(v, torch.Tensor):
|
90 |
+
v = v.item()
|
91 |
+
assert isinstance(v, (float, int))
|
92 |
+
self.meters[k].update(v)
|
93 |
+
|
94 |
+
def __getattr__(self, attr):
|
95 |
+
if attr in self.meters:
|
96 |
+
return self.meters[attr]
|
97 |
+
if attr in self.__dict__:
|
98 |
+
return self.__dict__[attr]
|
99 |
+
raise AttributeError(
|
100 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
101 |
+
)
|
102 |
+
|
103 |
+
def __str__(self):
|
104 |
+
loss_str = []
|
105 |
+
for name, meter in self.meters.items():
|
106 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
107 |
+
return self.delimiter.join(loss_str)
|
108 |
+
|
109 |
+
def global_avg(self):
|
110 |
+
loss_str = []
|
111 |
+
for name, meter in self.meters.items():
|
112 |
+
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
113 |
+
return self.delimiter.join(loss_str)
|
114 |
+
|
115 |
+
def synchronize_between_processes(self):
|
116 |
+
for meter in self.meters.values():
|
117 |
+
meter.synchronize_between_processes()
|
118 |
+
|
119 |
+
def add_meter(self, name, meter):
|
120 |
+
self.meters[name] = meter
|
121 |
+
|
122 |
+
def log_every(self, iterable, print_freq, header=None):
|
123 |
+
i = 0
|
124 |
+
if not header:
|
125 |
+
header = ""
|
126 |
+
start_time = time.time()
|
127 |
+
end = time.time()
|
128 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
129 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
130 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
131 |
+
log_msg = [
|
132 |
+
header,
|
133 |
+
"[{0" + space_fmt + "}/{1}]",
|
134 |
+
"eta: {eta}",
|
135 |
+
"{meters}",
|
136 |
+
"time: {time}",
|
137 |
+
"data: {data}",
|
138 |
+
]
|
139 |
+
if torch.cuda.is_available():
|
140 |
+
log_msg.append("max mem: {memory:.0f}")
|
141 |
+
log_msg = self.delimiter.join(log_msg)
|
142 |
+
MB = 1024.0 * 1024.0
|
143 |
+
for obj in iterable:
|
144 |
+
data_time.update(time.time() - end)
|
145 |
+
yield obj
|
146 |
+
iter_time.update(time.time() - end)
|
147 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
148 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
149 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
150 |
+
if torch.cuda.is_available():
|
151 |
+
print(
|
152 |
+
log_msg.format(
|
153 |
+
i,
|
154 |
+
len(iterable),
|
155 |
+
eta=eta_string,
|
156 |
+
meters=str(self),
|
157 |
+
time=str(iter_time),
|
158 |
+
data=str(data_time),
|
159 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
160 |
+
)
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
print(
|
164 |
+
log_msg.format(
|
165 |
+
i,
|
166 |
+
len(iterable),
|
167 |
+
eta=eta_string,
|
168 |
+
meters=str(self),
|
169 |
+
time=str(iter_time),
|
170 |
+
data=str(data_time),
|
171 |
+
)
|
172 |
+
)
|
173 |
+
i += 1
|
174 |
+
end = time.time()
|
175 |
+
total_time = time.time() - start_time
|
176 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
177 |
+
print(
|
178 |
+
"{} Total time: {} ({:.4f} s / it)".format(
|
179 |
+
header, total_time_str, total_time / len(iterable)
|
180 |
+
)
|
181 |
+
)
|
182 |
+
|
183 |
+
|
184 |
+
class AttrDict(dict):
|
185 |
+
def __init__(self, *args, **kwargs):
|
186 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
187 |
+
self.__dict__ = self
|
188 |
+
|
189 |
+
|
190 |
+
def setup_logger():
|
191 |
+
logging.basicConfig(
|
192 |
+
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
|
193 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
194 |
+
handlers=[logging.StreamHandler()],
|
195 |
+
)
|
minigpt4/common/optims.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
from minigpt4.common.registry import registry
|
11 |
+
|
12 |
+
|
13 |
+
@registry.register_lr_scheduler("linear_warmup_step_lr")
|
14 |
+
class LinearWarmupStepLRScheduler:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
optimizer,
|
18 |
+
max_epoch,
|
19 |
+
min_lr,
|
20 |
+
init_lr,
|
21 |
+
decay_rate=1,
|
22 |
+
warmup_start_lr=-1,
|
23 |
+
warmup_steps=0,
|
24 |
+
**kwargs
|
25 |
+
):
|
26 |
+
self.optimizer = optimizer
|
27 |
+
|
28 |
+
self.max_epoch = max_epoch
|
29 |
+
self.min_lr = min_lr
|
30 |
+
|
31 |
+
self.decay_rate = decay_rate
|
32 |
+
|
33 |
+
self.init_lr = init_lr
|
34 |
+
self.warmup_steps = warmup_steps
|
35 |
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
36 |
+
|
37 |
+
def step(self, cur_epoch, cur_step):
|
38 |
+
if cur_epoch == 0:
|
39 |
+
warmup_lr_schedule(
|
40 |
+
step=cur_step,
|
41 |
+
optimizer=self.optimizer,
|
42 |
+
max_step=self.warmup_steps,
|
43 |
+
init_lr=self.warmup_start_lr,
|
44 |
+
max_lr=self.init_lr,
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
step_lr_schedule(
|
48 |
+
epoch=cur_epoch,
|
49 |
+
optimizer=self.optimizer,
|
50 |
+
init_lr=self.init_lr,
|
51 |
+
min_lr=self.min_lr,
|
52 |
+
decay_rate=self.decay_rate,
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
|
57 |
+
class LinearWarmupCosineLRScheduler:
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
optimizer,
|
61 |
+
max_epoch,
|
62 |
+
iters_per_epoch,
|
63 |
+
min_lr,
|
64 |
+
init_lr,
|
65 |
+
warmup_steps=0,
|
66 |
+
warmup_start_lr=-1,
|
67 |
+
**kwargs
|
68 |
+
):
|
69 |
+
self.optimizer = optimizer
|
70 |
+
|
71 |
+
self.max_epoch = max_epoch
|
72 |
+
self.iters_per_epoch = iters_per_epoch
|
73 |
+
self.min_lr = min_lr
|
74 |
+
|
75 |
+
self.init_lr = init_lr
|
76 |
+
self.warmup_steps = warmup_steps
|
77 |
+
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
78 |
+
|
79 |
+
def step(self, cur_epoch, cur_step):
|
80 |
+
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
81 |
+
if total_cur_step < self.warmup_steps:
|
82 |
+
warmup_lr_schedule(
|
83 |
+
step=total_cur_step,
|
84 |
+
optimizer=self.optimizer,
|
85 |
+
max_step=self.warmup_steps,
|
86 |
+
init_lr=self.warmup_start_lr,
|
87 |
+
max_lr=self.init_lr,
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
cosine_lr_schedule(
|
91 |
+
epoch=total_cur_step,
|
92 |
+
optimizer=self.optimizer,
|
93 |
+
max_epoch=self.max_epoch * self.iters_per_epoch,
|
94 |
+
init_lr=self.init_lr,
|
95 |
+
min_lr=self.min_lr,
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
100 |
+
"""Decay the learning rate"""
|
101 |
+
lr = (init_lr - min_lr) * 0.5 * (
|
102 |
+
1.0 + math.cos(math.pi * epoch / max_epoch)
|
103 |
+
) + min_lr
|
104 |
+
for param_group in optimizer.param_groups:
|
105 |
+
param_group["lr"] = lr
|
106 |
+
|
107 |
+
|
108 |
+
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
109 |
+
"""Warmup the learning rate"""
|
110 |
+
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
111 |
+
for param_group in optimizer.param_groups:
|
112 |
+
param_group["lr"] = lr
|
113 |
+
|
114 |
+
|
115 |
+
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
116 |
+
"""Decay the learning rate"""
|
117 |
+
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
118 |
+
for param_group in optimizer.param_groups:
|
119 |
+
param_group["lr"] = lr
|
minigpt4/common/registry.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
class Registry:
|
10 |
+
mapping = {
|
11 |
+
"builder_name_mapping": {},
|
12 |
+
"task_name_mapping": {},
|
13 |
+
"processor_name_mapping": {},
|
14 |
+
"model_name_mapping": {},
|
15 |
+
"lr_scheduler_name_mapping": {},
|
16 |
+
"runner_name_mapping": {},
|
17 |
+
"state": {},
|
18 |
+
"paths": {},
|
19 |
+
}
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def register_builder(cls, name):
|
23 |
+
r"""Register a dataset builder to registry with key 'name'
|
24 |
+
|
25 |
+
Args:
|
26 |
+
name: Key with which the builder will be registered.
|
27 |
+
|
28 |
+
Usage:
|
29 |
+
|
30 |
+
from minigpt4.common.registry import registry
|
31 |
+
from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
|
32 |
+
"""
|
33 |
+
|
34 |
+
def wrap(builder_cls):
|
35 |
+
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
36 |
+
|
37 |
+
assert issubclass(
|
38 |
+
builder_cls, BaseDatasetBuilder
|
39 |
+
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
40 |
+
builder_cls
|
41 |
+
)
|
42 |
+
if name in cls.mapping["builder_name_mapping"]:
|
43 |
+
raise KeyError(
|
44 |
+
"Name '{}' already registered for {}.".format(
|
45 |
+
name, cls.mapping["builder_name_mapping"][name]
|
46 |
+
)
|
47 |
+
)
|
48 |
+
cls.mapping["builder_name_mapping"][name] = builder_cls
|
49 |
+
return builder_cls
|
50 |
+
|
51 |
+
return wrap
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def register_task(cls, name):
|
55 |
+
r"""Register a task to registry with key 'name'
|
56 |
+
|
57 |
+
Args:
|
58 |
+
name: Key with which the task will be registered.
|
59 |
+
|
60 |
+
Usage:
|
61 |
+
|
62 |
+
from minigpt4.common.registry import registry
|
63 |
+
"""
|
64 |
+
|
65 |
+
def wrap(task_cls):
|
66 |
+
from minigpt4.tasks.base_task import BaseTask
|
67 |
+
|
68 |
+
assert issubclass(
|
69 |
+
task_cls, BaseTask
|
70 |
+
), "All tasks must inherit BaseTask class"
|
71 |
+
if name in cls.mapping["task_name_mapping"]:
|
72 |
+
raise KeyError(
|
73 |
+
"Name '{}' already registered for {}.".format(
|
74 |
+
name, cls.mapping["task_name_mapping"][name]
|
75 |
+
)
|
76 |
+
)
|
77 |
+
cls.mapping["task_name_mapping"][name] = task_cls
|
78 |
+
return task_cls
|
79 |
+
|
80 |
+
return wrap
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def register_model(cls, name):
|
84 |
+
r"""Register a task to registry with key 'name'
|
85 |
+
|
86 |
+
Args:
|
87 |
+
name: Key with which the task will be registered.
|
88 |
+
|
89 |
+
Usage:
|
90 |
+
|
91 |
+
from minigpt4.common.registry import registry
|
92 |
+
"""
|
93 |
+
|
94 |
+
def wrap(model_cls):
|
95 |
+
# from minigpt4.models import BaseModel
|
96 |
+
|
97 |
+
# assert issubclass(
|
98 |
+
# model_cls, BaseModel
|
99 |
+
# ), "All models must inherit BaseModel class"
|
100 |
+
|
101 |
+
if name in cls.mapping["model_name_mapping"]:
|
102 |
+
raise KeyError(
|
103 |
+
"Name '{}' already registered for {}.".format(
|
104 |
+
name, cls.mapping["model_name_mapping"][name]
|
105 |
+
)
|
106 |
+
)
|
107 |
+
cls.mapping["model_name_mapping"][name] = model_cls
|
108 |
+
return model_cls
|
109 |
+
|
110 |
+
return wrap
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def register_processor(cls, name):
|
114 |
+
r"""Register a processor to registry with key 'name'
|
115 |
+
|
116 |
+
Args:
|
117 |
+
name: Key with which the task will be registered.
|
118 |
+
|
119 |
+
Usage:
|
120 |
+
|
121 |
+
from minigpt4.common.registry import registry
|
122 |
+
"""
|
123 |
+
|
124 |
+
def wrap(processor_cls):
|
125 |
+
from minigpt4.processors import BaseProcessor
|
126 |
+
|
127 |
+
assert issubclass(
|
128 |
+
processor_cls, BaseProcessor
|
129 |
+
), "All processors must inherit BaseProcessor class"
|
130 |
+
if name in cls.mapping["processor_name_mapping"]:
|
131 |
+
raise KeyError(
|
132 |
+
"Name '{}' already registered for {}.".format(
|
133 |
+
name, cls.mapping["processor_name_mapping"][name]
|
134 |
+
)
|
135 |
+
)
|
136 |
+
cls.mapping["processor_name_mapping"][name] = processor_cls
|
137 |
+
return processor_cls
|
138 |
+
|
139 |
+
return wrap
|
140 |
+
|
141 |
+
@classmethod
|
142 |
+
def register_lr_scheduler(cls, name):
|
143 |
+
r"""Register a model to registry with key 'name'
|
144 |
+
|
145 |
+
Args:
|
146 |
+
name: Key with which the task will be registered.
|
147 |
+
|
148 |
+
Usage:
|
149 |
+
|
150 |
+
from minigpt4.common.registry import registry
|
151 |
+
"""
|
152 |
+
|
153 |
+
def wrap(lr_sched_cls):
|
154 |
+
if name in cls.mapping["lr_scheduler_name_mapping"]:
|
155 |
+
raise KeyError(
|
156 |
+
"Name '{}' already registered for {}.".format(
|
157 |
+
name, cls.mapping["lr_scheduler_name_mapping"][name]
|
158 |
+
)
|
159 |
+
)
|
160 |
+
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
|
161 |
+
return lr_sched_cls
|
162 |
+
|
163 |
+
return wrap
|
164 |
+
|
165 |
+
@classmethod
|
166 |
+
def register_runner(cls, name):
|
167 |
+
r"""Register a model to registry with key 'name'
|
168 |
+
|
169 |
+
Args:
|
170 |
+
name: Key with which the task will be registered.
|
171 |
+
|
172 |
+
Usage:
|
173 |
+
|
174 |
+
from minigpt4.common.registry import registry
|
175 |
+
"""
|
176 |
+
|
177 |
+
def wrap(runner_cls):
|
178 |
+
if name in cls.mapping["runner_name_mapping"]:
|
179 |
+
raise KeyError(
|
180 |
+
"Name '{}' already registered for {}.".format(
|
181 |
+
name, cls.mapping["runner_name_mapping"][name]
|
182 |
+
)
|
183 |
+
)
|
184 |
+
cls.mapping["runner_name_mapping"][name] = runner_cls
|
185 |
+
return runner_cls
|
186 |
+
|
187 |
+
return wrap
|
188 |
+
|
189 |
+
@classmethod
|
190 |
+
def register_path(cls, name, path):
|
191 |
+
r"""Register a path to registry with key 'name'
|
192 |
+
|
193 |
+
Args:
|
194 |
+
name: Key with which the path will be registered.
|
195 |
+
|
196 |
+
Usage:
|
197 |
+
|
198 |
+
from minigpt4.common.registry import registry
|
199 |
+
"""
|
200 |
+
assert isinstance(path, str), "All path must be str."
|
201 |
+
if name in cls.mapping["paths"]:
|
202 |
+
raise KeyError("Name '{}' already registered.".format(name))
|
203 |
+
cls.mapping["paths"][name] = path
|
204 |
+
|
205 |
+
@classmethod
|
206 |
+
def register(cls, name, obj):
|
207 |
+
r"""Register an item to registry with key 'name'
|
208 |
+
|
209 |
+
Args:
|
210 |
+
name: Key with which the item will be registered.
|
211 |
+
|
212 |
+
Usage::
|
213 |
+
|
214 |
+
from minigpt4.common.registry import registry
|
215 |
+
|
216 |
+
registry.register("config", {})
|
217 |
+
"""
|
218 |
+
path = name.split(".")
|
219 |
+
current = cls.mapping["state"]
|
220 |
+
|
221 |
+
for part in path[:-1]:
|
222 |
+
if part not in current:
|
223 |
+
current[part] = {}
|
224 |
+
current = current[part]
|
225 |
+
|
226 |
+
current[path[-1]] = obj
|
227 |
+
|
228 |
+
# @classmethod
|
229 |
+
# def get_trainer_class(cls, name):
|
230 |
+
# return cls.mapping["trainer_name_mapping"].get(name, None)
|
231 |
+
|
232 |
+
@classmethod
|
233 |
+
def get_builder_class(cls, name):
|
234 |
+
return cls.mapping["builder_name_mapping"].get(name, None)
|
235 |
+
|
236 |
+
@classmethod
|
237 |
+
def get_model_class(cls, name):
|
238 |
+
return cls.mapping["model_name_mapping"].get(name, None)
|
239 |
+
|
240 |
+
@classmethod
|
241 |
+
def get_task_class(cls, name):
|
242 |
+
return cls.mapping["task_name_mapping"].get(name, None)
|
243 |
+
|
244 |
+
@classmethod
|
245 |
+
def get_processor_class(cls, name):
|
246 |
+
return cls.mapping["processor_name_mapping"].get(name, None)
|
247 |
+
|
248 |
+
@classmethod
|
249 |
+
def get_lr_scheduler_class(cls, name):
|
250 |
+
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
|
251 |
+
|
252 |
+
@classmethod
|
253 |
+
def get_runner_class(cls, name):
|
254 |
+
return cls.mapping["runner_name_mapping"].get(name, None)
|
255 |
+
|
256 |
+
@classmethod
|
257 |
+
def list_runners(cls):
|
258 |
+
return sorted(cls.mapping["runner_name_mapping"].keys())
|
259 |
+
|
260 |
+
@classmethod
|
261 |
+
def list_models(cls):
|
262 |
+
return sorted(cls.mapping["model_name_mapping"].keys())
|
263 |
+
|
264 |
+
@classmethod
|
265 |
+
def list_tasks(cls):
|
266 |
+
return sorted(cls.mapping["task_name_mapping"].keys())
|
267 |
+
|
268 |
+
@classmethod
|
269 |
+
def list_processors(cls):
|
270 |
+
return sorted(cls.mapping["processor_name_mapping"].keys())
|
271 |
+
|
272 |
+
@classmethod
|
273 |
+
def list_lr_schedulers(cls):
|
274 |
+
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
|
275 |
+
|
276 |
+
@classmethod
|
277 |
+
def list_datasets(cls):
|
278 |
+
return sorted(cls.mapping["builder_name_mapping"].keys())
|
279 |
+
|
280 |
+
@classmethod
|
281 |
+
def get_path(cls, name):
|
282 |
+
return cls.mapping["paths"].get(name, None)
|
283 |
+
|
284 |
+
@classmethod
|
285 |
+
def get(cls, name, default=None, no_warning=False):
|
286 |
+
r"""Get an item from registry with key 'name'
|
287 |
+
|
288 |
+
Args:
|
289 |
+
name (string): Key whose value needs to be retrieved.
|
290 |
+
default: If passed and key is not in registry, default value will
|
291 |
+
be returned with a warning. Default: None
|
292 |
+
no_warning (bool): If passed as True, warning when key doesn't exist
|
293 |
+
will not be generated. Useful for MMF's
|
294 |
+
internal operations. Default: False
|
295 |
+
"""
|
296 |
+
original_name = name
|
297 |
+
name = name.split(".")
|
298 |
+
value = cls.mapping["state"]
|
299 |
+
for subname in name:
|
300 |
+
value = value.get(subname, default)
|
301 |
+
if value is default:
|
302 |
+
break
|
303 |
+
|
304 |
+
if (
|
305 |
+
"writer" in cls.mapping["state"]
|
306 |
+
and value == default
|
307 |
+
and no_warning is False
|
308 |
+
):
|
309 |
+
cls.mapping["state"]["writer"].warning(
|
310 |
+
"Key {} is not present in registry, returning default value "
|
311 |
+
"of {}".format(original_name, default)
|
312 |
+
)
|
313 |
+
return value
|
314 |
+
|
315 |
+
@classmethod
|
316 |
+
def unregister(cls, name):
|
317 |
+
r"""Remove an item from registry with key 'name'
|
318 |
+
|
319 |
+
Args:
|
320 |
+
name: Key which needs to be removed.
|
321 |
+
Usage::
|
322 |
+
|
323 |
+
from mmf.common.registry import registry
|
324 |
+
|
325 |
+
config = registry.unregister("config")
|
326 |
+
"""
|
327 |
+
return cls.mapping["state"].pop(name, None)
|
328 |
+
|
329 |
+
|
330 |
+
registry = Registry()
|
minigpt4/common/utils.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import io
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import pickle
|
13 |
+
import re
|
14 |
+
import shutil
|
15 |
+
import urllib
|
16 |
+
import urllib.error
|
17 |
+
import urllib.request
|
18 |
+
from typing import Optional
|
19 |
+
from urllib.parse import urlparse
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import pandas as pd
|
23 |
+
import yaml
|
24 |
+
from iopath.common.download import download
|
25 |
+
from iopath.common.file_io import file_lock, g_pathmgr
|
26 |
+
from minigpt4.common.registry import registry
|
27 |
+
from torch.utils.model_zoo import tqdm
|
28 |
+
from torchvision.datasets.utils import (
|
29 |
+
check_integrity,
|
30 |
+
download_file_from_google_drive,
|
31 |
+
extract_archive,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def now():
|
36 |
+
from datetime import datetime
|
37 |
+
|
38 |
+
return datetime.now().strftime("%Y%m%d%H%M")
|
39 |
+
|
40 |
+
|
41 |
+
def is_url(url_or_filename):
|
42 |
+
parsed = urlparse(url_or_filename)
|
43 |
+
return parsed.scheme in ("http", "https")
|
44 |
+
|
45 |
+
|
46 |
+
def get_cache_path(rel_path):
|
47 |
+
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
|
48 |
+
|
49 |
+
|
50 |
+
def get_abs_path(rel_path):
|
51 |
+
return os.path.join(registry.get_path("library_root"), rel_path)
|
52 |
+
|
53 |
+
|
54 |
+
def load_json(filename):
|
55 |
+
with open(filename, "r") as f:
|
56 |
+
return json.load(f)
|
57 |
+
|
58 |
+
|
59 |
+
# The following are adapted from torchvision and vissl
|
60 |
+
# torchvision: https://github.com/pytorch/vision
|
61 |
+
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
|
62 |
+
|
63 |
+
|
64 |
+
def makedir(dir_path):
|
65 |
+
"""
|
66 |
+
Create the directory if it does not exist.
|
67 |
+
"""
|
68 |
+
is_success = False
|
69 |
+
try:
|
70 |
+
if not g_pathmgr.exists(dir_path):
|
71 |
+
g_pathmgr.mkdirs(dir_path)
|
72 |
+
is_success = True
|
73 |
+
except BaseException:
|
74 |
+
print(f"Error creating directory: {dir_path}")
|
75 |
+
return is_success
|
76 |
+
|
77 |
+
|
78 |
+
def get_redirected_url(url: str):
|
79 |
+
"""
|
80 |
+
Given a URL, returns the URL it redirects to or the
|
81 |
+
original URL in case of no indirection
|
82 |
+
"""
|
83 |
+
import requests
|
84 |
+
|
85 |
+
with requests.Session() as session:
|
86 |
+
with session.get(url, stream=True, allow_redirects=True) as response:
|
87 |
+
if response.history:
|
88 |
+
return response.url
|
89 |
+
else:
|
90 |
+
return url
|
91 |
+
|
92 |
+
|
93 |
+
def to_google_drive_download_url(view_url: str) -> str:
|
94 |
+
"""
|
95 |
+
Utility function to transform a view URL of google drive
|
96 |
+
to a download URL for google drive
|
97 |
+
Example input:
|
98 |
+
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
|
99 |
+
Example output:
|
100 |
+
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
|
101 |
+
"""
|
102 |
+
splits = view_url.split("/")
|
103 |
+
assert splits[-1] == "view"
|
104 |
+
file_id = splits[-2]
|
105 |
+
return f"https://drive.google.com/uc?export=download&id={file_id}"
|
106 |
+
|
107 |
+
|
108 |
+
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
|
109 |
+
"""
|
110 |
+
Download a file from google drive
|
111 |
+
Downloading an URL from google drive requires confirmation when
|
112 |
+
the file of the size is too big (google drive notifies that
|
113 |
+
anti-viral checks cannot be performed on such files)
|
114 |
+
"""
|
115 |
+
import requests
|
116 |
+
|
117 |
+
with requests.Session() as session:
|
118 |
+
|
119 |
+
# First get the confirmation token and append it to the URL
|
120 |
+
with session.get(url, stream=True, allow_redirects=True) as response:
|
121 |
+
for k, v in response.cookies.items():
|
122 |
+
if k.startswith("download_warning"):
|
123 |
+
url = url + "&confirm=" + v
|
124 |
+
|
125 |
+
# Then download the content of the file
|
126 |
+
with session.get(url, stream=True, verify=True) as response:
|
127 |
+
makedir(output_path)
|
128 |
+
path = os.path.join(output_path, output_file_name)
|
129 |
+
total_size = int(response.headers.get("Content-length", 0))
|
130 |
+
with open(path, "wb") as file:
|
131 |
+
from tqdm import tqdm
|
132 |
+
|
133 |
+
with tqdm(total=total_size) as progress_bar:
|
134 |
+
for block in response.iter_content(
|
135 |
+
chunk_size=io.DEFAULT_BUFFER_SIZE
|
136 |
+
):
|
137 |
+
file.write(block)
|
138 |
+
progress_bar.update(len(block))
|
139 |
+
|
140 |
+
|
141 |
+
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
142 |
+
parts = urlparse(url)
|
143 |
+
|
144 |
+
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
145 |
+
return None
|
146 |
+
|
147 |
+
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
148 |
+
if match is None:
|
149 |
+
return None
|
150 |
+
|
151 |
+
return match.group("id")
|
152 |
+
|
153 |
+
|
154 |
+
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
155 |
+
with open(filename, "wb") as fh:
|
156 |
+
with urllib.request.urlopen(
|
157 |
+
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
158 |
+
) as response:
|
159 |
+
with tqdm(total=response.length) as pbar:
|
160 |
+
for chunk in iter(lambda: response.read(chunk_size), ""):
|
161 |
+
if not chunk:
|
162 |
+
break
|
163 |
+
pbar.update(chunk_size)
|
164 |
+
fh.write(chunk)
|
165 |
+
|
166 |
+
|
167 |
+
def download_url(
|
168 |
+
url: str,
|
169 |
+
root: str,
|
170 |
+
filename: Optional[str] = None,
|
171 |
+
md5: Optional[str] = None,
|
172 |
+
) -> None:
|
173 |
+
"""Download a file from a url and place it in root.
|
174 |
+
Args:
|
175 |
+
url (str): URL to download file from
|
176 |
+
root (str): Directory to place downloaded file in
|
177 |
+
filename (str, optional): Name to save the file under.
|
178 |
+
If None, use the basename of the URL.
|
179 |
+
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
180 |
+
"""
|
181 |
+
root = os.path.expanduser(root)
|
182 |
+
if not filename:
|
183 |
+
filename = os.path.basename(url)
|
184 |
+
fpath = os.path.join(root, filename)
|
185 |
+
|
186 |
+
makedir(root)
|
187 |
+
|
188 |
+
# check if file is already present locally
|
189 |
+
if check_integrity(fpath, md5):
|
190 |
+
print("Using downloaded and verified file: " + fpath)
|
191 |
+
return
|
192 |
+
|
193 |
+
# expand redirect chain if needed
|
194 |
+
url = get_redirected_url(url)
|
195 |
+
|
196 |
+
# check if file is located on Google Drive
|
197 |
+
file_id = _get_google_drive_file_id(url)
|
198 |
+
if file_id is not None:
|
199 |
+
return download_file_from_google_drive(file_id, root, filename, md5)
|
200 |
+
|
201 |
+
# download the file
|
202 |
+
try:
|
203 |
+
print("Downloading " + url + " to " + fpath)
|
204 |
+
_urlretrieve(url, fpath)
|
205 |
+
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
|
206 |
+
if url[:5] == "https":
|
207 |
+
url = url.replace("https:", "http:")
|
208 |
+
print(
|
209 |
+
"Failed download. Trying https -> http instead."
|
210 |
+
" Downloading " + url + " to " + fpath
|
211 |
+
)
|
212 |
+
_urlretrieve(url, fpath)
|
213 |
+
else:
|
214 |
+
raise e
|
215 |
+
|
216 |
+
# check integrity of downloaded file
|
217 |
+
if not check_integrity(fpath, md5):
|
218 |
+
raise RuntimeError("File not found or corrupted.")
|
219 |
+
|
220 |
+
|
221 |
+
def download_and_extract_archive(
|
222 |
+
url: str,
|
223 |
+
download_root: str,
|
224 |
+
extract_root: Optional[str] = None,
|
225 |
+
filename: Optional[str] = None,
|
226 |
+
md5: Optional[str] = None,
|
227 |
+
remove_finished: bool = False,
|
228 |
+
) -> None:
|
229 |
+
download_root = os.path.expanduser(download_root)
|
230 |
+
if extract_root is None:
|
231 |
+
extract_root = download_root
|
232 |
+
if not filename:
|
233 |
+
filename = os.path.basename(url)
|
234 |
+
|
235 |
+
download_url(url, download_root, filename, md5)
|
236 |
+
|
237 |
+
archive = os.path.join(download_root, filename)
|
238 |
+
print("Extracting {} to {}".format(archive, extract_root))
|
239 |
+
extract_archive(archive, extract_root, remove_finished)
|
240 |
+
|
241 |
+
|
242 |
+
def cache_url(url: str, cache_dir: str) -> str:
|
243 |
+
"""
|
244 |
+
This implementation downloads the remote resource and caches it locally.
|
245 |
+
The resource will only be downloaded if not previously requested.
|
246 |
+
"""
|
247 |
+
parsed_url = urlparse(url)
|
248 |
+
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
|
249 |
+
makedir(dirname)
|
250 |
+
filename = url.split("/")[-1]
|
251 |
+
cached = os.path.join(dirname, filename)
|
252 |
+
with file_lock(cached):
|
253 |
+
if not os.path.isfile(cached):
|
254 |
+
logging.info(f"Downloading {url} to {cached} ...")
|
255 |
+
cached = download(url, dirname, filename=filename)
|
256 |
+
logging.info(f"URL {url} cached in {cached}")
|
257 |
+
return cached
|
258 |
+
|
259 |
+
|
260 |
+
# TODO (prigoyal): convert this into RAII-style API
|
261 |
+
def create_file_symlink(file1, file2):
|
262 |
+
"""
|
263 |
+
Simply create the symlinks for a given file1 to file2.
|
264 |
+
Useful during model checkpointing to symlinks to the
|
265 |
+
latest successful checkpoint.
|
266 |
+
"""
|
267 |
+
try:
|
268 |
+
if g_pathmgr.exists(file2):
|
269 |
+
g_pathmgr.rm(file2)
|
270 |
+
g_pathmgr.symlink(file1, file2)
|
271 |
+
except Exception as e:
|
272 |
+
logging.info(f"Could NOT create symlink. Error: {e}")
|
273 |
+
|
274 |
+
|
275 |
+
def save_file(data, filename, append_to_json=True, verbose=True):
|
276 |
+
"""
|
277 |
+
Common i/o utility to handle saving data to various file formats.
|
278 |
+
Supported:
|
279 |
+
.pkl, .pickle, .npy, .json
|
280 |
+
Specifically for .json, users have the option to either append (default)
|
281 |
+
or rewrite by passing in Boolean value to append_to_json.
|
282 |
+
"""
|
283 |
+
if verbose:
|
284 |
+
logging.info(f"Saving data to file: {filename}")
|
285 |
+
file_ext = os.path.splitext(filename)[1]
|
286 |
+
if file_ext in [".pkl", ".pickle"]:
|
287 |
+
with g_pathmgr.open(filename, "wb") as fopen:
|
288 |
+
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
|
289 |
+
elif file_ext == ".npy":
|
290 |
+
with g_pathmgr.open(filename, "wb") as fopen:
|
291 |
+
np.save(fopen, data)
|
292 |
+
elif file_ext == ".json":
|
293 |
+
if append_to_json:
|
294 |
+
with g_pathmgr.open(filename, "a") as fopen:
|
295 |
+
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
296 |
+
fopen.flush()
|
297 |
+
else:
|
298 |
+
with g_pathmgr.open(filename, "w") as fopen:
|
299 |
+
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
300 |
+
fopen.flush()
|
301 |
+
elif file_ext == ".yaml":
|
302 |
+
with g_pathmgr.open(filename, "w") as fopen:
|
303 |
+
dump = yaml.dump(data)
|
304 |
+
fopen.write(dump)
|
305 |
+
fopen.flush()
|
306 |
+
else:
|
307 |
+
raise Exception(f"Saving {file_ext} is not supported yet")
|
308 |
+
|
309 |
+
if verbose:
|
310 |
+
logging.info(f"Saved data to file: {filename}")
|
311 |
+
|
312 |
+
|
313 |
+
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
|
314 |
+
"""
|
315 |
+
Common i/o utility to handle loading data from various file formats.
|
316 |
+
Supported:
|
317 |
+
.pkl, .pickle, .npy, .json
|
318 |
+
For the npy files, we support reading the files in mmap_mode.
|
319 |
+
If the mmap_mode of reading is not successful, we load data without the
|
320 |
+
mmap_mode.
|
321 |
+
"""
|
322 |
+
if verbose:
|
323 |
+
logging.info(f"Loading data from file: {filename}")
|
324 |
+
|
325 |
+
file_ext = os.path.splitext(filename)[1]
|
326 |
+
if file_ext == ".txt":
|
327 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
328 |
+
data = fopen.readlines()
|
329 |
+
elif file_ext in [".pkl", ".pickle"]:
|
330 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
331 |
+
data = pickle.load(fopen, encoding="latin1")
|
332 |
+
elif file_ext == ".npy":
|
333 |
+
if mmap_mode:
|
334 |
+
try:
|
335 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
336 |
+
data = np.load(
|
337 |
+
fopen,
|
338 |
+
allow_pickle=allow_pickle,
|
339 |
+
encoding="latin1",
|
340 |
+
mmap_mode=mmap_mode,
|
341 |
+
)
|
342 |
+
except ValueError as e:
|
343 |
+
logging.info(
|
344 |
+
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
|
345 |
+
)
|
346 |
+
data = np.load(
|
347 |
+
filename,
|
348 |
+
allow_pickle=allow_pickle,
|
349 |
+
encoding="latin1",
|
350 |
+
mmap_mode=mmap_mode,
|
351 |
+
)
|
352 |
+
logging.info("Successfully loaded without g_pathmgr")
|
353 |
+
except Exception:
|
354 |
+
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
|
355 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
356 |
+
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
357 |
+
else:
|
358 |
+
with g_pathmgr.open(filename, "rb") as fopen:
|
359 |
+
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
360 |
+
elif file_ext == ".json":
|
361 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
362 |
+
data = json.load(fopen)
|
363 |
+
elif file_ext == ".yaml":
|
364 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
365 |
+
data = yaml.load(fopen, Loader=yaml.FullLoader)
|
366 |
+
elif file_ext == ".csv":
|
367 |
+
with g_pathmgr.open(filename, "r") as fopen:
|
368 |
+
data = pd.read_csv(fopen)
|
369 |
+
else:
|
370 |
+
raise Exception(f"Reading from {file_ext} is not supported yet")
|
371 |
+
return data
|
372 |
+
|
373 |
+
|
374 |
+
def abspath(resource_path: str):
|
375 |
+
"""
|
376 |
+
Make a path absolute, but take into account prefixes like
|
377 |
+
"http://" or "manifold://"
|
378 |
+
"""
|
379 |
+
regex = re.compile(r"^\w+://")
|
380 |
+
if regex.match(resource_path) is None:
|
381 |
+
return os.path.abspath(resource_path)
|
382 |
+
else:
|
383 |
+
return resource_path
|
384 |
+
|
385 |
+
|
386 |
+
def makedir(dir_path):
|
387 |
+
"""
|
388 |
+
Create the directory if it does not exist.
|
389 |
+
"""
|
390 |
+
is_success = False
|
391 |
+
try:
|
392 |
+
if not g_pathmgr.exists(dir_path):
|
393 |
+
g_pathmgr.mkdirs(dir_path)
|
394 |
+
is_success = True
|
395 |
+
except BaseException:
|
396 |
+
logging.info(f"Error creating directory: {dir_path}")
|
397 |
+
return is_success
|
398 |
+
|
399 |
+
|
400 |
+
def is_url(input_url):
|
401 |
+
"""
|
402 |
+
Check if an input string is a url. look for http(s):// and ignoring the case
|
403 |
+
"""
|
404 |
+
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
|
405 |
+
return is_url
|
406 |
+
|
407 |
+
|
408 |
+
def cleanup_dir(dir):
|
409 |
+
"""
|
410 |
+
Utility for deleting a directory. Useful for cleaning the storage space
|
411 |
+
that contains various training artifacts like checkpoints, data etc.
|
412 |
+
"""
|
413 |
+
if os.path.exists(dir):
|
414 |
+
logging.info(f"Deleting directory: {dir}")
|
415 |
+
shutil.rmtree(dir)
|
416 |
+
logging.info(f"Deleted contents of directory: {dir}")
|
417 |
+
|
418 |
+
|
419 |
+
def get_file_size(filename):
|
420 |
+
"""
|
421 |
+
Given a file, get the size of file in MB
|
422 |
+
"""
|
423 |
+
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
424 |
+
return size_in_mb
|
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
import sys
|
4 |
+
dataDir = '../../VQA'
|
5 |
+
sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
|
6 |
+
from vqa import VQA
|
7 |
+
from vqaEvaluation.vqaEval import VQAEval
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import skimage.io as io
|
10 |
+
import json
|
11 |
+
import random
|
12 |
+
import os
|
13 |
+
|
14 |
+
# set up file names and paths
|
15 |
+
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
|
16 |
+
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
|
17 |
+
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
|
18 |
+
dataSubType ='train2014'
|
19 |
+
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
|
20 |
+
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
|
21 |
+
imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
|
22 |
+
resultType ='fake'
|
23 |
+
fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
|
24 |
+
|
25 |
+
# An example result json file has been provided in './Results' folder.
|
26 |
+
|
27 |
+
[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
|
28 |
+
resultType, fileType) for fileType in fileTypes]
|
29 |
+
|
30 |
+
# create vqa object and vqaRes object
|
31 |
+
vqa = VQA(annFile, quesFile)
|
32 |
+
vqaRes = vqa.loadRes(resFile, quesFile)
|
33 |
+
|
34 |
+
# create vqaEval object by taking vqa and vqaRes
|
35 |
+
vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
|
36 |
+
|
37 |
+
# evaluate results
|
38 |
+
"""
|
39 |
+
If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
|
40 |
+
By default it uses all the question ids in annotation file
|
41 |
+
"""
|
42 |
+
vqaEval.evaluate()
|
43 |
+
|
44 |
+
# print accuracies
|
45 |
+
print "\n"
|
46 |
+
print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
|
47 |
+
print "Per Question Type Accuracy is the following:"
|
48 |
+
for quesType in vqaEval.accuracy['perQuestionType']:
|
49 |
+
print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
|
50 |
+
print "\n"
|
51 |
+
print "Per Answer Type Accuracy is the following:"
|
52 |
+
for ansType in vqaEval.accuracy['perAnswerType']:
|
53 |
+
print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
|
54 |
+
print "\n"
|
55 |
+
# demo how to use evalQA to retrieve low score result
|
56 |
+
evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
|
57 |
+
if len(evals) > 0:
|
58 |
+
print 'ground truth answers'
|
59 |
+
randomEval = random.choice(evals)
|
60 |
+
randomAnn = vqa.loadQA(randomEval)
|
61 |
+
vqa.showQA(randomAnn)
|
62 |
+
|
63 |
+
print '\n'
|
64 |
+
print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
|
65 |
+
ann = vqaRes.loadQA(randomEval)[0]
|
66 |
+
print "Answer: %s\n" %(ann['answer'])
|
67 |
+
|
68 |
+
imgId = randomAnn[0]['image_id']
|
69 |
+
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
70 |
+
if os.path.isfile(imgDir + imgFilename):
|
71 |
+
I = io.imread(imgDir + imgFilename)
|
72 |
+
plt.imshow(I)
|
73 |
+
plt.axis('off')
|
74 |
+
plt.show()
|
75 |
+
|
76 |
+
# plot accuracy for various question types
|
77 |
+
plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
|
78 |
+
plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
|
79 |
+
plt.title('Per Question Type Accuracy', fontsize=10)
|
80 |
+
plt.xlabel('Question Types', fontsize=10)
|
81 |
+
plt.ylabel('Accuracy', fontsize=10)
|
82 |
+
plt.show()
|
83 |
+
|
84 |
+
# save evaluation results to ./Results folder
|
85 |
+
json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
|
86 |
+
json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
|
87 |
+
json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
|
88 |
+
json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
|
89 |
+
|
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
author='aagrawal'
|
minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
|
3 |
+
__author__='aagrawal'
|
4 |
+
|
5 |
+
import re
|
6 |
+
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
7 |
+
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
|
8 |
+
import sys
|
9 |
+
|
10 |
+
|
11 |
+
class VQAEval:
|
12 |
+
def __init__(self, vqa, vqaRes, n=2):
|
13 |
+
self.n = n
|
14 |
+
self.accuracy = {}
|
15 |
+
self.evalQA = {}
|
16 |
+
self.evalQuesType = {}
|
17 |
+
self.evalAnsType = {}
|
18 |
+
self.vqa = vqa
|
19 |
+
self.vqaRes = vqaRes
|
20 |
+
self.params = {'question_id': vqa.getQuesIds()}
|
21 |
+
self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
|
22 |
+
"couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
|
23 |
+
"hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
|
24 |
+
"he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
|
25 |
+
"Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
|
26 |
+
"maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
|
27 |
+
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
|
28 |
+
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
|
29 |
+
"she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
|
30 |
+
"somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
|
31 |
+
"somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
|
32 |
+
"someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
|
33 |
+
"something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
|
34 |
+
"there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
|
35 |
+
"they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
|
36 |
+
"wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
|
37 |
+
"whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
|
38 |
+
"whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
|
39 |
+
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
|
40 |
+
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
|
41 |
+
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
|
42 |
+
"youll": "you'll", "youre": "you're", "youve": "you've"}
|
43 |
+
self.manualMap = { 'none': '0',
|
44 |
+
'zero': '0',
|
45 |
+
'one': '1',
|
46 |
+
'two': '2',
|
47 |
+
'three': '3',
|
48 |
+
'four': '4',
|
49 |
+
'five': '5',
|
50 |
+
'six': '6',
|
51 |
+
'seven': '7',
|
52 |
+
'eight': '8',
|
53 |
+
'nine': '9',
|
54 |
+
'ten': '10'
|
55 |
+
}
|
56 |
+
self.articles = ['a',
|
57 |
+
'an',
|
58 |
+
'the'
|
59 |
+
]
|
60 |
+
|
61 |
+
|
62 |
+
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
63 |
+
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
64 |
+
self.punct = [';', r"/", '[', ']', '"', '{', '}',
|
65 |
+
'(', ')', '=', '+', '\\', '_', '-',
|
66 |
+
'>', '<', '@', '`', ',', '?', '!']
|
67 |
+
|
68 |
+
|
69 |
+
def evaluate(self, quesIds=None):
|
70 |
+
if quesIds == None:
|
71 |
+
quesIds = [quesId for quesId in self.params['question_id']]
|
72 |
+
gts = {}
|
73 |
+
res = {}
|
74 |
+
for quesId in quesIds:
|
75 |
+
gts[quesId] = self.vqa.qa[quesId]
|
76 |
+
res[quesId] = self.vqaRes.qa[quesId]
|
77 |
+
|
78 |
+
# =================================================
|
79 |
+
# Compute accuracy
|
80 |
+
# =================================================
|
81 |
+
accQA = []
|
82 |
+
accQuesType = {}
|
83 |
+
accAnsType = {}
|
84 |
+
# print "computing accuracy"
|
85 |
+
step = 0
|
86 |
+
for quesId in quesIds:
|
87 |
+
for ansDic in gts[quesId]['answers']:
|
88 |
+
ansDic['answer'] = ansDic['answer'].replace('\n', ' ')
|
89 |
+
ansDic['answer'] = ansDic['answer'].replace('\t', ' ')
|
90 |
+
ansDic['answer'] = ansDic['answer'].strip()
|
91 |
+
resAns = res[quesId]['answer']
|
92 |
+
resAns = resAns.replace('\n', ' ')
|
93 |
+
resAns = resAns.replace('\t', ' ')
|
94 |
+
resAns = resAns.strip()
|
95 |
+
gtAcc = []
|
96 |
+
gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
|
97 |
+
|
98 |
+
if len(set(gtAnswers)) > 1:
|
99 |
+
for ansDic in gts[quesId]['answers']:
|
100 |
+
ansDic['answer'] = self.processPunctuation(ansDic['answer'])
|
101 |
+
ansDic['answer'] = self.processDigitArticle(ansDic['answer'])
|
102 |
+
resAns = self.processPunctuation(resAns)
|
103 |
+
resAns = self.processDigitArticle(resAns)
|
104 |
+
|
105 |
+
for gtAnsDatum in gts[quesId]['answers']:
|
106 |
+
otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
|
107 |
+
matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()]
|
108 |
+
acc = min(1, float(len(matchingAns))/3)
|
109 |
+
gtAcc.append(acc)
|
110 |
+
quesType = gts[quesId]['question_type']
|
111 |
+
ansType = gts[quesId]['answer_type']
|
112 |
+
avgGTAcc = float(sum(gtAcc))/len(gtAcc)
|
113 |
+
accQA.append(avgGTAcc)
|
114 |
+
if quesType not in accQuesType:
|
115 |
+
accQuesType[quesType] = []
|
116 |
+
accQuesType[quesType].append(avgGTAcc)
|
117 |
+
if ansType not in accAnsType:
|
118 |
+
accAnsType[ansType] = []
|
119 |
+
accAnsType[ansType].append(avgGTAcc)
|
120 |
+
self.setEvalQA(quesId, avgGTAcc)
|
121 |
+
self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
122 |
+
self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
123 |
+
if step%100 == 0:
|
124 |
+
self.updateProgress(step/float(len(quesIds)))
|
125 |
+
step = step + 1
|
126 |
+
|
127 |
+
self.setAccuracy(accQA, accQuesType, accAnsType)
|
128 |
+
# print "Done computing accuracy"
|
129 |
+
|
130 |
+
def processPunctuation(self, inText):
|
131 |
+
outText = inText
|
132 |
+
for p in self.punct:
|
133 |
+
if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
|
134 |
+
outText = outText.replace(p, '')
|
135 |
+
else:
|
136 |
+
outText = outText.replace(p, ' ')
|
137 |
+
outText = self.periodStrip.sub("",
|
138 |
+
outText,
|
139 |
+
re.UNICODE)
|
140 |
+
return outText
|
141 |
+
|
142 |
+
def processDigitArticle(self, inText):
|
143 |
+
outText = []
|
144 |
+
tempText = inText.lower().split()
|
145 |
+
for word in tempText:
|
146 |
+
word = self.manualMap.setdefault(word, word)
|
147 |
+
if word not in self.articles:
|
148 |
+
outText.append(word)
|
149 |
+
else:
|
150 |
+
pass
|
151 |
+
for wordId, word in enumerate(outText):
|
152 |
+
if word in self.contractions:
|
153 |
+
outText[wordId] = self.contractions[word]
|
154 |
+
outText = ' '.join(outText)
|
155 |
+
return outText
|
156 |
+
|
157 |
+
def setAccuracy(self, accQA, accQuesType, accAnsType):
|
158 |
+
self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
|
159 |
+
self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
|
160 |
+
self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
|
161 |
+
|
162 |
+
def setEvalQA(self, quesId, acc):
|
163 |
+
self.evalQA[quesId] = round(100*acc, self.n)
|
164 |
+
|
165 |
+
def setEvalQuesType(self, quesId, quesType, acc):
|
166 |
+
if quesType not in self.evalQuesType:
|
167 |
+
self.evalQuesType[quesType] = {}
|
168 |
+
self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
|
169 |
+
|
170 |
+
def setEvalAnsType(self, quesId, ansType, acc):
|
171 |
+
if ansType not in self.evalAnsType:
|
172 |
+
self.evalAnsType[ansType] = {}
|
173 |
+
self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
|
174 |
+
|
175 |
+
def updateProgress(self, progress):
|
176 |
+
barLength = 20
|
177 |
+
status = ""
|
178 |
+
if isinstance(progress, int):
|
179 |
+
progress = float(progress)
|
180 |
+
if not isinstance(progress, float):
|
181 |
+
progress = 0
|
182 |
+
status = "error: progress var must be float\r\n"
|
183 |
+
if progress < 0:
|
184 |
+
progress = 0
|
185 |
+
status = "Halt...\r\n"
|
186 |
+
if progress >= 1:
|
187 |
+
progress = 1
|
188 |
+
status = "Done...\r\n"
|
189 |
+
block = int(round(barLength*progress))
|
190 |
+
text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
|
191 |
+
sys.stdout.write(text)
|
192 |
+
sys.stdout.flush()
|
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
from vqaTools.vqa import VQA
|
4 |
+
import random
|
5 |
+
import skimage.io as io
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import os
|
8 |
+
|
9 |
+
dataDir ='../../VQA'
|
10 |
+
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
|
11 |
+
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
|
12 |
+
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
|
13 |
+
dataSubType ='train2014'
|
14 |
+
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
|
15 |
+
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
|
16 |
+
imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
|
17 |
+
|
18 |
+
# initialize VQA api for QA annotations
|
19 |
+
vqa=VQA(annFile, quesFile)
|
20 |
+
|
21 |
+
# load and display QA annotations for given question types
|
22 |
+
"""
|
23 |
+
All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder.
|
24 |
+
"""
|
25 |
+
annIds = vqa.getQuesIds(quesTypes='how many');
|
26 |
+
anns = vqa.loadQA(annIds)
|
27 |
+
randomAnn = random.choice(anns)
|
28 |
+
vqa.showQA([randomAnn])
|
29 |
+
imgId = randomAnn['image_id']
|
30 |
+
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
31 |
+
if os.path.isfile(imgDir + imgFilename):
|
32 |
+
I = io.imread(imgDir + imgFilename)
|
33 |
+
plt.imshow(I)
|
34 |
+
plt.axis('off')
|
35 |
+
plt.show()
|
36 |
+
|
37 |
+
# load and display QA annotations for given answer types
|
38 |
+
"""
|
39 |
+
ansTypes can be one of the following
|
40 |
+
yes/no
|
41 |
+
number
|
42 |
+
other
|
43 |
+
"""
|
44 |
+
annIds = vqa.getQuesIds(ansTypes='yes/no');
|
45 |
+
anns = vqa.loadQA(annIds)
|
46 |
+
randomAnn = random.choice(anns)
|
47 |
+
vqa.showQA([randomAnn])
|
48 |
+
imgId = randomAnn['image_id']
|
49 |
+
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
50 |
+
if os.path.isfile(imgDir + imgFilename):
|
51 |
+
I = io.imread(imgDir + imgFilename)
|
52 |
+
plt.imshow(I)
|
53 |
+
plt.axis('off')
|
54 |
+
plt.show()
|
55 |
+
|
56 |
+
# load and display QA annotations for given images
|
57 |
+
"""
|
58 |
+
Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
|
59 |
+
Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
|
60 |
+
"""
|
61 |
+
ids = vqa.getImgIds()
|
62 |
+
annIds = vqa.getQuesIds(imgIds=random.sample(ids,5));
|
63 |
+
anns = vqa.loadQA(annIds)
|
64 |
+
randomAnn = random.choice(anns)
|
65 |
+
vqa.showQA([randomAnn])
|
66 |
+
imgId = randomAnn['image_id']
|
67 |
+
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
68 |
+
if os.path.isfile(imgDir + imgFilename):
|
69 |
+
I = io.imread(imgDir + imgFilename)
|
70 |
+
plt.imshow(I)
|
71 |
+
plt.axis('off')
|
72 |
+
plt.show()
|
73 |
+
|
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'aagrawal'
|
minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__author__ = 'aagrawal'
|
2 |
+
__version__ = '0.9'
|
3 |
+
|
4 |
+
# Interface for accessing the VQA dataset.
|
5 |
+
|
6 |
+
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
7 |
+
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
|
8 |
+
|
9 |
+
# The following functions are defined:
|
10 |
+
# VQA - VQA class that loads VQA annotation file and prepares data structures.
|
11 |
+
# getQuesIds - Get question ids that satisfy given filter conditions.
|
12 |
+
# getImgIds - Get image ids that satisfy given filter conditions.
|
13 |
+
# loadQA - Load questions and answers with the specified question ids.
|
14 |
+
# showQA - Display the specified questions and answers.
|
15 |
+
# loadRes - Load result file and create result object.
|
16 |
+
|
17 |
+
# Help on each function can be accessed by: "help(COCO.function)"
|
18 |
+
|
19 |
+
import json
|
20 |
+
import datetime
|
21 |
+
import copy
|
22 |
+
|
23 |
+
|
24 |
+
class VQA:
|
25 |
+
def __init__(self, annotation_file=None, question_file=None):
|
26 |
+
"""
|
27 |
+
Constructor of VQA helper class for reading and visualizing questions and answers.
|
28 |
+
:param annotation_file (str): location of VQA annotation file
|
29 |
+
:return:
|
30 |
+
"""
|
31 |
+
# load dataset
|
32 |
+
self.dataset = {}
|
33 |
+
self.questions = {}
|
34 |
+
self.qa = {}
|
35 |
+
self.qqa = {}
|
36 |
+
self.imgToQA = {}
|
37 |
+
if not annotation_file == None and not question_file == None:
|
38 |
+
# print 'loading VQA annotations and questions into memory...'
|
39 |
+
time_t = datetime.datetime.utcnow()
|
40 |
+
dataset = json.load(open(annotation_file, 'r'))
|
41 |
+
questions = json.load(open(question_file, 'r'))
|
42 |
+
# print datetime.datetime.utcnow() - time_t
|
43 |
+
self.dataset = dataset
|
44 |
+
self.questions = questions
|
45 |
+
self.createIndex()
|
46 |
+
|
47 |
+
def createIndex(self):
|
48 |
+
imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
|
49 |
+
qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
|
50 |
+
qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
|
51 |
+
for ann in self.dataset['annotations']:
|
52 |
+
imgToQA[ann['image_id']] += [ann]
|
53 |
+
qa[ann['question_id']] = ann
|
54 |
+
for ques in self.questions['questions']:
|
55 |
+
qqa[ques['question_id']] = ques
|
56 |
+
# print 'index created!'
|
57 |
+
|
58 |
+
# create class members
|
59 |
+
self.qa = qa
|
60 |
+
self.qqa = qqa
|
61 |
+
self.imgToQA = imgToQA
|
62 |
+
|
63 |
+
def info(self):
|
64 |
+
"""
|
65 |
+
Print information about the VQA annotation file.
|
66 |
+
:return:
|
67 |
+
"""
|
68 |
+
|
69 |
+
# for key, value in self.datset['info'].items():
|
70 |
+
# print '%s: %s'%(key, value)
|
71 |
+
|
72 |
+
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
|
73 |
+
"""
|
74 |
+
Get question ids that satisfy given filter conditions. default skips that filter
|
75 |
+
:param imgIds (int array) : get question ids for given imgs
|
76 |
+
quesTypes (str array) : get question ids for given question types
|
77 |
+
ansTypes (str array) : get question ids for given answer types
|
78 |
+
:return: ids (int array) : integer array of question ids
|
79 |
+
"""
|
80 |
+
imgIds = imgIds if type(imgIds) == list else [imgIds]
|
81 |
+
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
82 |
+
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
83 |
+
|
84 |
+
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
|
85 |
+
anns = self.dataset['annotations']
|
86 |
+
else:
|
87 |
+
if not len(imgIds) == 0:
|
88 |
+
anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], [])
|
89 |
+
else:
|
90 |
+
anns = self.dataset['annotations']
|
91 |
+
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
|
92 |
+
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
|
93 |
+
ids = [ann['question_id'] for ann in anns]
|
94 |
+
return ids
|
95 |
+
|
96 |
+
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
|
97 |
+
"""
|
98 |
+
Get image ids that satisfy given filter conditions. default skips that filter
|
99 |
+
:param quesIds (int array) : get image ids for given question ids
|
100 |
+
quesTypes (str array) : get image ids for given question types
|
101 |
+
ansTypes (str array) : get image ids for given answer types
|
102 |
+
:return: ids (int array) : integer array of image ids
|
103 |
+
"""
|
104 |
+
quesIds = quesIds if type(quesIds) == list else [quesIds]
|
105 |
+
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
106 |
+
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
107 |
+
|
108 |
+
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
|
109 |
+
anns = self.dataset['annotations']
|
110 |
+
else:
|
111 |
+
if not len(quesIds) == 0:
|
112 |
+
anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
|
113 |
+
else:
|
114 |
+
anns = self.dataset['annotations']
|
115 |
+
anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
|
116 |
+
anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
|
117 |
+
ids = [ann['image_id'] for ann in anns]
|
118 |
+
return ids
|
119 |
+
|
120 |
+
def loadQA(self, ids=[]):
|
121 |
+
"""
|
122 |
+
Load questions and answers with the specified question ids.
|
123 |
+
:param ids (int array) : integer ids specifying question ids
|
124 |
+
:return: qa (object array) : loaded qa objects
|
125 |
+
"""
|
126 |
+
if type(ids) == list:
|
127 |
+
return [self.qa[id] for id in ids]
|
128 |
+
elif type(ids) == int:
|
129 |
+
return [self.qa[ids]]
|
130 |
+
|
131 |
+
def showQA(self, anns):
|
132 |
+
"""
|
133 |
+
Display the specified annotations.
|
134 |
+
:param anns (array of object): annotations to display
|
135 |
+
:return: None
|
136 |
+
"""
|
137 |
+
if len(anns) == 0:
|
138 |
+
return 0
|
139 |
+
for ann in anns:
|
140 |
+
quesId = ann['question_id']
|
141 |
+
print("Question: %s" % (self.qqa[quesId]['question']))
|
142 |
+
for ans in ann['answers']:
|
143 |
+
print("Answer %d: %s" % (ans['answer_id'], ans['answer']))
|
144 |
+
|
145 |
+
def loadRes(self, resFile, quesFile):
|
146 |
+
"""
|
147 |
+
Load result file and return a result object.
|
148 |
+
:param resFile (str) : file name of result file
|
149 |
+
:return: res (obj) : result api object
|
150 |
+
"""
|
151 |
+
res = VQA()
|
152 |
+
res.questions = json.load(open(quesFile))
|
153 |
+
res.dataset['info'] = copy.deepcopy(self.questions['info'])
|
154 |
+
res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
|
155 |
+
res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
|
156 |
+
res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
|
157 |
+
res.dataset['license'] = copy.deepcopy(self.questions['license'])
|
158 |
+
|
159 |
+
# print 'Loading and preparing results... '
|
160 |
+
time_t = datetime.datetime.utcnow()
|
161 |
+
anns = json.load(open(resFile))
|
162 |
+
assert type(anns) == list, 'results is not an array of objects'
|
163 |
+
annsQuesIds = [ann['question_id'] for ann in anns]
|
164 |
+
assert set(annsQuesIds) == set(self.getQuesIds()), \
|
165 |
+
'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
|
166 |
+
for ann in anns:
|
167 |
+
quesId = ann['question_id']
|
168 |
+
if res.dataset['task_type'] == 'Multiple Choice':
|
169 |
+
assert ann['answer'] in self.qqa[quesId][
|
170 |
+
'multiple_choices'], 'predicted answer is not one of the multiple choices'
|
171 |
+
qaAnn = self.qa[quesId]
|
172 |
+
ann['image_id'] = qaAnn['image_id']
|
173 |
+
ann['question_type'] = qaAnn['question_type']
|
174 |
+
ann['answer_type'] = qaAnn['answer_type']
|
175 |
+
# print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())
|
176 |
+
|
177 |
+
res.dataset['annotations'] = anns
|
178 |
+
res.createIndex()
|
179 |
+
return res
|
minigpt4/common/vqa_tools/VQA/README.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset.
|
2 |
+
===================
|
3 |
+
## VQA v2.0 release ##
|
4 |
+
This release consists of
|
5 |
+
- Real
|
6 |
+
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
|
7 |
+
- 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing
|
8 |
+
- 4,437,570 answers for training and 2,143,540 answers for validation (10 per question)
|
9 |
+
|
10 |
+
There is only one type of task
|
11 |
+
- Open-ended task
|
12 |
+
|
13 |
+
## VQA v1.0 release ##
|
14 |
+
This release consists of
|
15 |
+
- Real
|
16 |
+
- 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
|
17 |
+
- 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image)
|
18 |
+
- 2,483,490 answers for training and 1,215,120 answers for validation (10 per question)
|
19 |
+
- Abstract
|
20 |
+
- 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images
|
21 |
+
- 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image)
|
22 |
+
- 600,000 answers for training and 300,000 answers for validation (10 per question)
|
23 |
+
|
24 |
+
There are two types of tasks
|
25 |
+
- Open-ended task
|
26 |
+
- Multiple-choice task (18 choices per question)
|
27 |
+
|
28 |
+
## Requirements ##
|
29 |
+
- python 2.7
|
30 |
+
- scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation)
|
31 |
+
- matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation)
|
32 |
+
|
33 |
+
## Files ##
|
34 |
+
./Questions
|
35 |
+
- For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
|
36 |
+
- For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
|
37 |
+
- Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
|
38 |
+
- [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip)
|
39 |
+
- [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip)
|
40 |
+
- Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip).
|
41 |
+
|
42 |
+
./Annotations
|
43 |
+
- For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
|
44 |
+
- For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
|
45 |
+
- Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
|
46 |
+
- [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip)
|
47 |
+
- [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip)
|
48 |
+
- Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip).
|
49 |
+
|
50 |
+
./Images
|
51 |
+
- For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders.
|
52 |
+
- For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders.
|
53 |
+
|
54 |
+
./PythonHelperTools
|
55 |
+
- This directory contains the Python API to read and visualize the VQA dataset
|
56 |
+
- vqaDemo.py (demo script)
|
57 |
+
- vqaTools (API to read and visualize data)
|
58 |
+
|
59 |
+
./PythonEvaluationTools
|
60 |
+
- This directory contains the Python evaluation code
|
61 |
+
- vqaEvalDemo.py (evaluation demo script)
|
62 |
+
- vqaEvaluation (evaluation code)
|
63 |
+
|
64 |
+
./Results
|
65 |
+
- OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo)
|
66 |
+
- Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details.
|
67 |
+
|
68 |
+
./QuestionTypes
|
69 |
+
- This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k.
|
70 |
+
- mscoco_question_types.txt
|
71 |
+
- abstract_v002_question_types.txt
|
72 |
+
|
73 |
+
## References ##
|
74 |
+
- [VQA: Visual Question Answering](http://visualqa.org/)
|
75 |
+
- [Microsoft COCO](http://mscoco.org/)
|
76 |
+
|
77 |
+
## Developers ##
|
78 |
+
- Aishwarya Agrawal (Virginia Tech)
|
79 |
+
- Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco).
|
80 |
+
- The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption).
|
minigpt4/common/vqa_tools/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
__author__ = "aagrawal"
|
minigpt4/common/vqa_tools/aokvqa/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2022 Allen Institute for Artificial Intelligence
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
minigpt4/common/vqa_tools/aokvqa/README.md
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A-OKVQA
|
2 |
+
|
3 |
+
Official repository for **A-OKVQA: A Benchmark for Visual Question Answering using World Knowledge**.
|
4 |
+
|
5 |
+
Links: [[Paper]](https://arxiv.org/abs/2206.01718) [[Website]](http://a-okvqa.allenai.org) [[Leaderboard]](https://leaderboard.allenai.org/a-okvqa/submissions/public)
|
6 |
+
|
7 |
+
### Abstract
|
8 |
+
|
9 |
+
The Visual Question Answering (VQA) task aspires to provide a meaningful testbed for the development of AI models that can jointly reason over visual and natural language inputs. Despite a proliferation of VQA datasets, this goal is hindered by a set of common limitations. These include a reliance on relatively simplistic questions that are repetitive in both concepts and linguistic structure, little world knowledge needed outside of the paired image, and limited reasoning required to arrive at the correct answer. We introduce A-OKVQA, a crowdsourced dataset composed of a diverse set of about 25K questions requiring a broad base of commonsense and world knowledge to answer. In contrast to the existing knowledge-based VQA datasets, the questions generally cannot be answered by simply querying a knowledge base, and instead require some form of commonsense reasoning about the scene depicted in the image. We demonstrate the potential of this new dataset through a detailed analysis of its contents and baseline performance measurements over a variety of state-of-the-art vision–language models.
|
10 |
+
|
11 |
+
![dataset_web](https://user-images.githubusercontent.com/28768645/170799740-f0d9ea60-6aff-4322-98d5-cae8e05983f4.svg)
|
12 |
+
|
13 |
+
<hr>
|
14 |
+
|
15 |
+
#### Table of Contents
|
16 |
+
|
17 |
+
- [Getting started](#getting-started)
|
18 |
+
* [Downloading the dataset](#downloading-the-dataset)
|
19 |
+
- [Evaluation & Leaderboard](#evaluation)
|
20 |
+
- [Codebase](#codebase)
|
21 |
+
* [Preparing data](#preparing-data)
|
22 |
+
* [Models and Predictions](#models-and-predictions)
|
23 |
+
|
24 |
+
<hr>
|
25 |
+
|
26 |
+
## Getting started
|
27 |
+
|
28 |
+
```bash
|
29 |
+
git clone --single-branch --recurse-submodules https://github.com/allenai/aokvqa.git
|
30 |
+
|
31 |
+
cd aokvqa
|
32 |
+
export PYTHONPATH=.
|
33 |
+
|
34 |
+
conda env create --name aokvqa
|
35 |
+
conda activate aokvqa
|
36 |
+
```
|
37 |
+
|
38 |
+
### Downloading the dataset
|
39 |
+
|
40 |
+
```bash
|
41 |
+
export AOKVQA_DIR=./datasets/aokvqa/
|
42 |
+
mkdir -p ${AOKVQA_DIR}
|
43 |
+
|
44 |
+
curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C ${AOKVQA_DIR}
|
45 |
+
```
|
46 |
+
|
47 |
+
<details> <summary><b>Downloading COCO 2017</b></summary>
|
48 |
+
|
49 |
+
```bash
|
50 |
+
export COCO_DIR=./datasets/coco/
|
51 |
+
mkdir -p ${COCO_DIR}
|
52 |
+
|
53 |
+
for split in train val test; do
|
54 |
+
wget "http://images.cocodataset.org/zips/${split}2017.zip"
|
55 |
+
unzip "${split}2017.zip" -d ${COCO_DIR}; rm "${split}2017.zip"
|
56 |
+
done
|
57 |
+
|
58 |
+
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
|
59 |
+
unzip annotations_trainval2017.zip -d ${COCO_DIR}; rm annotations_trainval2017.zip
|
60 |
+
```
|
61 |
+
|
62 |
+
</details>
|
63 |
+
|
64 |
+
Loading our dataset is easy! Just grab our [load_aokvqa.py](https://github.com/allenai/aokvqa/blob/main/load_aokvqa.py) file and refer to the following code.
|
65 |
+
|
66 |
+
```python
|
67 |
+
import os
|
68 |
+
aokvqa_dir = os.getenv('AOKVQA_DIR')
|
69 |
+
|
70 |
+
from load_aokvqa import load_aokvqa, get_coco_path
|
71 |
+
train_dataset = load_aokvqa(aokvqa_dir, 'train') # also 'val' or 'test'
|
72 |
+
```
|
73 |
+
|
74 |
+
<details> <summary><b>Example dataset entry</b></summary>
|
75 |
+
|
76 |
+
```python
|
77 |
+
dataset_example = train_dataset[0]
|
78 |
+
|
79 |
+
print(dataset_example['question_id'])
|
80 |
+
# 22MexNkBPpdZGX6sxbxVBH
|
81 |
+
|
82 |
+
coco_dir = os.getenv('COCO_DIR')
|
83 |
+
image_path = get_coco_path('train', dataset_example['image_id'], coco_dir)
|
84 |
+
print(image_path)
|
85 |
+
# ./datasets/coco/train2017/000000299207.jpg
|
86 |
+
|
87 |
+
print(dataset_example['question'])
|
88 |
+
print(dataset_example['choices'])
|
89 |
+
# What is the man by the bags awaiting?
|
90 |
+
# ['skateboarder', 'train', 'delivery', 'cab']
|
91 |
+
|
92 |
+
correct_choice = dataset_example['choices'][ dataset_example['correct_choice_idx'] ]
|
93 |
+
# Corrrect: cab
|
94 |
+
|
95 |
+
print(dataset_example['rationales'][0])
|
96 |
+
# A train would not be on the street, he would not have luggage waiting for a delivery, and the skateboarder is there and not paying attention to him so a cab is the only possible answer.
|
97 |
+
```
|
98 |
+
|
99 |
+
</details>
|
100 |
+
|
101 |
+
## Evaluation
|
102 |
+
|
103 |
+
Please prepare `predictions_{split}.json` files (for `split: {val,test}`) in the format below. You may omit either `multiple_choice` or `direct_answer` field if you only want to evaluate one setting.
|
104 |
+
|
105 |
+
```python
|
106 |
+
{
|
107 |
+
'<question_id>' : {
|
108 |
+
'multiple_choice' : '<prediction>',
|
109 |
+
'direct_answer' : '<prediction>'
|
110 |
+
}
|
111 |
+
}
|
112 |
+
```
|
113 |
+
|
114 |
+
You can run evaluation on the validation set as follows.
|
115 |
+
|
116 |
+
```bash
|
117 |
+
python evaluation/eval_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --preds ./predictions_val.json
|
118 |
+
```
|
119 |
+
|
120 |
+
### Leaderboard
|
121 |
+
|
122 |
+
You may submit `predictions_test.json` to the [leaderboard](https://leaderboard.allenai.org/a-okvqa/submissions/get-started).
|
123 |
+
|
124 |
+
## Codebase
|
125 |
+
|
126 |
+
We provide all code and pretrained models necessary to replicate our experiments for Large-Scale Pretrained Models (sec. 5.2) and Rationale Generation (sec. 5.3).
|
127 |
+
|
128 |
+
### Preparing data
|
129 |
+
|
130 |
+
```bash
|
131 |
+
export FEATURES_DIR=./features/
|
132 |
+
mkdir -p ${FEATURES_DIR}
|
133 |
+
```
|
134 |
+
|
135 |
+
You can compute CLIP features for our vocabulary and dataset. These are most commonly used by our other experiments.
|
136 |
+
|
137 |
+
```bash
|
138 |
+
python data_scripts/encode_vocab_clip.py --vocab ${AOKVQA_DIR}/large_vocab_train.csv --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt
|
139 |
+
|
140 |
+
for split in train val test; do
|
141 |
+
python data_scripts/extract_clip_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_${split}.pt
|
142 |
+
done
|
143 |
+
```
|
144 |
+
|
145 |
+
<details> <summary><b>For training ClipCap with a transformer mapping network</b></summary>
|
146 |
+
|
147 |
+
If you want to train our ClipCap models with the transformer mapping network (instead of an MLP, like we do), you'll also need to run `extract_clip_features.py` with `--model-type RN50x4`.
|
148 |
+
|
149 |
+
</details>
|
150 |
+
|
151 |
+
<details> <summary><b>For ResNet and BERT input features</b></summary>
|
152 |
+
|
153 |
+
Our ResNet and BERT classification experiments require these respective features instead of CLIP. To generate these, please run the following commands:
|
154 |
+
|
155 |
+
```bash
|
156 |
+
# ResNet
|
157 |
+
for split in train val test; do
|
158 |
+
python data_scripts/extract_resnet_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --out ${FEATURES_DIR}/resnet_${split}.pt
|
159 |
+
done
|
160 |
+
|
161 |
+
# BERT
|
162 |
+
for split in train val test; do
|
163 |
+
python data_scripts/extract_bert_features.py --aokvqa-dir ${AOKVQA_DIR} --split ${split} --out ${FEATURES_DIR}/bert_${split}.pt
|
164 |
+
done
|
165 |
+
```
|
166 |
+
|
167 |
+
</details>
|
168 |
+
|
169 |
+
### Models and Predictions
|
170 |
+
|
171 |
+
```bash
|
172 |
+
export LOG_DIR=./logs/
|
173 |
+
export PREDS_DIR=./predictions/
|
174 |
+
export PT_MODEL_DIR=./pretrained_models/
|
175 |
+
mkdir -p ${LOG_DIR} ${PREDS_DIR} ${PT_MODEL_DIR}
|
176 |
+
```
|
177 |
+
|
178 |
+
<details> <summary><b>Download our pretrained model weights</b></summary>
|
179 |
+
|
180 |
+
```bash
|
181 |
+
# Checkpoints for transfer learning experiments
|
182 |
+
curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/transfer_exp_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models
|
183 |
+
|
184 |
+
# Checkpoints for ClipCap models (generating answers and rationales)
|
185 |
+
curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/clipcap_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models
|
186 |
+
```
|
187 |
+
|
188 |
+
</details>
|
189 |
+
|
190 |
+
We have included instructions for replicating each of our experiments (see README.md files below).
|
191 |
+
|
192 |
+
All Python scripts should be run from the root of this repository. Please be sure to first run the installation and data preparation as directed above.
|
193 |
+
|
194 |
+
- [Heuristics](./heuristics/README.md)
|
195 |
+
- [Transfer Learning Experiments](./transfer_experiments/README.md)
|
196 |
+
- [Querying GPT-3](./gpt3/README.md)
|
197 |
+
- [ClipCap](https://github.com/allenai/aokvqa/blob/ClipCap/README.md)
|
198 |
+
- [Generating Captions & Rationales](https://github.com/allenai/aokvqa/blob/ClipCap/README.md)
|
199 |
+
|
200 |
+
For each experiment, we follow this prediction file naming scheme: `{model-name}_{split}-{setting}.json` (e.g. `random-weighted_val-mc.json` or `random-weighted_test-da.json`). As examples in these Readme files, we produce predictions on the validation set.
|
201 |
+
|
202 |
+
We unify predictions for each split before evaluation. (You can omit one of `--mc` or `--da` prediction file if you only want to evaluate one setting.)
|
203 |
+
|
204 |
+
```bash
|
205 |
+
python evaluation/prepare_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc ./predictions_val-mc.json --da ./predictions_val-da.json --out ./predictions_val.json
|
206 |
+
# repeat for test split ...
|
207 |
+
```
|
minigpt4/common/vqa_tools/aokvqa/data_scripts/build_vocab.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from collections import Counter
|
4 |
+
import pathlib
|
5 |
+
|
6 |
+
from load_aokvqa import load_aokvqa
|
7 |
+
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
11 |
+
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
|
15 |
+
# Build vocab from train set: correct choices + (direct answers appearing in >= 3 )
|
16 |
+
|
17 |
+
train_set = load_aokvqa(args.aokvqa_dir, 'train')
|
18 |
+
|
19 |
+
vocab = []
|
20 |
+
all_choices = Counter()
|
21 |
+
direct_answers = Counter()
|
22 |
+
|
23 |
+
for i in train_set:
|
24 |
+
vocab.append( i['choices'][i['correct_choice_idx']] )
|
25 |
+
all_choices.update(i['choices'])
|
26 |
+
direct_answers.update(set(i['direct_answers']))
|
27 |
+
vocab += [k for k,v in all_choices.items() if v >= 3]
|
28 |
+
vocab += [k for k,v in direct_answers.items() if v >= 3]
|
29 |
+
|
30 |
+
vocab = sorted(set(vocab))
|
31 |
+
print(f"Vocab size: {len(vocab)}")
|
32 |
+
|
33 |
+
# Save vocabulary Output
|
34 |
+
|
35 |
+
with open(args.output_file, 'w') as f:
|
36 |
+
for v in vocab:
|
37 |
+
print(v, file=f)
|
38 |
+
|
39 |
+
## Check validation set coverage
|
40 |
+
|
41 |
+
val_set = load_aokvqa(args.aokvqa_dir, 'val')
|
42 |
+
|
43 |
+
val_acc = [v['choices'][v['correct_choice_idx']] in vocab for v in val_set]
|
44 |
+
val_acc = sum(val_acc) / len(val_acc) * 100
|
45 |
+
print(f"Val set coverage: {val_acc:.2f}" )
|
minigpt4/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from tqdm import tqdm
|
3 |
+
import argparse
|
4 |
+
import pathlib
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import clip
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument('--vocab', type=pathlib.Path, required=True, dest='vocab_file')
|
11 |
+
parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type')
|
12 |
+
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
|
13 |
+
args = parser.parse_args()
|
14 |
+
|
15 |
+
assert args.output_file.suffix == '.pt'
|
16 |
+
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
model, preprocess = clip.load(args.model_type, device=device)
|
19 |
+
|
20 |
+
with torch.no_grad():
|
21 |
+
a = open(args.vocab_file).read().splitlines()
|
22 |
+
mc_text = clip.tokenize(a).to(device)
|
23 |
+
mc_text_features = torch.stack([model.encode_text(mct.unsqueeze(0)).cpu() for mct in tqdm(mc_text)], dim=1)[0]
|
24 |
+
mc_text_features = mc_text_features.float()
|
25 |
+
model_name = args.model_type.replace('/', '-').replace('@', '-')
|
26 |
+
torch.save(mc_text_features, args.output_file)
|
minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import pathlib
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from transformers import AutoTokenizer, AutoModel
|
8 |
+
|
9 |
+
from load_aokvqa import load_aokvqa
|
10 |
+
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
14 |
+
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
15 |
+
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
assert args.output_file.suffix == '.pt'
|
19 |
+
|
20 |
+
## Load dataset
|
21 |
+
|
22 |
+
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
23 |
+
|
24 |
+
## Load model
|
25 |
+
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
|
27 |
+
model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
|
28 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
+
model = model.to(device)
|
30 |
+
model.eval()
|
31 |
+
|
32 |
+
def mean_pooling(model_output, attention_mask):
|
33 |
+
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
|
34 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
35 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
36 |
+
|
37 |
+
## Encoding loop
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
embeddings = {}
|
41 |
+
|
42 |
+
for d in tqdm(dataset):
|
43 |
+
encoded_input = tokenizer([d['question']], padding=True, return_tensors='pt')
|
44 |
+
encoded_input = {k:v.to(device) for k,v in encoded_input.items()}
|
45 |
+
e = mean_pooling(model(**encoded_input), encoded_input['attention_mask'])
|
46 |
+
embeddings[d['question_id']] = {
|
47 |
+
'question' : e[0].cpu()
|
48 |
+
}
|
49 |
+
|
50 |
+
torch.save(embeddings, args.output_file)
|
minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
from tqdm import tqdm
|
4 |
+
import argparse
|
5 |
+
import pathlib
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import clip
|
9 |
+
|
10 |
+
from load_aokvqa import load_aokvqa, get_coco_path
|
11 |
+
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
15 |
+
parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
|
16 |
+
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
17 |
+
parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type')
|
18 |
+
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
assert args.output_file.suffix == '.pt'
|
22 |
+
|
23 |
+
## Load dataset
|
24 |
+
|
25 |
+
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
26 |
+
|
27 |
+
## Load model
|
28 |
+
|
29 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
+
model, preprocess = clip.load(args.model_type, device=device)
|
31 |
+
|
32 |
+
## Encoding loop
|
33 |
+
|
34 |
+
with torch.no_grad():
|
35 |
+
embeddings = {}
|
36 |
+
|
37 |
+
for d in tqdm(dataset):
|
38 |
+
q = d["question"]
|
39 |
+
q_text = clip.tokenize(q).to(device)
|
40 |
+
q_text_features = model.encode_text(q_text)
|
41 |
+
|
42 |
+
img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir))
|
43 |
+
img = preprocess(img).unsqueeze(0).to(device)
|
44 |
+
image_features = model.encode_image(img)
|
45 |
+
|
46 |
+
embeddings[d['question_id']] = {
|
47 |
+
'question' : q_text_features[0].float().cpu(),
|
48 |
+
'image' : image_features[0].float().cpu(),
|
49 |
+
}
|
50 |
+
|
51 |
+
torch.save(embeddings, args.output_file)
|
minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import pathlib
|
4 |
+
from tqdm import tqdm
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torchvision import models
|
10 |
+
from torchvision import transforms as T
|
11 |
+
|
12 |
+
from load_aokvqa import load_aokvqa, get_coco_path
|
13 |
+
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
17 |
+
parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
|
18 |
+
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
19 |
+
parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
assert args.output_file.suffix == '.pt'
|
23 |
+
|
24 |
+
## Load dataset
|
25 |
+
|
26 |
+
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
27 |
+
|
28 |
+
## Load model
|
29 |
+
|
30 |
+
resnet_preprocess = T.Compose([
|
31 |
+
T.Resize(size=224, interpolation=T.InterpolationMode.BICUBIC),
|
32 |
+
T.CenterCrop(size=(224, 224)),
|
33 |
+
T.ToTensor(),
|
34 |
+
T.Normalize(
|
35 |
+
mean=[0.485, 0.456, 0.406],
|
36 |
+
std=[0.229, 0.224, 0.225]
|
37 |
+
)
|
38 |
+
])
|
39 |
+
|
40 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
+
|
42 |
+
resnet_model = models.resnet50(pretrained=True)
|
43 |
+
resnet_model = torch.nn.Sequential(
|
44 |
+
*list(resnet_model.children())[:-1],
|
45 |
+
nn.Flatten()
|
46 |
+
) # strip classification layer
|
47 |
+
resnet_model = resnet_model.to(device)
|
48 |
+
|
49 |
+
## Encoding loop
|
50 |
+
|
51 |
+
with torch.no_grad():
|
52 |
+
embeddings = {}
|
53 |
+
|
54 |
+
for d in tqdm(dataset):
|
55 |
+
img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir)).convert('RGB')
|
56 |
+
resnet_input = resnet_preprocess(img).unsqueeze(0).to(device)
|
57 |
+
resnet_features = resnet_model(resnet_input)
|
58 |
+
embeddings[d['question_id']] = {
|
59 |
+
'image' : resnet_features[0].cpu()
|
60 |
+
}
|
61 |
+
|
62 |
+
torch.save(embeddings, args.output_file)
|
minigpt4/common/vqa_tools/aokvqa/environment.yml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: aokvqa
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- huggingface
|
6 |
+
- conda-forge
|
7 |
+
- defaults
|
8 |
+
dependencies:
|
9 |
+
- python=3.7
|
10 |
+
- cudatoolkit=11.3
|
11 |
+
- numpy=1.21.6
|
12 |
+
- pytorch=1.11.0
|
13 |
+
- torchvision=0.12.0
|
14 |
+
- pytorch-lightning=1.6.3
|
15 |
+
- torchmetrics=0.8.1
|
16 |
+
- gdown=4.4.0
|
17 |
+
- pip=22.0.4
|
18 |
+
- pip:
|
19 |
+
- argparse==1.4.0
|
20 |
+
- Pillow==9.0.1
|
21 |
+
- tensorboard==2.9.0
|
22 |
+
- ftfy==6.1.1
|
23 |
+
- regex==2022.3.15
|
24 |
+
- tqdm==4.64.0
|
25 |
+
- clip @ git+https://github.com/openai/CLIP.git@b46f5ac7587d2e1862f8b7b1573179d80dcdd620
|
26 |
+
- openai==0.18.1
|
27 |
+
- nltk==3.7
|
28 |
+
- sacrebleu==2.0.0
|
29 |
+
- sacremoses==0.0.53
|
30 |
+
- sentence-transformers==2.2.0
|
31 |
+
- datasets==2.1.0
|
32 |
+
- tokenizers==0.10.3
|
33 |
+
- transformers==4.10.3
|
34 |
+
|
35 |
+
# Next: resolve conflict between sentence-transfomers and pytorch-lightning
|
36 |
+
# pip uninstall sentencepiece
|
minigpt4/common/vqa_tools/aokvqa/evaluation/eval_predictions.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pathlib
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
|
6 |
+
from load_aokvqa import load_aokvqa
|
7 |
+
|
8 |
+
|
9 |
+
def eval_aokvqa(dataset, preds, multiple_choice=False, strict=True):
|
10 |
+
|
11 |
+
if isinstance(dataset, list):
|
12 |
+
dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }
|
13 |
+
|
14 |
+
if multiple_choice is False:
|
15 |
+
dataset = {k:v for k,v in dataset.items() if v['difficult_direct_answer'] is False}
|
16 |
+
|
17 |
+
if strict:
|
18 |
+
dataset_qids = set(dataset.keys())
|
19 |
+
preds_qids = set(preds.keys())
|
20 |
+
assert dataset_qids.issubset(preds_qids)
|
21 |
+
|
22 |
+
# dataset = q_id (str) : dataset element (dict)
|
23 |
+
# preds = q_id (str) : prediction (str)
|
24 |
+
|
25 |
+
acc = []
|
26 |
+
|
27 |
+
for q in dataset.keys():
|
28 |
+
if q not in preds.keys():
|
29 |
+
acc.append(0.0)
|
30 |
+
continue
|
31 |
+
|
32 |
+
pred = preds[q]
|
33 |
+
choices = dataset[q]['choices']
|
34 |
+
direct_answers = dataset[q]['direct_answers']
|
35 |
+
|
36 |
+
## Multiple Choice setting
|
37 |
+
if multiple_choice:
|
38 |
+
if strict:
|
39 |
+
assert pred in choices, 'Prediction must be a valid choice'
|
40 |
+
correct_choice_idx = dataset[q]['correct_choice_idx']
|
41 |
+
acc.append( float(pred == choices[correct_choice_idx]) )
|
42 |
+
## Direct Answer setting
|
43 |
+
else:
|
44 |
+
num_match = sum([pred.lower() == da.lower() for da in direct_answers])
|
45 |
+
vqa_acc = min(1.0, num_match / 3.0)
|
46 |
+
acc.append(vqa_acc)
|
47 |
+
|
48 |
+
acc = sum(acc) / len(acc) * 100
|
49 |
+
|
50 |
+
return acc
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
parser = argparse.ArgumentParser()
|
55 |
+
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
56 |
+
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
57 |
+
parser.add_argument('--preds', type=str, required=True, dest='prediction_files')
|
58 |
+
args = parser.parse_args()
|
59 |
+
|
60 |
+
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
61 |
+
|
62 |
+
for prediction_file in glob.glob(args.prediction_files):
|
63 |
+
predictions = json.load(open(prediction_file, 'r'))
|
64 |
+
|
65 |
+
# Multiple choice
|
66 |
+
|
67 |
+
mc_predictions = {}
|
68 |
+
|
69 |
+
for q in predictions.keys():
|
70 |
+
if 'multiple_choice' in predictions[q].keys():
|
71 |
+
mc_predictions[q] = predictions[q]['multiple_choice']
|
72 |
+
|
73 |
+
if mc_predictions != {}:
|
74 |
+
mc_acc = eval_aokvqa(
|
75 |
+
dataset,
|
76 |
+
mc_predictions,
|
77 |
+
multiple_choice=True,
|
78 |
+
strict=False
|
79 |
+
)
|
80 |
+
print(prediction_file, 'MC', mc_acc)
|
81 |
+
|
82 |
+
# Direct Answer
|
83 |
+
|
84 |
+
da_predictions = {}
|
85 |
+
|
86 |
+
for q in predictions.keys():
|
87 |
+
if 'direct_answer' in predictions[q].keys():
|
88 |
+
da_predictions[q] = predictions[q]['direct_answer']
|
89 |
+
|
90 |
+
if da_predictions != {}:
|
91 |
+
da_acc = eval_aokvqa(
|
92 |
+
dataset,
|
93 |
+
da_predictions,
|
94 |
+
multiple_choice=False,
|
95 |
+
strict=False
|
96 |
+
)
|
97 |
+
print(prediction_file, 'DA', da_acc)
|
minigpt4/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
def load_aokvqa(aokvqa_dir, split, version='v1p0'):
|
6 |
+
assert split in ['train', 'val', 'test', 'test_w_ans']
|
7 |
+
dataset = json.load(open(
|
8 |
+
os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json")
|
9 |
+
))
|
10 |
+
return dataset
|
11 |
+
|
12 |
+
def get_coco_path(split, image_id, coco_dir):
|
13 |
+
return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg")
|
minigpt4/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pathlib
|
3 |
+
import json
|
4 |
+
|
5 |
+
from load_aokvqa import load_aokvqa
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == '__main__':
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
11 |
+
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
12 |
+
parser.add_argument('--mc', type=argparse.FileType('r'), dest='mc_pred_file')
|
13 |
+
parser.add_argument('--da', type=argparse.FileType('r'), dest='da_pred_file')
|
14 |
+
parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file')
|
15 |
+
args = parser.parse_args()
|
16 |
+
assert args.mc_pred_file or args.da_pred_file
|
17 |
+
|
18 |
+
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
19 |
+
mc_preds = json.load(args.mc_pred_file) if args.mc_pred_file else None
|
20 |
+
da_preds = json.load(args.da_pred_file) if args.da_pred_file else None
|
21 |
+
predictions = {}
|
22 |
+
|
23 |
+
for d in dataset:
|
24 |
+
q = d['question_id']
|
25 |
+
predictions[q] = {}
|
26 |
+
if mc_preds and q in mc_preds.keys():
|
27 |
+
predictions[q]['multiple_choice'] = mc_preds[q]
|
28 |
+
if da_preds and q in da_preds.keys():
|
29 |
+
predictions[q]['direct_answer'] = da_preds[q]
|
30 |
+
|
31 |
+
json.dump(predictions, args.output_file)
|
minigpt4/common/vqa_tools/aokvqa/evaluation/remap_predictions.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pathlib
|
3 |
+
import json
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from sentence_transformers.util import cos_sim
|
8 |
+
|
9 |
+
from load_aokvqa import load_aokvqa
|
10 |
+
|
11 |
+
|
12 |
+
def map_to_choices(dataset, predictions, device='cpu'):
|
13 |
+
if isinstance(dataset, list):
|
14 |
+
dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }
|
15 |
+
|
16 |
+
if all([p in dataset[q]['choices'] for q, p in predictions.items()]):
|
17 |
+
return predictions
|
18 |
+
|
19 |
+
model = SentenceTransformer('sentence-transformers/average_word_embeddings_glove.6B.300d')
|
20 |
+
model.to(device)
|
21 |
+
for q in tqdm(predictions.keys()):
|
22 |
+
choices = dataset[q]['choices']
|
23 |
+
if predictions[q] not in choices:
|
24 |
+
choice_embeddings = model.encode([predictions[q]] + choices, convert_to_tensor=True)
|
25 |
+
a_idx = cos_sim(choice_embeddings[0], choice_embeddings[1:]).argmax().item()
|
26 |
+
predictions[q] = choices[a_idx]
|
27 |
+
|
28 |
+
return predictions
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == '__main__':
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
34 |
+
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
35 |
+
parser.add_argument('--pred', type=argparse.FileType('r'), required=True, dest='prediction_file')
|
36 |
+
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
|
40 |
+
dataset = load_aokvqa(args.aokvqa_dir, args.split)
|
41 |
+
predictions = json.load(args.prediction_file)
|
42 |
+
predictions = map_to_choices(dataset, predictions)
|
43 |
+
|
44 |
+
json.dump(predictions, args.output_file)
|
minigpt4/common/vqa_tools/aokvqa/gpt3/README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Querying GPT-3
|
2 |
+
|
3 |
+
To follow our experiments which use GPT-3, you must have access to the [OpenAI API](https://openai.com/api/) (at cost). Please retrieve your [organization](https://beta.openai.com/account/org-settings) and [API](https://beta.openai.com/account/api-keys) keys and set them in your environment variables.
|
4 |
+
|
5 |
+
```bash
|
6 |
+
export OPENAI_ORG=....
|
7 |
+
export OPENAI_API_KEY=...
|
8 |
+
```
|
9 |
+
|
10 |
+
For producing predictions for both DA and MC settings, run:
|
11 |
+
```bash
|
12 |
+
python gpt3/query_gpt3.py --aokvqa-dir ${AOKVQA_DIR} --split val --out ${PREDS_DIR}/gpt3_val-da.json
|
13 |
+
python remap_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --pred ${PREDS_DIR}/gpt3_val-da.json --out ${PREDS_DIR}/gpt3_val-mc.json
|
14 |
+
```
|
minigpt4/common/vqa_tools/aokvqa/gpt3/caption_inputs.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import pathlib
|
5 |
+
|
6 |
+
from load_aokvqa import load_aokvqa
|
7 |
+
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
11 |
+
parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
|
12 |
+
parser.add_argument('--split', type=str, choices=['train', 'val'], required=True)
|
13 |
+
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
14 |
+
args = parser.parse_args()
|
15 |
+
|
16 |
+
aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
|
17 |
+
|
18 |
+
coco_captions = json.load(open(os.path.join(args.coco_dir, 'annotations', f'captions_{args.split}2017.json')))['annotations']
|
19 |
+
coco_captions = {c['image_id'] : c['caption'] for c in coco_captions}
|
20 |
+
|
21 |
+
captions = { d['question_id'] : coco_captions[d['image_id']] for d in aokvqa_set }
|
22 |
+
|
23 |
+
json.dump(captions, args.output_file)
|
minigpt4/common/vqa_tools/aokvqa/gpt3/query_gpt3.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import json
|
4 |
+
from tqdm import tqdm
|
5 |
+
import argparse
|
6 |
+
import pathlib
|
7 |
+
|
8 |
+
import openai
|
9 |
+
openai.organization = os.getenv('OPENAI_ORG')
|
10 |
+
openai.api_key = os.getenv('OPENAI_API_KEY')
|
11 |
+
|
12 |
+
from load_aokvqa import load_aokvqa
|
13 |
+
|
14 |
+
|
15 |
+
random.seed(0)
|
16 |
+
|
17 |
+
|
18 |
+
def main():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
21 |
+
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
22 |
+
parser.add_argument('--n', type=int, default=10, dest='num_examples')
|
23 |
+
parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
|
24 |
+
parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix')
|
25 |
+
parser.add_argument('--include-choices', action='store_true', dest='include_choices')
|
26 |
+
parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
|
27 |
+
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
|
31 |
+
train_set = load_aokvqa(args.aokvqa_dir, 'train')
|
32 |
+
eval_set = load_aokvqa(args.aokvqa_dir, args.split)
|
33 |
+
|
34 |
+
train_context = {}
|
35 |
+
context = {}
|
36 |
+
if args.context_file is not None:
|
37 |
+
train_context = json.load(args.train_context_file)
|
38 |
+
context = json.load(args.context_file)
|
39 |
+
|
40 |
+
predictions = {}
|
41 |
+
|
42 |
+
for d in tqdm(eval_set):
|
43 |
+
q = d['question_id']
|
44 |
+
|
45 |
+
prompt = args.prompt_prefix
|
46 |
+
for e in random.sample(train_set, args.num_examples):
|
47 |
+
prompt += prompt_element(e,
|
48 |
+
context=train_context.get(q, None),
|
49 |
+
include_choices=args.include_choices,
|
50 |
+
answer=True
|
51 |
+
)
|
52 |
+
prompt += '\n\n'
|
53 |
+
|
54 |
+
prompt += prompt_element(d,
|
55 |
+
context=context.get(q, None),
|
56 |
+
include_choices=args.include_choices,
|
57 |
+
answer=False
|
58 |
+
)
|
59 |
+
|
60 |
+
response = openai.Completion.create(
|
61 |
+
engine="text-curie-001",
|
62 |
+
prompt=prompt,
|
63 |
+
temperature=0.0,
|
64 |
+
max_tokens=10,
|
65 |
+
)
|
66 |
+
|
67 |
+
predictions[q] = response.choices[0].text.strip()
|
68 |
+
|
69 |
+
json.dump(predictions, args.output_file)
|
70 |
+
|
71 |
+
|
72 |
+
def prompt_element(d, context=None, include_choices=False, answer=False):
|
73 |
+
return (f"Context: {context}\n" if context is not None else '') + \
|
74 |
+
f"Q: {d['question']}\n" + \
|
75 |
+
(f"Choices: {', '.join(d['choices'])}.\n" if include_choices else '') + \
|
76 |
+
f"A:" + (f" {d['choices'][d['correct_choice_idx']]}" if answer else '')
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
main()
|
minigpt4/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import argparse
|
3 |
+
import pathlib
|
4 |
+
|
5 |
+
from load_aokvqa import load_aokvqa
|
6 |
+
|
7 |
+
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
10 |
+
parser.add_argument('--split', type=str, choices=['train', 'val', 'test_w_ans'], required=True)
|
11 |
+
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
|
15 |
+
rationales = {d['question_id'] : d['rationales'][0] for d in aokvqa_set}
|
16 |
+
json.dump(rationales, args.output_file)
|