diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..15c3d9b632a5f47259569a0d31489e86e65c91fd 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +repo_imgs/sample_1.gif filter=lfs diff=lfs merge=lfs -text +repo_imgs/sample_2.gif filter=lfs diff=lfs merge=lfs -text +repo_imgs/sample_3.gif filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..9d5f2f9b2b609b8e8628051fea6a408de9d959dc --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,14 @@ +BSD 3-Clause License + +Copyright 2023 Deyao Zhu +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/LICENSE_Lavis.md b/LICENSE_Lavis.md new file mode 100644 index 0000000000000000000000000000000000000000..9ba97919e5b9568c8b9c42ea85251f01049a220e --- /dev/null +++ b/LICENSE_Lavis.md @@ -0,0 +1,14 @@ +BSD 3-Clause License + +Copyright (c) 2022 Salesforce, Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/datasets/training_datasets/video_text_data/video_instruct_100/download_script.py b/datasets/training_datasets/video_text_data/video_instruct_100/download_script.py new file mode 100644 index 0000000000000000000000000000000000000000..d6da7d2fc72bb16ff038b9158e69e0e74e193486 --- /dev/null +++ b/datasets/training_datasets/video_text_data/video_instruct_100/download_script.py @@ -0,0 +1,94 @@ +import json +from tqdm import tqdm +from pytubefix import YouTube + +import xml.etree.ElementTree as ET +import os + +with open ('VideoInstruct100K.json','r') as f : + data=json.load(f) + +# Usage +existed_video_id={} +for video_name in os.listdir('videos'): + video_id = video_name.split('.')[0] + existed_video_id[video_id]=True + + + +def download_video_with_subtitles(video_id): + # Create a YouTube object. + yt = YouTube(f'https://www.youtube.com/watch?v={video_id}') + + video_filename = f"{video_id}.mp4" + video_downloaded=False + try : + # Get the video stream with the highest resolution and download the video. + stream = yt.streams.get_highest_resolution() + stream.download(output_path='videos', filename=video_filename) + video_downloaded=True + except Exception as e: + print(f"Error downloading video {video_id}: {str(e)}") + video_downloaded=False + if not video_downloaded: + return False,False + + # Get the video's available captions (subtitles). + captions = yt.captions.all() + + # Download the captions if available in xml format. + caption_downloaded = False + for caption in captions: + caption_code = caption.code + # select only english captions + if 'en' in caption_code: + caption.download(title=f"{video_id}", output_path='subtitles_xml',srt=False) + caption_downloaded = True + return video_downloaded,caption_downloaded +def convert_xml_vtt(xml_path, vtt_path): + # Parse the XML subtitle file + tree = ET.parse(xml_path) + root = tree.getroot() + + # Initialize a list to store VTT subtitle entries + vtt_subtitle = [] + + # Function to convert time in milliseconds to WebVTT format + def ms_to_vtt_time(milliseconds): + seconds, milliseconds = divmod(milliseconds, 1000) + minutes, seconds = divmod(seconds, 60) + return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}" + + # Iterate through subtitle elements + toggle = True + for p in root.findall(".//p"): + if toggle: + start_time = int(p.get("t")) + subtitle_text = " ".join(s.text.strip() for s in p.findall(".//s")) + # duration = int(p.get("d")) if p.get("d") is not None else 0 + if not toggle: + end_time = int(p.get("t")) + # Format and append the VTT entry to the list + vtt_subtitle.append(f"{ms_to_vtt_time(start_time)} --> {ms_to_vtt_time(end_time)}\n{subtitle_text}\n") + toggle = not toggle + # Join the VTT entries into a single string + vtt_content = "WEBVTT\n\n" + "\n".join(vtt_subtitle) + + # Save the VTT content to a file + with open(vtt_path, "w", encoding="utf-8") as vtt_file: + vtt_file.write(vtt_content) +import os +os.makedirs('videos', exist_ok=True) +os.makedirs('subtitles_vtt', exist_ok=True) +os.makedirs('subtitles_xml', exist_ok=True) +for video_path in tqdm(data,desc='Downloading videos') : + video_id=video_path.split('/')[-1].split('.')[0] + if existed_video_id.get(video_id,False): + continue + video_downloaded,caption_downloaded=download_video_with_subtitles(video_id) + if caption_downloaded: + # convert xml to vtt + xml_file_path=f'subtitles_xml/{video_id} (a.en).xml' + convert_xml_vtt(xml_file_path,f'subtitles_vtt/{video_id}.vtt') + + diff --git a/demo_job.sh b/demo_job.sh new file mode 100644 index 0000000000000000000000000000000000000000..7228b8611b7263295e77a95923b5ae3a933fcdb4 --- /dev/null +++ b/demo_job.sh @@ -0,0 +1,21 @@ +#!/bin/bash +#SBATCH --partition=batch +#SBATCH --job-name=video_demo_llama2 +#SBATCH --output=video_demo_llama2.out +#SBATCH --error=video_demo_llama2.err +#SBATCH --time=0-10:30:00 +#SBATCH --mem=100G +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 + +# Choose the model to test +# Mistral +# ckpt="checkpoints/video_mistral_checkpoint_last.pth" +# config="test_configs/mistral_test_config.yaml" + +# Llama2 +ckpt="checkpoints/video_llama_checkpoint_last.pth" +config="test_configs/llama2_test_config.yaml" + + +python minigpt4_video_demo.py --cfg-path $config --ckpt $ckpt diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..402f7126ca468f2c927c552b09d503789a491651 --- /dev/null +++ b/environment.yml @@ -0,0 +1,331 @@ +name: minigpt4_video_test_v100 +channels: + - conda-forge +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - archspec=0.2.2=pyhd8ed1ab_0 + - boltons=23.1.1=pyhd8ed1ab_0 + - brotli-python=1.1.0=py39h3d6467e_1 + - bzip2=1.0.8=hd590300_5 + - c-ares=1.25.0=hd590300_0 + - ca-certificates=2024.2.2=hbcca054_0 + - certifi=2024.2.2=pyhd8ed1ab_0 + - cffi=1.16.0=py39h7a31438_0 + - charset-normalizer=3.3.2=pyhd8ed1ab_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - conda=23.11.0=py39hf3d152e_1 + - conda-libmamba-solver=23.12.0=pyhd8ed1ab_0 + - conda-package-handling=2.2.0=pyh38be061_0 + - conda-package-streaming=0.9.0=pyhd8ed1ab_0 + - cudatoolkit=11.8.0=h4ba93d1_12 + - cudatoolkit-dev=11.7.0=h1de0b5d_6 + - distro=1.9.0=pyhd8ed1ab_0 + - faiss=1.7.4=py39cuda112h460e57a_0_cuda + - fmt=10.1.1=h00ab1b0_1 + - freetype=2.12.1=h267a509_2 + - gmp=6.1.2=hf484d3e_1000 + - gnutls=3.5.19=h2a4e5f8_1 + - icu=73.2=h59595ed_0 + - idna=3.6=pyhd8ed1ab_0 + - jsonpatch=1.33=pyhd8ed1ab_0 + - jsonpointer=2.4=py39hf3d152e_3 + - keyutils=1.6.1=h166bdaf_0 + - krb5=1.21.2=h659d440_0 + - ld_impl_linux-64=2.40=h41732ed_0 + - libarchive=3.7.2=h2aa1ff5_1 + - libblas=3.9.0=20_linux64_openblas + - libcblas=3.9.0=20_linux64_openblas + - libcurl=8.5.0=hca28451_0 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=hd590300_2 + - libfaiss=1.7.4=cuda112hb18a002_0_cuda + - libfaiss-avx2=1.7.4=cuda112h1234567_0_cuda + - libffi=3.4.2=h7f98852_5 + - libgcc-ng=13.2.0=h807b86a_3 + - libgfortran-ng=13.2.0=h69a702a_3 + - libgfortran5=13.2.0=ha4646dd_3 + - libgomp=13.2.0=h807b86a_3 + - libiconv=1.17=hd590300_2 + - liblapack=3.9.0=20_linux64_openblas + - libmamba=1.5.6=had39da4_0 + - libmambapy=1.5.6=py39h10defb6_0 + - libnghttp2=1.58.0=h47da74e_1 + - libnsl=2.0.1=hd590300_0 + - libopenblas=0.3.25=pthreads_h413a1c8_0 + - libpng=1.6.39=h753d276_0 + - libsolv=0.7.27=hfc55251_0 + - libsqlite=3.44.2=h2797004_0 + - libssh2=1.11.0=h0841786_0 + - libstdcxx-ng=13.2.0=h7e041cc_3 + - libuuid=2.38.1=h0b41bf4_0 + - libxcrypt=4.4.36=hd590300_1 + - libxml2=2.12.3=h232c23b_0 + - libzlib=1.2.13=hd590300_5 + - lz4-c=1.9.4=hcb278e6_0 + - lzo=2.10=h516909a_1000 + - menuinst=2.0.1=py39hf3d152e_0 + - ncurses=6.4=h59595ed_2 + - nettle=3.3=0 + - numpy=1.26.3=py39h474f0d3_0 + - openh264=1.8.0=hdbcaa40_1000 + - openssl=3.2.1=hd590300_0 + - packaging=23.2=pyhd8ed1ab_0 + - pip=23.3.2=pyhd8ed1ab_0 + - platformdirs=4.1.0=pyhd8ed1ab_0 + - pluggy=1.3.0=pyhd8ed1ab_0 + - pybind11-abi=4=hd8ed1ab_3 + - pycosat=0.6.6=py39hd1e30aa_0 + - pycparser=2.21=pyhd8ed1ab_0 + - pysocks=1.7.1=pyha2e5f31_6 + - python=3.9.18=h0755675_1_cpython + - python_abi=3.9=4_cp39 + - readline=8.2=h8228510_1 + - reproc=14.2.4.post0=hd590300_1 + - reproc-cpp=14.2.4.post0=h59595ed_1 + - requests=2.31.0=pyhd8ed1ab_0 + - ruamel.yaml=0.18.5=py39hd1e30aa_0 + - ruamel.yaml.clib=0.2.7=py39hd1e30aa_2 + - tk=8.6.13=noxft_h4845f30_101 + - tqdm=4.66.1=pyhd8ed1ab_0 + - urllib3=2.1.0=pyhd8ed1ab_0 + - wheel=0.42.0=pyhd8ed1ab_0 + - x264=1!152.20180717=h14c3975_1001 + - xz=5.2.6=h166bdaf_0 + - yaml-cpp=0.8.0=h59595ed_0 + - zlib=1.2.13=hd590300_5 + - zstandard=0.22.0=py39h6e5214e_0 + - zstd=1.5.5=hfc55251_0 + - pip: + - accelerate==0.25.0 + - aiofiles==23.2.1 + - aiohttp==3.9.1 + - aiosignal==1.3.1 + - altair==5.2.0 + - annotated-types==0.6.0 + - antlr4-python3-runtime==4.9.3 + - anyio==4.2.0 + - appdirs==1.4.4 + - asgiref==3.7.2 + - async-timeout==4.0.3 + - attrs==23.2.0 + - backoff==2.2.1 + - bcrypt==4.1.2 + - beautifulsoup4==4.12.2 + - bitarray==2.9.2 + - bitsandbytes==0.42.0 + - bleach==6.1.0 + - blinker==1.7.0 + - braceexpand==0.1.7 + - build==1.0.3 + - cachetools==5.3.2 + - chardet==5.2.0 + - chroma-hnswlib==0.7.3 + - chromadb==0.4.22 + - click==8.1.7 + - cmake==3.25.0 + - colbert-ai==0.2.18 + - coloredlogs==15.0.1 + - contourpy==1.2.0 + - cycler==0.12.1 + - datasets==2.17.0 + - decorator==4.4.2 + - decord==0.6.0 + - deprecated==1.2.14 + - dill==0.3.8 + - docker-pycreds==0.4.0 + - docopt==0.6.2 + - einops==0.7.0 + - exceptiongroup==1.2.0 + - faiss-gpu==1.7.2 + - fastapi==0.108.0 + - ffmpeg==1.4 + - ffmpeg-python==0.2.0 + - ffmpy==0.3.1 + - filelock==3.13.1 + - flash-attn==2.5.4 + - flask==3.0.2 + - flatbuffers==23.5.26 + - fonttools==4.47.0 + - frozenlist==1.4.1 + - fsspec==2023.10.0 + - ftfy==6.1.3 + - future==0.18.3 + - gdown==4.7.1 + - git-python==1.0.3 + - gitdb==4.0.11 + - gitpython==3.1.40 + - google-auth==2.26.1 + - googleapis-common-protos==1.62.0 + - gradio + - gradio-client + - h11==0.14.0 + - h5py==3.10.0 + - httpcore==1.0.2 + - httptools==0.6.1 + - httpx==0.26.0 + - huggingface-hub==0.21.1 + - humanfriendly==10.0 + - imageio==2.33.1 + - imageio-ffmpeg==0.4.9 + - importlib-metadata==6.11.0 + - importlib-resources==6.1.1 + - inquirerpy==0.3.4 + - iopath==0.1.10 + - itsdangerous==2.1.2 + - jinja2==3.1.2 + - joblib==1.3.2 + - jsonschema==4.20.0 + - jsonschema-specifications==2023.12.1 + - kaggle==1.6.0 + - kiwisolver==1.4.5 + - kubernetes==29.0.0 + - lazy-loader==0.3 + - lit==15.0.7 + - llvmlite==0.41.1 + - markdown-it-py==3.0.0 + - matplotlib==3.8.2 + - mdurl==0.1.2 + - mmh3==4.1.0 + - monotonic==1.6 + - more-itertools==10.1.0 + - moviepy==1.0.3 + - mpmath==1.3.0 + - multidict==6.0.4 + - multiprocess==0.70.16 + - mutagen==1.47.0 + - networkx==3.2.1 + - ninja==1.11.1.1 + - nltk==3.8.1 + - numba==0.58.1 + - nvidia-cublas-cu11==11.10.3.66 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu11==11.7.99 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu11==11.7.99 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu11==8.5.0.96 + - nvidia-cudnn-cu12==8.9.2.26 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-nccl-cu12==2.18.1 + - nvidia-nvjitlink-cu12==12.3.101 + - nvidia-nvtx-cu12==12.1.105 + - omegaconf==2.3.0 + - onnxruntime==1.16.3 + - openai==0.28.0 + - openai-whisper==20231117 + - opencv-python==4.7.0.72 + - opentelemetry-api==1.22.0 + - opentelemetry-exporter-otlp-proto-common==1.22.0 + - opentelemetry-exporter-otlp-proto-grpc==1.22.0 + - opentelemetry-instrumentation==0.43b0 + - opentelemetry-instrumentation-asgi==0.43b0 + - opentelemetry-instrumentation-fastapi==0.43b0 + - opentelemetry-proto==1.22.0 + - opentelemetry-sdk==1.22.0 + - opentelemetry-semantic-conventions==0.43b0 + - opentelemetry-util-http==0.43b0 + - orjson==3.9.10 + - overrides==7.4.0 + - pandas==2.0.0 + - pathtools==0.1.2 + - peft==0.2.0 + - pfzy==0.3.4 + - pillow==10.2.0 + - plotly==5.18.0 + - portalocker==2.8.2 + - posthog==3.3.0 + - proglog==0.1.10 + - progressbar2==4.3.2 + - prompt-toolkit==3.0.43 + - protobuf==4.25.1 + - psutil==5.9.7 + - pulsar-client==3.4.0 + - pyarrow==15.0.0 + - pyarrow-hotfix==0.6 + - pyasn1==0.5.1 + - pyasn1-modules==0.3.0 + - pycocoevalcap==1.2 + - pycocotools==2.0.6 + - pycryptodomex==3.19.1 + - pydantic==2.5.3 + - pydantic-core==2.14.6 + - pydub==0.25.1 + - pygments==2.17.2 + - pyparsing==3.1.1 + - pypika==0.48.9 + - pyproject-hooks==1.0.0 + - pysrt==1.1.2 + - python-dateutil==2.8.2 + - python-dotenv==1.0.0 + - python-multipart==0.0.6 + - python-slugify==8.0.1 + - python-utils==3.8.1 + - pytubefix + - pytz==2023.3.post1 + - pyyaml==6.0.1 + - referencing==0.32.0 + - regex==2023.12.25 + - rich==13.7.0 + - rouge==1.0.1 + - rpds-py==0.16.2 + - rsa==4.9 + - safetensors==0.4.1 + - scikit-image==0.22.0 + - scikit-learn==1.3.2 + - scipy==1.11.4 + - seaborn==0.13.1 + - semantic-version==2.10.0 + - sentence-transformers==2.2.2 + - sentencepiece==0.1.97 + - sentry-sdk==1.39.1 + - setproctitle==1.3.3 + - setuptools==69.0.3 + - shellingham==1.5.4 + - six==1.16.0 + - smmap==5.0.1 + - sniffio==1.3.0 + - soundfile==0.12.1 + - soupsieve==2.5 + - starlette==0.32.0.post1 + - sympy==1.12 + - tenacity==8.2.3 + - text-unidecode==1.3 + - threadpoolctl==3.2.0 + - tifffile==2023.12.9 + - tiktoken==0.5.2 + - timm==0.6.13 + - tokenizers==0.15.2 + - tomli==2.0.1 + - tomlkit==0.12.0 + - toolz==0.12.0 + - torch==2.0.1 + - torchaudio==2.0.2 + - torchvision==0.15.2 + - transformers==4.37.2 + - triton==2.0.0 + - typer==0.9.0 + - typing-extensions==4.9.0 + - tzdata==2023.4 + - ujson==5.9.0 + - uvicorn==0.25.0 + - uvloop==0.19.0 + - visual-genome==1.1.1 + - wandb==0.14.2 + - watchfiles==0.21.0 + - wcwidth==0.2.13 + - webdataset==0.2.48 + - webencodings==0.5.1 + - websocket-client==1.7.0 + - websockets + - webvtt-py==0.4.6 + - wrapt==1.16.0 + - xxhash==3.4.1 + - yarl==1.9.4 + - youtube-dl==2021.12.17 + - yt-dlp + - zipp diff --git a/eval_video.py b/eval_video.py new file mode 100644 index 0000000000000000000000000000000000000000..566d2f2efd0c49948a9d256da7a9df66c6634e6a --- /dev/null +++ b/eval_video.py @@ -0,0 +1,221 @@ +import os +import json +from tqdm import tqdm +from torch.utils.data import DataLoader +from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser +from minigpt4.conversation.conversation import CONV_VISION +from minigpt4.processors.blip_processors import Blip2ImageTrainProcessor,BlipCaptionProcessor +from minigpt4.datasets.datasets.video_datasets import VideoChatGPTEvalDataset,VideoChatGPTEval_consistancy,Video_validation_Dataset,TVQAEVAL,TVQAEVAL_Long + +parser = eval_parser() +parser.add_argument("--dataset", type=str, default='msvd', help="dataset to evaluate") +parser.add_argument("--add_subtitles",action='store_true',help="whether to add subtitles to the video") +parser.add_argument("--name", type=str, default='3_datasets', help="evaluation name") +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument("--start", type=int, default=0, help="start from video number") +parser.add_argument("--end", type=int, default=10000000, help="end at video number") +args = parser.parse_args() + +print(args.ckpt) +print(args.name) +print(args.cfg_path) +if "test_configs/mistral_test_config.yaml" == args.cfg_path: + llm_name="mistral" +else: + llm_name="llama2" +print("using captions",args.add_subtitles) + +model, vis_processor = init_model(args) +conv_temp = CONV_VISION.copy() +conv_temp.system = "" +if args.dataset == 'video_chatgpt_generic': + ann_path="datasets/evaluation_datasets/videochatgpt_benchmark/generic_qa.json" + videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Test_Videos" + subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles" + videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/benchmark/generic" + annotations_keys=['Q','A','video_name'] + data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name) +elif args.dataset == 'video_chatgpt_temporal': + ann_path="datasets/evaluation_datasets/videochatgpt_benchmark/temporal_qa.json" + videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Test_Videos" + subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles" + videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/benchmark/temporal" + annotations_keys=['Q','A','video_name'] + data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name) +elif args.dataset == 'video_chatgpt_consistency': + ann_path="datasets/evaluation_datasets/videochatgpt_benchmark/consistency_qa.json" + videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Test_Videos" + subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles" + annotations_keys=[['Q1','Q2'],'A','video_name'] + data = VideoChatGPTEval_consistancy(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys, add_subtitles=args.add_subtitles,llm_name=llm_name) + +elif args.dataset == 'msrvtt': + ann_path="datasets/evaluation_datasets/msrvtt/val_qa_edited.json" + videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/MSRVTT/videos/all" + subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles" + videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/msrvtt" + annotations_keys=['question','answer','video_id'] + data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name) + +elif args.dataset == 'msvd': + ann_path="datasets/evaluation_datasets/msvd/val_qa_edited.json" + videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/MSVD-QA/videos" + subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles" + videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/msvd" + annotations_keys=['question','answer','video_id'] + data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name) +elif args.dataset == 'activitynet': + ann_path="datasets/evaluation_datasets/activityNet/test_qa.json" + videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/Activity_net/Activity_net_videos" + subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles/" + videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/activity_net" + annotations_keys=['question','answer','video_id'] + data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=args.add_subtitles,llm_name=llm_name) +elif args.dataset == 'tgif': + ann_path="datasets/evaluation_datasets/tgif/Test_frameqa_question.json" + videos_path="/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/TGIF/mp4s" + subtitles_path="/home/ataallka/minigpt_video/minigpt_multi_img/inference_subtitles" + videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/tgif" + annotations_keys=['question','answer','gif_name'] + # annotations_keys=['question','description','gif_name'] + data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path, add_subtitles=False,llm_name=llm_name) +elif args.dataset == 'tvqa': + # TVQA dataset + ann_path="datasets/evaluation_datasets/tvqa_short/tvqa_val.json" + videos_path= "/ibex/project/c2090/datasets/TVR_dataset/videos/video_files/frames_hq/" + subtitles_path="/ibex/project/c2090/datasets/TVR_dataset/TVRetrieval/data/tvqa_preprocessed_subtitles.json" + videos_features_path="/ibex/project/c2106/kirolos/videos_features/evaluation/tvqa" + data = TVQAEVAL(vis_processor, videos_path, ann_path,subtitles_path,videos_features_path,add_subtitles=args.add_subtitles,llm_name=llm_name) + +eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False) + +minigpt4_predict = [] +sub="subtitles" if args.add_subtitles else "no_subtitles" +if args.start == 0 and args.end == 10000000: + save_path = f'results/{args.name}_{args.dataset}_{sub}.json' +else: + print("start from video number",args.start) + print("end at video number",args.end) + save_path = f'results/{args.name}_{args.dataset}_{sub}_{args.start}_{args.end}.json' + +os.makedirs("results", exist_ok=True) +c=0 +pred_result = {} +gt_result = {} +if args.dataset == 'video_chatgpt_consistency': + for images, texts_1,texts_2, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"): + if args.start<= c = args.end : + break + c+=1 + +elif args.dataset == 'tvr': + for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"): + if args.start<= c = args.end : + break + c+=1 +elif args.dataset == 'ego_schema' or args.dataset == 'tvqa' or args.dataset == 'tvqa_long_videos': + for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"): + if args.start<= c = args.end : + break + c+=1 +else: + for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"): + if args.start<= c = args.end : + break + c+=1 + +with open(save_path, 'w') as f: + json.dump(minigpt4_predict, f) +print("saved results to",save_path) +# save results +# bleu_save_path = f'results/{args.name}_{args.dataset}_bleu.json' +# cider_save_path = f'results/{args.name}_{args.dataset}_cider.json' +# chatgpt_eval_save_path = f'results/{args.name}_{args.dataset}_chatgpt_eval.json' +# bleu_results=eval_bleu(minigpt4_predict) +# with open(bleu_save_path, 'w') as f: +# json.dump(bleu_results, f) +# print("bleu_results",bleu_results) +# cider_results=eval_cider(pred_result,gt_result) +# with open(cider_save_path, 'w') as f: +# json.dump(cider_results, f) +# print("mean_cider_scores:",cider_results['mean_cider_scores']) + +# chatgpt_results=chat_gpt_eval(pred_result,gt_result) + +# with open(chatgpt_eval_save_path, 'w') as f: +# json.dump(chatgpt_results, f) +# print("avg_chatgpt_score",chatgpt_results['avg_chatgpt_score']) +# print(chatgpt_results) + + diff --git a/jobs_video/eval/choose_best_ckpt/choose_best_ckpt.py b/jobs_video/eval/choose_best_ckpt/choose_best_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..4040d0122ac93184ec199372d1d9e39395e58c78 --- /dev/null +++ b/jobs_video/eval/choose_best_ckpt/choose_best_ckpt.py @@ -0,0 +1,14 @@ +import os +import shutil +ckpt_dir = 'ckpt_dir' +print(f'number of ckpts: {len(os.listdir(ckpt_dir))}') +for ckpt in sorted(os.listdir(ckpt_dir)): + if not ckpt.endswith('.pth'): + continue + ckpt_path = os.path.join(ckpt_dir,ckpt) + job_name="cmd_webvid_video_instruct_"+ckpt.split(".")[0] + # submit a job with this ckpt file + os.system(f'sbatch ./evalualtion_ckpt.sh {ckpt_path} {job_name}') + # print(f'sbatch ./evalualtion_ckpt.sh {ckpt_path} {job_name}') + # print(f'job {job_name} submitted') + # break diff --git a/jobs_video/eval/choose_best_ckpt/evalualtion_ckpt.sh b/jobs_video/eval/choose_best_ckpt/evalualtion_ckpt.sh new file mode 100644 index 0000000000000000000000000000000000000000..5dde7b55e7fe0b00ad29a30cdf977e72a837f54f --- /dev/null +++ b/jobs_video/eval/choose_best_ckpt/evalualtion_ckpt.sh @@ -0,0 +1,17 @@ +#!/bin/bash +#SBATCH --partition=batch +#SBATCH --job-name=val%j +#SBATCH --output=val%j.out +#SBATCH --error=val%j.err +#SBATCH --time=0-10:00:00 +#SBATCH --mem=100G +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +## run the application: +NAME=$2 # Name of the experiment +DATASET="dataset_name" # available datasets: tvqa, msrvtt, msvd, activitynet,tgif,video_chatgpt_generic,video_chatgpt_temporal,video_chatgpt_consistency +BATCH_SIZE=2 # batch size +CKPT_PATH=$1 # path to the checkpoint +cfg_path="test_configs/mistral_test_config.yaml" # path to the config file +cd ../../../ +python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --add_subtitles \ No newline at end of file diff --git a/jobs_video/eval/llama2_evalualtion.sh b/jobs_video/eval/llama2_evalualtion.sh new file mode 100644 index 0000000000000000000000000000000000000000..d981416e848c67f8daa4f2d72a8cda952b740189 --- /dev/null +++ b/jobs_video/eval/llama2_evalualtion.sh @@ -0,0 +1,37 @@ +#!/bin/bash +#SBATCH --partition=batch +#SBATCH --job-name=llama2_best%j +#SBATCH --output=llama2_best%j.out +#SBATCH --error=llama2_best%j.err +#SBATCH --time=0-23:00:00 +#SBATCH --mem=100G +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +## run the application: +NAME="llama2_best" # Name of the experiment +DATASET="tvqa" # available datasets: tvqa, msrvtt, msvd, activitynet,tgif ,video_chatgpt_generic,video_chatgpt_temporal,video_chatgpt_consistency +BATCH_SIZE=8 +CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth" # path to the checkpoint +cfg_path="test_configs/llama2_test_config.yaml" # path to the config file +# # if the number of samples are large you can specify the start and end index to evaluate on several machines +# pass the start and end index as arguments +start=$1 # start index +end=$2 # end index +# if start and end are not provided, then use the whole dataset +if [ -z "$START" ] +then + START=0 +fi +if [ -z "$END" ] +then + END=10000000 +fi +echo "Start: $START" +echo "End: $END" + +cd ../../ +# without subtitles +python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --start $start --end $end + +# with subtitles +# python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --add_subtitles --start $start --end $end \ No newline at end of file diff --git a/jobs_video/eval/mistral_evalualtion.sh b/jobs_video/eval/mistral_evalualtion.sh new file mode 100644 index 0000000000000000000000000000000000000000..272e29cc928e202e18e7cdce688bb98c3a9ed9e3 --- /dev/null +++ b/jobs_video/eval/mistral_evalualtion.sh @@ -0,0 +1,39 @@ +#!/bin/bash +#SBATCH --partition=batch +#SBATCH --mail-user=kirolos.ataallah@kaust.edu.sa +#SBATCH --mail-type=ALL +#SBATCH --job-name=mistral_best%j +#SBATCH --output=mistral_best%j.out +#SBATCH --error=mistral_best%j.err +#SBATCH --time=0-23:00:00 +#SBATCH --mem=100G +#SBATCH --gres=gpu:a100:1 +#SBATCH --nodes=1 +## run the application: +NAME="mistral_best" # Name of the experiment +DATASET="tvqa" # available datasets: tvqa, msrvtt, msvd, activitynet,tgif,video_chatgpt_generic,video_chatgpt_temporal,video_chatgpt_consistency +BATCH_SIZE=4 # batch size for A100 by using subtiles is 2 and without subtitles is 4 +CKPT_PATH="checkpoints/video_mistral_checkpoint_best.pth" # path to the checkpoint +cfg_path="test_configs/mistral_test_config.yaml" # path to the config file +# # if the number of samples are large you can specify the start and end index to evaluate on several machines +# pass the start and end index as arguments +start=$1 # start index +end=$2 # end index +# if start and end are not provided, then use the whole dataset +if [ -z "$START" ] +then + START=0 +fi +if [ -z "$END" ] +then + END=10000000 +fi +echo "Start: $START" +echo "End: $END" + +cd ../../ +# without subtitles +python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --start $start --end $end + +# with subtitles +# python eval_video.py --dataset $DATASET --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --cfg-path=$cfg_path --add_subtitles --start $start --end $end \ No newline at end of file diff --git a/jobs_video/eval/submit_job.py b/jobs_video/eval/submit_job.py new file mode 100644 index 0000000000000000000000000000000000000000..a32773cddc5518a28fce130dbbc2ee601ae65c21 --- /dev/null +++ b/jobs_video/eval/submit_job.py @@ -0,0 +1,19 @@ +import os +import shutil +import sys + +start=0 +end=7800 +step=800 + +# Mistral +for i in range(start,end,step): + cmd=f'sbatch ./mistral_evalualtion.sh {i} {i+step}' + # print(cmd) + os.system(cmd) + +# Llama 2 +# for i in range(start,end,step): +# cmd=f'sbatch ./llama2_evalualtion.sh {i} {i+step}' +# # print(cmd) +# os.system(cmd) \ No newline at end of file diff --git a/jobs_video/train/stage_2_llama2.sh b/jobs_video/train/stage_2_llama2.sh new file mode 100644 index 0000000000000000000000000000000000000000..6e06bc8d082d1aef9c3ea9c711c9b9f25f1466c7 --- /dev/null +++ b/jobs_video/train/stage_2_llama2.sh @@ -0,0 +1,23 @@ +#!/bin/bash +#SBATCH --partition=batch +#SBATCH --job-name=test +#SBATCH --output=test.out +#SBATCH --error=test.err +#SBATCH --time=23:00:00 +#SBATCH --mem=110G +#SBATCH --gres=gpu:a100:4 +#SBATCH --cpus-per-task=16 +## run the application: +job_name=test # Name of the experiment +cfg_path="train_configs/224_v2_llama2_video_stage_2.yaml" # path to the config file +number_of_gpus=1 # number of gpus +# cd ../../ + +read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range +while : +do + PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`" + ss -lpn | grep -q ":$PORT " || break +done +echo "Port is $PORT" +torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path} \ No newline at end of file diff --git a/jobs_video/train/stage_2_mistral.sh b/jobs_video/train/stage_2_mistral.sh new file mode 100644 index 0000000000000000000000000000000000000000..acbebfb8c383a01ace15e71085a39653f2f9d31a --- /dev/null +++ b/jobs_video/train/stage_2_mistral.sh @@ -0,0 +1,23 @@ +#!/bin/bash +#SBATCH --partition=batch +#SBATCH --job-name=test +#SBATCH --output=test.out +#SBATCH --error=test.err +#SBATCH --time=23:00:00 +#SBATCH --mem=110G +#SBATCH --gres=gpu:a100:4 +#SBATCH --cpus-per-task=16 +## run the application: +job_name=test # Name of the experiment +cfg_path="train_configs/224_v2_mistral_video_stage_2.yaml" # path to the config file +number_of_gpus=1 # number of gpus +# cd ../../ + +read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range +while : +do + PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`" + ss -lpn | grep -q ":$PORT " || break +done +echo "Port is $PORT" +torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path} \ No newline at end of file diff --git a/jobs_video/train/stage_3_llama2.sh b/jobs_video/train/stage_3_llama2.sh new file mode 100644 index 0000000000000000000000000000000000000000..5b4cce9a1420a829886dfa84d8d0ea7f2d6e1eaf --- /dev/null +++ b/jobs_video/train/stage_3_llama2.sh @@ -0,0 +1,23 @@ +#!/bin/bash +#SBATCH --partition=batch +#SBATCH --job-name=test +#SBATCH --output=test.out +#SBATCH --error=test.err +#SBATCH --time=23:00:00 +#SBATCH --mem=110G +#SBATCH --gres=gpu:a100:4 +#SBATCH --cpus-per-task=16 +## run the application: +job_name="test" # Name of the experiment +cfg_path="train_configs/224_v2_llama2_video_stage_3.yaml" # path to the config file +number_of_gpus=1 # number of gpus +# cd ../../ + +read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range +while : +do + PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`" + ss -lpn | grep -q ":$PORT " || break +done +echo "Port is $PORT" +torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path} \ No newline at end of file diff --git a/jobs_video/train/stage_3_mistral.sh b/jobs_video/train/stage_3_mistral.sh new file mode 100644 index 0000000000000000000000000000000000000000..44391e887c9505d6487710986d524855eaa472de --- /dev/null +++ b/jobs_video/train/stage_3_mistral.sh @@ -0,0 +1,23 @@ +#!/bin/bash +#SBATCH --partition=batch +#SBATCH --job-name=test +#SBATCH --output=test.out +#SBATCH --error=test.err +#SBATCH --time=23:00:00 +#SBATCH --mem=110G +#SBATCH --gres=gpu:a100:4 +#SBATCH --cpus-per-task=16 +## run the application: +job_name="test" # Name of the experiment +cfg_path="train_configs/224_v2_mistral_video_stage_3.yaml" # path to the config file +number_of_gpus=1 # number of gpus +# cd ../../ + +read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range +while : +do + PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`" + ss -lpn | grep -q ":$PORT " || break +done +echo "Port is $PORT" +torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path} \ No newline at end of file diff --git a/minigpt4/__init__.py b/minigpt4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb31f42f9107a0b748b878deb1c5768019d62b32 --- /dev/null +++ b/minigpt4/__init__.py @@ -0,0 +1,31 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import sys + +from omegaconf import OmegaConf + +from minigpt4.common.registry import registry + +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.tasks import * + + +root_dir = os.path.dirname(os.path.abspath(__file__)) +default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) + +registry.register_path("library_root", root_dir) +repo_root = os.path.join(root_dir, "..") +registry.register_path("repo_root", repo_root) +cache_root = os.path.join(repo_root, default_cfg.env.cache_root) +registry.register_path("cache_root", cache_root) + +registry.register("MAX_INT", sys.maxsize) +registry.register("SPLIT_NAMES", ["train", "val", "test"]) diff --git a/minigpt4/common/__init__.py b/minigpt4/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/minigpt4/common/config.py b/minigpt4/common/config.py new file mode 100644 index 0000000000000000000000000000000000000000..0d092a352ff8719573e5c0e3d9584f62442fd9df --- /dev/null +++ b/minigpt4/common/config.py @@ -0,0 +1,474 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import json +from typing import Dict + +from omegaconf import OmegaConf +from minigpt4.common.registry import registry + + +class Config: + def __init__(self, args): + self.config = {} + + self.args = args + + # Register the config and configuration for setup + registry.register("configuration", self) + + user_config = self._build_opt_list(self.args.options) + + config = OmegaConf.load(self.args.cfg_path) + + runner_config = self.build_runner_config(config) + model_config = self.build_model_config(config, **user_config) + dataset_config = self.build_dataset_config(config) + + # Validate the user-provided runner configuration + # model and dataset configuration are supposed to be validated by the respective classes + # [TODO] validate the model/dataset configuration + # self._validate_runner_config(runner_config) + + # Override the default configuration with user options. + self.config = OmegaConf.merge( + runner_config, model_config, dataset_config, user_config + ) + + def _validate_runner_config(self, runner_config): + """ + This method validates the configuration, such that + 1) all the user specified options are valid; + 2) no type mismatches between the user specified options and the config. + """ + runner_config_validator = create_runner_config_validator() + runner_config_validator.validate(runner_config) + + def _build_opt_list(self, opts): + opts_dot_list = self._convert_to_dot_list(opts) + return OmegaConf.from_dotlist(opts_dot_list) + + @staticmethod + def build_model_config(config, **kwargs): + model = config.get("model", None) + assert model is not None, "Missing model configuration file." + + model_cls = registry.get_model_class(model.arch) + assert model_cls is not None, f"Model '{model.arch}' has not been registered." + + model_type = kwargs.get("model.model_type", None) + if not model_type: + model_type = model.get("model_type", None) + # else use the model type selected by user. + + assert model_type is not None, "Missing model_type." + + print("--------------") + print("model arch",model.arch) + print("model cls",model_cls) + + model_config_path = model_cls.default_config_path(model_type=model_type) + + model_config = OmegaConf.create() + # hierarchy override, customized config > default config + model_config = OmegaConf.merge( + model_config, + OmegaConf.load(model_config_path), + {"model": config["model"]}, + ) + + return model_config + + @staticmethod + def build_runner_config(config): + return {"run": config.run} + + @staticmethod + def build_dataset_config(config): + datasets = config.get("datasets", None) + if datasets is None: + raise KeyError( + "Expecting 'datasets' as the root key for dataset configuration." + ) + + dataset_config = OmegaConf.create() + + for dataset_name in datasets: + + print("dataset name", dataset_name) + builder_cls = registry.get_builder_class(dataset_name) + + dataset_config_type = datasets[dataset_name].get("type", "default") + dataset_config_path = builder_cls.default_config_path( + type=dataset_config_type + ) + + # hierarchy override, customized config > default config + dataset_config = OmegaConf.merge( + dataset_config, + OmegaConf.load(dataset_config_path), + {"datasets": {dataset_name: config["datasets"][dataset_name]}}, + ) + + return dataset_config + + def _convert_to_dot_list(self, opts): + if opts is None: + opts = [] + + if len(opts) == 0: + return opts + + has_equal = opts[0].find("=") != -1 + + if has_equal: + return opts + + return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] + + def get_config(self): + return self.config + + @property + def run_cfg(self): + return self.config.run + + @property + def datasets_cfg(self): + return self.config.datasets + + @property + def model_cfg(self): + return self.config.model + + def pretty_print(self): + logging.info("\n===== Running Parameters =====") + logging.info(self._convert_node_to_json(self.config.run)) + + logging.info("\n====== Dataset Attributes ======") + datasets = self.config.datasets + + for dataset in datasets: + if dataset in self.config.datasets: + logging.info(f"\n======== {dataset} =======") + dataset_config = self.config.datasets[dataset] + logging.info(self._convert_node_to_json(dataset_config)) + else: + logging.warning(f"No dataset named '{dataset}' in config. Skipping") + + logging.info(f"\n====== Model Attributes ======") + logging.info(self._convert_node_to_json(self.config.model)) + + def _convert_node_to_json(self, node): + container = OmegaConf.to_container(node, resolve=True) + return json.dumps(container, indent=4, sort_keys=True) + + def to_dict(self): + return OmegaConf.to_container(self.config) + + +def node_to_dict(node): + return OmegaConf.to_container(node) + + +class ConfigValidator: + """ + This is a preliminary implementation to centralize and validate the configuration. + May be altered in the future. + + A helper class to validate configurations from yaml file. + + This serves the following purposes: + 1. Ensure all the options in the yaml are defined, raise error if not. + 2. when type mismatches are found, the validator will raise an error. + 3. a central place to store and display helpful messages for supported configurations. + + """ + + class _Argument: + def __init__(self, name, choices=None, type=None, help=None): + self.name = name + self.val = None + self.choices = choices + self.type = type + self.help = help + + def __str__(self): + s = f"{self.name}={self.val}" + if self.type is not None: + s += f", ({self.type})" + if self.choices is not None: + s += f", choices: {self.choices}" + if self.help is not None: + s += f", ({self.help})" + return s + + def __init__(self, description): + self.description = description + + self.arguments = dict() + + self.parsed_args = None + + def __getitem__(self, key): + assert self.parsed_args is not None, "No arguments parsed yet." + + return self.parsed_args[key] + + def __str__(self) -> str: + return self.format_help() + + def add_argument(self, *args, **kwargs): + """ + Assume the first argument is the name of the argument. + """ + self.arguments[args[0]] = self._Argument(*args, **kwargs) + + def validate(self, config=None): + """ + Convert yaml config (dict-like) to list, required by argparse. + """ + for k, v in config.items(): + assert ( + k in self.arguments + ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}.""" + + if self.arguments[k].type is not None: + try: + self.arguments[k].val = self.arguments[k].type(v) + except ValueError: + raise ValueError(f"{k} is not a valid {self.arguments[k].type}.") + + if self.arguments[k].choices is not None: + assert ( + v in self.arguments[k].choices + ), f"""{k} must be one of {self.arguments[k].choices}.""" + + return config + + def format_arguments(self): + return str([f"{k}" for k in sorted(self.arguments.keys())]) + + def format_help(self): + # description + key-value pair string for each argument + help_msg = str(self.description) + return help_msg + ", available arguments: " + self.format_arguments() + + def print_help(self): + # display help message + print(self.format_help()) + + +def create_runner_config_validator(): + validator = ConfigValidator(description="Runner configurations") + + validator.add_argument( + "runner", + type=str, + choices=["runner_base", "runner_iter"], + help="""Runner to use. The "runner_base" uses epoch-based training while iter-based + runner runs based on iters. Default: runner_base""", + ) + # add argumetns for training dataset ratios + validator.add_argument( + "train_dataset_ratios", + type=Dict[str, float], + help="""Ratios of training dataset. This is used in iteration-based runner. + Do not support for epoch-based runner because how to define an epoch becomes tricky. + Default: None""", + ) + validator.add_argument( + "max_iters", + type=float, + help="Maximum number of iterations to run.", + ) + validator.add_argument( + "max_epoch", + type=int, + help="Maximum number of epochs to run.", + ) + # add arguments for iters_per_inner_epoch + validator.add_argument( + "iters_per_inner_epoch", + type=float, + help="Number of iterations per inner epoch. This is required when runner is runner_iter.", + ) + lr_scheds_choices = registry.list_lr_schedulers() + validator.add_argument( + "lr_sched", + type=str, + choices=lr_scheds_choices, + help="Learning rate scheduler to use, from {}".format(lr_scheds_choices), + ) + task_choices = registry.list_tasks() + validator.add_argument( + "task", + type=str, + choices=task_choices, + help="Task to use, from {}".format(task_choices), + ) + # add arguments for init_lr + validator.add_argument( + "init_lr", + type=float, + help="Initial learning rate. This will be the learning rate after warmup and before decay.", + ) + # add arguments for min_lr + validator.add_argument( + "min_lr", + type=float, + help="Minimum learning rate (after decay).", + ) + # add arguments for warmup_lr + validator.add_argument( + "warmup_lr", + type=float, + help="Starting learning rate for warmup.", + ) + # add arguments for learning rate decay rate + validator.add_argument( + "lr_decay_rate", + type=float, + help="Learning rate decay rate. Required if using a decaying learning rate scheduler.", + ) + # add arguments for weight decay + validator.add_argument( + "weight_decay", + type=float, + help="Weight decay rate.", + ) + # add arguments for training batch size + validator.add_argument( + "batch_size_train", + type=int, + help="Training batch size.", + ) + # add arguments for evaluation batch size + validator.add_argument( + "batch_size_eval", + type=int, + help="Evaluation batch size, including validation and testing.", + ) + # add arguments for number of workers for data loading + validator.add_argument( + "num_workers", + help="Number of workers for data loading.", + ) + # add arguments for warm up steps + validator.add_argument( + "warmup_steps", + type=int, + help="Number of warmup steps. Required if a warmup schedule is used.", + ) + # add arguments for random seed + validator.add_argument( + "seed", + type=int, + help="Random seed.", + ) + # add arguments for output directory + validator.add_argument( + "output_dir", + type=str, + help="Output directory to save checkpoints and logs.", + ) + # add arguments for whether only use evaluation + validator.add_argument( + "evaluate", + help="Whether to only evaluate the model. If true, training will not be performed.", + ) + # add arguments for splits used for training, e.g. ["train", "val"] + validator.add_argument( + "train_splits", + type=list, + help="Splits to use for training.", + ) + # add arguments for splits used for validation, e.g. ["val"] + validator.add_argument( + "valid_splits", + type=list, + help="Splits to use for validation. If not provided, will skip the validation.", + ) + # add arguments for splits used for testing, e.g. ["test"] + validator.add_argument( + "test_splits", + type=list, + help="Splits to use for testing. If not provided, will skip the testing.", + ) + # add arguments for accumulating gradient for iterations + validator.add_argument( + "accum_grad_iters", + type=int, + help="Number of iterations to accumulate gradient for.", + ) + + # ====== distributed training ====== + validator.add_argument( + "device", + type=str, + choices=["cpu", "cuda"], + help="Device to use. Support 'cuda' or 'cpu' as for now.", + ) + validator.add_argument( + "world_size", + type=int, + help="Number of processes participating in the job.", + ) + validator.add_argument("dist_url", type=str) + validator.add_argument("distributed", type=bool) + # add arguments to opt using distributed sampler during evaluation or not + validator.add_argument( + "use_dist_eval_sampler", + type=bool, + help="Whether to use distributed sampler during evaluation or not.", + ) + + # ====== task specific ====== + # generation task specific arguments + # add arguments for maximal length of text output + validator.add_argument( + "max_len", + type=int, + help="Maximal length of text output.", + ) + # add arguments for minimal length of text output + validator.add_argument( + "min_len", + type=int, + help="Minimal length of text output.", + ) + # add arguments number of beams + validator.add_argument( + "num_beams", + type=int, + help="Number of beams used for beam search.", + ) + + # vqa task specific arguments + # add arguments for number of answer candidates + validator.add_argument( + "num_ans_candidates", + type=int, + help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""", + ) + # add arguments for inference method + validator.add_argument( + "inference_method", + type=str, + choices=["genearte", "rank"], + help="""Inference method to use for question answering. If rank, requires a answer list.""", + ) + + # ====== model specific ====== + validator.add_argument( + "k_test", + type=int, + help="Number of top k most similar samples from ITC/VTC selection to be tested.", + ) + + return validator diff --git a/minigpt4/common/dist_utils.py b/minigpt4/common/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8022023f9b37852187bdfd788b7db16bd47599f7 --- /dev/null +++ b/minigpt4/common/dist_utils.py @@ -0,0 +1,146 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import functools +import os + +import torch +import torch.distributed as dist +import timm.models.hub as timm_hub + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def init_distributed_mode(args): + if args.distributed is False: + print("Not using distributed mode") + args.rank = 0 + return + + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + args.distributed = False + args.rank = 0 + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}, world {}): {}".format( + args.rank, args.world_size, args.dist_url + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + timeout=datetime.timedelta( + days=365 + ), # allow auto-downloading and de-compressing + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def get_dist_info(): + if torch.__version__ < "1.0": + initialized = dist._initialized + else: + initialized = dist.is_initialized() + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: # non-distributed training + rank = 0 + world_size = 1 + return rank, world_size + + +def main_process(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + +def download_cached_file(url, check_hash=True, progress=False): + """ + Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. + If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. + """ + + def get_cached_file_path(): + # a hack to sync the file path across processes + parts = torch.hub.urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(timm_hub.get_cache_dir(), filename) + + return cached_file + + if is_main_process(): + timm_hub.download_cached_file(url, check_hash, progress) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return get_cached_file_path() diff --git a/minigpt4/common/eval_utils.py b/minigpt4/common/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0450873c6bed49146c062de780cd6f2d6c38b9e0 --- /dev/null +++ b/minigpt4/common/eval_utils.py @@ -0,0 +1,224 @@ +import argparse +import numpy as np +from nltk.translate.bleu_score import sentence_bleu +import sys +sys.path.append('/home/ataallka/minigpt_video/minigpt_multi_img') +from minigpt4.common.registry import registry +from minigpt4.common.config import Config + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +# from minigpt4.runners import * +from minigpt4.tasks import * +from pycocoevalcap.cider.cider import Cider +import os +import openai +from tqdm import tqdm +import json +import ast +import time + +def eval_parser(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", help="path to configuration file.",default="test_configs/llama2_test_config.yaml") + parser.add_argument("--ckpt", type=str,default='checkpoints/video_llama_checkpoint_last.pth', help="path to checkpoint") + parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.") + parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens") + parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model") + parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + return parser + + +def prepare_texts(texts, conv_temp, template='', lengths=None): + convs = [conv_temp.copy() for _ in range(len(texts))] + if lengths is None: + [conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for conv, text in zip(convs, texts)] + else: + templates = [template * length for length in lengths] + [conv.append_message(conv.roles[0], '{} {}'.format(template, text)) for template, conv, text in zip(templates, convs, texts)] + [conv.append_message(conv.roles[1], None) for conv in convs] + texts = [conv.get_prompt() for conv in convs] + return texts + + +def init_model(args): + print('Initialization Model') + cfg = Config(args) + cfg.model_cfg.ckpt = args.ckpt + cfg.model_cfg.lora_r = args.lora_r + cfg.model_cfg.lora_alpha = args.lora_alpha + + model_config = cfg.model_cfg + model_config.low_resource = True + model_cls = registry.get_model_class(model_config.arch) + model = model_cls.from_config(model_config).to('cuda:0') + +# import pudb; pudb.set_trace() + key = list(cfg.datasets_cfg.keys())[0] + vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train + print(vis_processor_cfg) + vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) + print('Initialization Finished') + return model, vis_processor + +def computeIoU(bbox1, bbox2): + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + intersection_x1 = max(x1, x3) + intersection_y1 = max(y1, y3) + intersection_x2 = min(x2, x4) + intersection_y2 = min(y2, y4) + intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1) + bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1) + bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1) + union_area = bbox1_area + bbox2_area - intersection_area + iou = intersection_area / union_area + return iou + +def eval_bleu(results): + bleus1,bleus2,bleus3,bleus4 = [],[],[],[] + for result in tqdm (results,desc="bleu_eval"): + gt = result['gt'] + pred = result['pred'] + bleus1.append(sentence_bleu([gt.split()], pred.split(), weights=(1,0,0,0))) + bleus2.append(sentence_bleu([gt.split()], pred.split(), weights=(0.5,0.5,0,0))) + bleus3.append(sentence_bleu([gt.split()], pred.split(), weights=(0.33,0.33,0.33,0))) + bleus4.append(sentence_bleu([gt.split()], pred.split())) + # print(np.mean(bleus1),np.mean(bleus2),np.mean(bleus3),np.mean(bleus4),flush=True) + return {'bleu1':np.mean(bleus1),'bleu2':np.mean(bleus2),'bleu3':np.mean(bleus3),'bleu4':np.mean(bleus4)} + +# Create a Cider object +cider_scorer = Cider() +def eval_cider(pred_result,gt_result): + # Compute CIDEr scores + mean_cider_scores, cider_scores = cider_scorer.compute_score(gt_result, pred_result) + cider_scores_dict={} + for score,pred_vid_id,gt_vid_id in tqdm(zip(cider_scores.tolist(),pred_result,gt_result),desc="cider_eval") : + assert pred_vid_id==gt_vid_id + cider_scores_dict[pred_vid_id] = score + return {'mean_cider_scores':mean_cider_scores,'cider_scores':cider_scores_dict} + + +openai.api_key_path = "/home/ataallka/chatgpt_api.txt" + + +def chat_gpt_eval(results,output_path): + trial=0 + gpt_results=[] + avg_chatgpt_score=0 + existed_files={} + # read previous results from output path + for file in os.listdir(output_path): + if file.endswith(".json"): + with open(f'{output_path}/{file}') as json_file: + data = json.load(json_file) + gpt_results.append(data[0]) + avg_chatgpt_score+=float(data[0]['chatgpt_score']) + existed_files[data[0]['video_name']]=True + length_output_path=len(os.listdir(output_path)) + while len (results)!= length_output_path: + for res in tqdm(results,desc="chatgpt_eval"): + if existed_files.get(res['video_name'],False): + continue + video_name=res['video_name'] + sentence_1=res['A'] + sentence_2=res['pred'] + try: + # prompt=f"given these 2 sentences the first one is the ground truth text and the second sentence is the generated text ,give me a score from 0 to 1 to evaluate how much they are similar to each other, and have the same context and related to each other to evaluate the quality of this generated text.the output should be only the score float number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:" + prompt=f"given these 2 sentences the first one is the ground truth descrption of a video and the second sentence is the generated text from a video summarization model,give it a score from 0 to 5 to evaluate the model summarization performance.the output should be only the score number without any additional information\nfirst sentence: {sentence_1}\nsecond sentence: {sentence_2}\nscore:" + response = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": prompt + }], + ) + res['chatgpt_score']=response.choices[0].message['content'] + out={'video_name':video_name,'chatgpt_score':response.choices[0].message['content']} + gpt_results.append(out) + # save each video result in a json file + with open(f'{output_path}/{video_name}.json', 'w') as f: + json.dump([out], f) + avg_chatgpt_score+=float(response.choices[0].message['content']) + except Exception as e: + print("chat gpt error",e) + print ("Finished chat gpt evaluation in trial",trial) + trial+=1 + length_output_path=len(os.listdir(output_path)) + return results,avg_chatgpt_score/len(results) +def GPT4_answer(question, answer,pred): + try: + # Compute the correctness score + completion = openai.ChatCompletion.create( + # model="gpt-3.5-turbo", + model='gpt-4', + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer." + }, + { + "role": "user", + "content": + "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " + "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + return response_dict + except Exception as e: + print(f"Error : {e}") + return None +def GPT4_evaluation(val_result): + scores=[] + yes_count=0 + no_count=0 + for res in val_result: + gpt_response=GPT4_answer(res['Q'],res['A'],res['pred']) + if gpt_response is None: + continue + try: + scores.append(float(gpt_response['score'])) + if 'yes' in gpt_response['pred'].lower(): + yes_count+=1 + elif 'no' in gpt_response['pred'].lower(): + no_count+=1 + except: + continue + avg_score=sum(scores)/len(scores) + accuracy=(yes_count/(yes_count+no_count))*100 + print(f"chatgpt score: {avg_score} accuracy: {accuracy}") + return avg_score,accuracy + +# with open('results/ckpt_15_res89_res32_Video_validation_Dataset_subtitles.json','r') as f: +# results = json.load(f) +# t1=time.time() +# avg_score,accuracy=GPT4_evaluation(results) +# print(f"chatgpt score: {avg_score} accuracy: {accuracy}") +# print(f"Time taken: {time.time()-t1}") \ No newline at end of file diff --git a/minigpt4/common/gradcam.py b/minigpt4/common/gradcam.py new file mode 100644 index 0000000000000000000000000000000000000000..d53a5254d4b319eaf2cbfbd081b0ca8e38c5c7a0 --- /dev/null +++ b/minigpt4/common/gradcam.py @@ -0,0 +1,24 @@ +import numpy as np +from matplotlib import pyplot as plt +from scipy.ndimage import filters +from skimage import transform as skimage_transform + + +def getAttMap(img, attMap, blur=True, overlap=True): + attMap -= attMap.min() + if attMap.max() > 0: + attMap /= attMap.max() + attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") + if blur: + attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) + attMap -= attMap.min() + attMap /= attMap.max() + cmap = plt.get_cmap("jet") + attMapV = cmap(attMap) + attMapV = np.delete(attMapV, 3, 2) + if overlap: + attMap = ( + 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img + + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV + ) + return attMap diff --git a/minigpt4/common/logger.py b/minigpt4/common/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5a727213c6478606a154172830cdc43aae6f5a --- /dev/null +++ b/minigpt4/common/logger.py @@ -0,0 +1,195 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import logging +import time +from collections import defaultdict, deque + +import torch +import torch.distributed as dist + +from minigpt4.common import dist_utils + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not dist_utils.is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def global_avg(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def setup_logger(): + logging.basicConfig( + level=logging.INFO if dist_utils.is_main_process() else logging.WARN, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], + ) diff --git a/minigpt4/common/optims.py b/minigpt4/common/optims.py new file mode 100644 index 0000000000000000000000000000000000000000..270e66bf36afb768b44aff595d5dea415ddb6e9f --- /dev/null +++ b/minigpt4/common/optims.py @@ -0,0 +1,119 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import math + +from minigpt4.common.registry import registry + + +@registry.register_lr_scheduler("linear_warmup_step_lr") +class LinearWarmupStepLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + min_lr, + init_lr, + decay_rate=1, + warmup_start_lr=-1, + warmup_steps=0, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.min_lr = min_lr + + self.decay_rate = decay_rate + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + if cur_epoch == 0: + warmup_lr_schedule( + step=cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + step_lr_schedule( + epoch=cur_epoch, + optimizer=self.optimizer, + init_lr=self.init_lr, + min_lr=self.min_lr, + decay_rate=self.decay_rate, + ) + + +@registry.register_lr_scheduler("linear_warmup_cosine_lr") +class LinearWarmupCosineLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + iters_per_epoch, + min_lr, + init_lr, + warmup_steps=0, + warmup_start_lr=-1, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.iters_per_epoch = iters_per_epoch + self.min_lr = min_lr + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + total_cur_step = cur_epoch * self.iters_per_epoch + cur_step + if total_cur_step < self.warmup_steps: + warmup_lr_schedule( + step=total_cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + cosine_lr_schedule( + epoch=total_cur_step, + optimizer=self.optimizer, + max_epoch=self.max_epoch * self.iters_per_epoch, + init_lr=self.init_lr, + min_lr=self.min_lr, + ) + + +def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): + """Decay the learning rate""" + lr = (init_lr - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * epoch / max_epoch) + ) + min_lr + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): + """Warmup the learning rate""" + lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): + """Decay the learning rate""" + lr = max(min_lr, init_lr * (decay_rate**epoch)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr diff --git a/minigpt4/common/registry.py b/minigpt4/common/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..c95309756088f9d99f8e4f3b9678027c203f9cc3 --- /dev/null +++ b/minigpt4/common/registry.py @@ -0,0 +1,330 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + + +class Registry: + mapping = { + "builder_name_mapping": {}, + "task_name_mapping": {}, + "processor_name_mapping": {}, + "model_name_mapping": {}, + "lr_scheduler_name_mapping": {}, + "runner_name_mapping": {}, + "state": {}, + "paths": {}, + } + + @classmethod + def register_builder(cls, name): + r"""Register a dataset builder to registry with key 'name' + + Args: + name: Key with which the builder will be registered. + + Usage: + + from minigpt4.common.registry import registry + from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder + """ + + def wrap(builder_cls): + from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder + + assert issubclass( + builder_cls, BaseDatasetBuilder + ), "All builders must inherit BaseDatasetBuilder class, found {}".format( + builder_cls + ) + if name in cls.mapping["builder_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["builder_name_mapping"][name] + ) + ) + cls.mapping["builder_name_mapping"][name] = builder_cls + return builder_cls + + return wrap + + @classmethod + def register_task(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(task_cls): + from minigpt4.tasks.base_task import BaseTask + + assert issubclass( + task_cls, BaseTask + ), "All tasks must inherit BaseTask class" + if name in cls.mapping["task_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["task_name_mapping"][name] + ) + ) + cls.mapping["task_name_mapping"][name] = task_cls + return task_cls + + return wrap + + @classmethod + def register_model(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(model_cls): + # from minigpt4.models import BaseModel + + # assert issubclass( + # model_cls, BaseModel + # ), "All models must inherit BaseModel class" + + if name in cls.mapping["model_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["model_name_mapping"][name] + ) + ) + cls.mapping["model_name_mapping"][name] = model_cls + return model_cls + + return wrap + + @classmethod + def register_processor(cls, name): + r"""Register a processor to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(processor_cls): + from minigpt4.processors import BaseProcessor + + assert issubclass( + processor_cls, BaseProcessor + ), "All processors must inherit BaseProcessor class" + if name in cls.mapping["processor_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["processor_name_mapping"][name] + ) + ) + cls.mapping["processor_name_mapping"][name] = processor_cls + return processor_cls + + return wrap + + @classmethod + def register_lr_scheduler(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(lr_sched_cls): + if name in cls.mapping["lr_scheduler_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["lr_scheduler_name_mapping"][name] + ) + ) + cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls + return lr_sched_cls + + return wrap + + @classmethod + def register_runner(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + + def wrap(runner_cls): + if name in cls.mapping["runner_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["runner_name_mapping"][name] + ) + ) + cls.mapping["runner_name_mapping"][name] = runner_cls + return runner_cls + + return wrap + + @classmethod + def register_path(cls, name, path): + r"""Register a path to registry with key 'name' + + Args: + name: Key with which the path will be registered. + + Usage: + + from minigpt4.common.registry import registry + """ + assert isinstance(path, str), "All path must be str." + if name in cls.mapping["paths"]: + raise KeyError("Name '{}' already registered.".format(name)) + cls.mapping["paths"][name] = path + + @classmethod + def register(cls, name, obj): + r"""Register an item to registry with key 'name' + + Args: + name: Key with which the item will be registered. + + Usage:: + + from minigpt4.common.registry import registry + + registry.register("config", {}) + """ + path = name.split(".") + current = cls.mapping["state"] + + for part in path[:-1]: + if part not in current: + current[part] = {} + current = current[part] + + current[path[-1]] = obj + + # @classmethod + # def get_trainer_class(cls, name): + # return cls.mapping["trainer_name_mapping"].get(name, None) + + @classmethod + def get_builder_class(cls, name): + return cls.mapping["builder_name_mapping"].get(name, None) + + @classmethod + def get_model_class(cls, name): + return cls.mapping["model_name_mapping"].get(name, None) + + @classmethod + def get_task_class(cls, name): + return cls.mapping["task_name_mapping"].get(name, None) + + @classmethod + def get_processor_class(cls, name): + return cls.mapping["processor_name_mapping"].get(name, None) + + @classmethod + def get_lr_scheduler_class(cls, name): + return cls.mapping["lr_scheduler_name_mapping"].get(name, None) + + @classmethod + def get_runner_class(cls, name): + return cls.mapping["runner_name_mapping"].get(name, None) + + @classmethod + def list_runners(cls): + return sorted(cls.mapping["runner_name_mapping"].keys()) + + @classmethod + def list_models(cls): + return sorted(cls.mapping["model_name_mapping"].keys()) + + @classmethod + def list_tasks(cls): + return sorted(cls.mapping["task_name_mapping"].keys()) + + @classmethod + def list_processors(cls): + return sorted(cls.mapping["processor_name_mapping"].keys()) + + @classmethod + def list_lr_schedulers(cls): + return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) + + @classmethod + def list_datasets(cls): + return sorted(cls.mapping["builder_name_mapping"].keys()) + + @classmethod + def get_path(cls, name): + return cls.mapping["paths"].get(name, None) + + @classmethod + def get(cls, name, default=None, no_warning=False): + r"""Get an item from registry with key 'name' + + Args: + name (string): Key whose value needs to be retrieved. + default: If passed and key is not in registry, default value will + be returned with a warning. Default: None + no_warning (bool): If passed as True, warning when key doesn't exist + will not be generated. Useful for MMF's + internal operations. Default: False + """ + original_name = name + name = name.split(".") + value = cls.mapping["state"] + for subname in name: + value = value.get(subname, default) + if value is default: + break + + if ( + "writer" in cls.mapping["state"] + and value == default + and no_warning is False + ): + cls.mapping["state"]["writer"].warning( + "Key {} is not present in registry, returning default value " + "of {}".format(original_name, default) + ) + return value + + @classmethod + def unregister(cls, name): + r"""Remove an item from registry with key 'name' + + Args: + name: Key which needs to be removed. + Usage:: + + from mmf.common.registry import registry + + config = registry.unregister("config") + """ + return cls.mapping["state"].pop(name, None) + + +registry = Registry() diff --git a/minigpt4/common/utils.py b/minigpt4/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..09516aa7f1ab99f21f41d7596e355f878ec6245f --- /dev/null +++ b/minigpt4/common/utils.py @@ -0,0 +1,424 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import io +import json +import logging +import os +import pickle +import re +import shutil +import urllib +import urllib.error +import urllib.request +from typing import Optional +from urllib.parse import urlparse + +import numpy as np +import pandas as pd +import yaml +from iopath.common.download import download +from iopath.common.file_io import file_lock, g_pathmgr +from minigpt4.common.registry import registry +from torch.utils.model_zoo import tqdm +from torchvision.datasets.utils import ( + check_integrity, + download_file_from_google_drive, + extract_archive, +) + + +def now(): + from datetime import datetime + + return datetime.now().strftime("%Y%m%d%H%M") + + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + + +def get_cache_path(rel_path): + return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path)) + + +def get_abs_path(rel_path): + return os.path.join(registry.get_path("library_root"), rel_path) + + +def load_json(filename): + with open(filename, "r") as f: + return json.load(f) + + +# The following are adapted from torchvision and vissl +# torchvision: https://github.com/pytorch/vision +# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + print(f"Error creating directory: {dir_path}") + return is_success + + +def get_redirected_url(url: str): + """ + Given a URL, returns the URL it redirects to or the + original URL in case of no indirection + """ + import requests + + with requests.Session() as session: + with session.get(url, stream=True, allow_redirects=True) as response: + if response.history: + return response.url + else: + return url + + +def to_google_drive_download_url(view_url: str) -> str: + """ + Utility function to transform a view URL of google drive + to a download URL for google drive + Example input: + https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view + Example output: + https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp + """ + splits = view_url.split("/") + assert splits[-1] == "view" + file_id = splits[-2] + return f"https://drive.google.com/uc?export=download&id={file_id}" + + +def download_google_drive_url(url: str, output_path: str, output_file_name: str): + """ + Download a file from google drive + Downloading an URL from google drive requires confirmation when + the file of the size is too big (google drive notifies that + anti-viral checks cannot be performed on such files) + """ + import requests + + with requests.Session() as session: + + # First get the confirmation token and append it to the URL + with session.get(url, stream=True, allow_redirects=True) as response: + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + url = url + "&confirm=" + v + + # Then download the content of the file + with session.get(url, stream=True, verify=True) as response: + makedir(output_path) + path = os.path.join(output_path, output_file_name) + total_size = int(response.headers.get("Content-length", 0)) + with open(path, "wb") as file: + from tqdm import tqdm + + with tqdm(total=total_size) as progress_bar: + for block in response.iter_content( + chunk_size=io.DEFAULT_BUFFER_SIZE + ): + file.write(block) + progress_bar.update(len(block)) + + +def _get_google_drive_file_id(url: str) -> Optional[str]: + parts = urlparse(url) + + if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: + return None + + match = re.match(r"/file/d/(?P[^/]*)", parts.path) + if match is None: + return None + + return match.group("id") + + +def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: + with open(filename, "wb") as fh: + with urllib.request.urlopen( + urllib.request.Request(url, headers={"User-Agent": "vissl"}) + ) as response: + with tqdm(total=response.length) as pbar: + for chunk in iter(lambda: response.read(chunk_size), ""): + if not chunk: + break + pbar.update(chunk_size) + fh.write(chunk) + + +def download_url( + url: str, + root: str, + filename: Optional[str] = None, + md5: Optional[str] = None, +) -> None: + """Download a file from a url and place it in root. + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. + If None, use the basename of the URL. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + makedir(root) + + # check if file is already present locally + if check_integrity(fpath, md5): + print("Using downloaded and verified file: " + fpath) + return + + # expand redirect chain if needed + url = get_redirected_url(url) + + # check if file is located on Google Drive + file_id = _get_google_drive_file_id(url) + if file_id is not None: + return download_file_from_google_drive(file_id, root, filename, md5) + + # download the file + try: + print("Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] + if url[:5] == "https": + url = url.replace("https:", "http:") + print( + "Failed download. Trying https -> http instead." + " Downloading " + url + " to " + fpath + ) + _urlretrieve(url, fpath) + else: + raise e + + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def download_and_extract_archive( + url: str, + download_root: str, + extract_root: Optional[str] = None, + filename: Optional[str] = None, + md5: Optional[str] = None, + remove_finished: bool = False, +) -> None: + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print("Extracting {} to {}".format(archive, extract_root)) + extract_archive(archive, extract_root, remove_finished) + + +def cache_url(url: str, cache_dir: str) -> str: + """ + This implementation downloads the remote resource and caches it locally. + The resource will only be downloaded if not previously requested. + """ + parsed_url = urlparse(url) + dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/"))) + makedir(dirname) + filename = url.split("/")[-1] + cached = os.path.join(dirname, filename) + with file_lock(cached): + if not os.path.isfile(cached): + logging.info(f"Downloading {url} to {cached} ...") + cached = download(url, dirname, filename=filename) + logging.info(f"URL {url} cached in {cached}") + return cached + + +# TODO (prigoyal): convert this into RAII-style API +def create_file_symlink(file1, file2): + """ + Simply create the symlinks for a given file1 to file2. + Useful during model checkpointing to symlinks to the + latest successful checkpoint. + """ + try: + if g_pathmgr.exists(file2): + g_pathmgr.rm(file2) + g_pathmgr.symlink(file1, file2) + except Exception as e: + logging.info(f"Could NOT create symlink. Error: {e}") + + +def save_file(data, filename, append_to_json=True, verbose=True): + """ + Common i/o utility to handle saving data to various file formats. + Supported: + .pkl, .pickle, .npy, .json + Specifically for .json, users have the option to either append (default) + or rewrite by passing in Boolean value to append_to_json. + """ + if verbose: + logging.info(f"Saving data to file: {filename}") + file_ext = os.path.splitext(filename)[1] + if file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "wb") as fopen: + pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL) + elif file_ext == ".npy": + with g_pathmgr.open(filename, "wb") as fopen: + np.save(fopen, data) + elif file_ext == ".json": + if append_to_json: + with g_pathmgr.open(filename, "a") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + else: + with g_pathmgr.open(filename, "w") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "w") as fopen: + dump = yaml.dump(data) + fopen.write(dump) + fopen.flush() + else: + raise Exception(f"Saving {file_ext} is not supported yet") + + if verbose: + logging.info(f"Saved data to file: {filename}") + + +def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False): + """ + Common i/o utility to handle loading data from various file formats. + Supported: + .pkl, .pickle, .npy, .json + For the npy files, we support reading the files in mmap_mode. + If the mmap_mode of reading is not successful, we load data without the + mmap_mode. + """ + if verbose: + logging.info(f"Loading data from file: {filename}") + + file_ext = os.path.splitext(filename)[1] + if file_ext == ".txt": + with g_pathmgr.open(filename, "r") as fopen: + data = fopen.readlines() + elif file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "rb") as fopen: + data = pickle.load(fopen, encoding="latin1") + elif file_ext == ".npy": + if mmap_mode: + try: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load( + fopen, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + except ValueError as e: + logging.info( + f"Could not mmap {filename}: {e}. Trying without g_pathmgr" + ) + data = np.load( + filename, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + logging.info("Successfully loaded without g_pathmgr") + except Exception: + logging.info("Could not mmap without g_pathmgr. Trying without mmap") + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + else: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + elif file_ext == ".json": + with g_pathmgr.open(filename, "r") as fopen: + data = json.load(fopen) + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "r") as fopen: + data = yaml.load(fopen, Loader=yaml.FullLoader) + elif file_ext == ".csv": + with g_pathmgr.open(filename, "r") as fopen: + data = pd.read_csv(fopen) + else: + raise Exception(f"Reading from {file_ext} is not supported yet") + return data + + +def abspath(resource_path: str): + """ + Make a path absolute, but take into account prefixes like + "http://" or "manifold://" + """ + regex = re.compile(r"^\w+://") + if regex.match(resource_path) is None: + return os.path.abspath(resource_path) + else: + return resource_path + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + logging.info(f"Error creating directory: {dir_path}") + return is_success + + +def is_url(input_url): + """ + Check if an input string is a url. look for http(s):// and ignoring the case + """ + is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None + return is_url + + +def cleanup_dir(dir): + """ + Utility for deleting a directory. Useful for cleaning the storage space + that contains various training artifacts like checkpoints, data etc. + """ + if os.path.exists(dir): + logging.info(f"Deleting directory: {dir}") + shutil.rmtree(dir) + logging.info(f"Deleted contents of directory: {dir}") + + +def get_file_size(filename): + """ + Given a file, get the size of file in MB + """ + size_in_mb = os.path.getsize(filename) / float(1024**2) + return size_in_mb diff --git a/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py new file mode 100644 index 0000000000000000000000000000000000000000..07ca21d805684d71593c8d738798822411bdecc6 --- /dev/null +++ b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py @@ -0,0 +1,89 @@ +# coding: utf-8 + +import sys +dataDir = '../../VQA' +sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir)) +from vqa import VQA +from vqaEvaluation.vqaEval import VQAEval +import matplotlib.pyplot as plt +import skimage.io as io +import json +import random +import os + +# set up file names and paths +versionType ='v2_' # this should be '' when using VQA v2.0 dataset +taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0 +dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0. +dataSubType ='train2014' +annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType) +quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType) +imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType) +resultType ='fake' +fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType'] + +# An example result json file has been provided in './Results' folder. + +[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \ +resultType, fileType) for fileType in fileTypes] + +# create vqa object and vqaRes object +vqa = VQA(annFile, quesFile) +vqaRes = vqa.loadRes(resFile, quesFile) + +# create vqaEval object by taking vqa and vqaRes +vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2 + +# evaluate results +""" +If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function +By default it uses all the question ids in annotation file +""" +vqaEval.evaluate() + +# print accuracies +print "\n" +print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall']) +print "Per Question Type Accuracy is the following:" +for quesType in vqaEval.accuracy['perQuestionType']: + print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType]) +print "\n" +print "Per Answer Type Accuracy is the following:" +for ansType in vqaEval.accuracy['perAnswerType']: + print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType]) +print "\n" +# demo how to use evalQA to retrieve low score result +evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy +if len(evals) > 0: + print 'ground truth answers' + randomEval = random.choice(evals) + randomAnn = vqa.loadQA(randomEval) + vqa.showQA(randomAnn) + + print '\n' + print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval]) + ann = vqaRes.loadQA(randomEval)[0] + print "Answer: %s\n" %(ann['answer']) + + imgId = randomAnn[0]['image_id'] + imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' + if os.path.isfile(imgDir + imgFilename): + I = io.imread(imgDir + imgFilename) + plt.imshow(I) + plt.axis('off') + plt.show() + +# plot accuracy for various question types +plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center') +plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10) +plt.title('Per Question Type Accuracy', fontsize=10) +plt.xlabel('Question Types', fontsize=10) +plt.ylabel('Accuracy', fontsize=10) +plt.show() + +# save evaluation results to ./Results folder +json.dump(vqaEval.accuracy, open(accuracyFile, 'w')) +json.dump(vqaEval.evalQA, open(evalQAFile, 'w')) +json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w')) +json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w')) + diff --git a/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..148424d7391f6c8e8070f6dd20f02e2ddb1899cc --- /dev/null +++ b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py @@ -0,0 +1 @@ +author='aagrawal' diff --git a/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py new file mode 100644 index 0000000000000000000000000000000000000000..8a656044433b08c3b3a7610e0d4f701c9f3f752a --- /dev/null +++ b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py @@ -0,0 +1,192 @@ +# coding=utf-8 + +__author__='aagrawal' + +import re +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). +import sys + + +class VQAEval: + def __init__(self, vqa, vqaRes, n=2): + self.n = n + self.accuracy = {} + self.evalQA = {} + self.evalQuesType = {} + self.evalAnsType = {} + self.vqa = vqa + self.vqaRes = vqaRes + self.params = {'question_id': vqa.getQuesIds()} + self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \ + "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \ + "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \ + "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \ + "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \ + "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \ + "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \ + "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \ + "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \ + "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \ + "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \ + "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \ + "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \ + "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \ + "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \ + "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \ + "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \ + "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \ + "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \ + "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \ + "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \ + "youll": "you'll", "youre": "you're", "youve": "you've"} + self.manualMap = { 'none': '0', + 'zero': '0', + 'one': '1', + 'two': '2', + 'three': '3', + 'four': '4', + 'five': '5', + 'six': '6', + 'seven': '7', + 'eight': '8', + 'nine': '9', + 'ten': '10' + } + self.articles = ['a', + 'an', + 'the' + ] + + + self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") + self.commaStrip = re.compile("(\d)(\,)(\d)") + self.punct = [';', r"/", '[', ']', '"', '{', '}', + '(', ')', '=', '+', '\\', '_', '-', + '>', '<', '@', '`', ',', '?', '!'] + + + def evaluate(self, quesIds=None): + if quesIds == None: + quesIds = [quesId for quesId in self.params['question_id']] + gts = {} + res = {} + for quesId in quesIds: + gts[quesId] = self.vqa.qa[quesId] + res[quesId] = self.vqaRes.qa[quesId] + + # ================================================= + # Compute accuracy + # ================================================= + accQA = [] + accQuesType = {} + accAnsType = {} + # print "computing accuracy" + step = 0 + for quesId in quesIds: + for ansDic in gts[quesId]['answers']: + ansDic['answer'] = ansDic['answer'].replace('\n', ' ') + ansDic['answer'] = ansDic['answer'].replace('\t', ' ') + ansDic['answer'] = ansDic['answer'].strip() + resAns = res[quesId]['answer'] + resAns = resAns.replace('\n', ' ') + resAns = resAns.replace('\t', ' ') + resAns = resAns.strip() + gtAcc = [] + gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']] + + if len(set(gtAnswers)) > 1: + for ansDic in gts[quesId]['answers']: + ansDic['answer'] = self.processPunctuation(ansDic['answer']) + ansDic['answer'] = self.processDigitArticle(ansDic['answer']) + resAns = self.processPunctuation(resAns) + resAns = self.processDigitArticle(resAns) + + for gtAnsDatum in gts[quesId]['answers']: + otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum] + matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()] + acc = min(1, float(len(matchingAns))/3) + gtAcc.append(acc) + quesType = gts[quesId]['question_type'] + ansType = gts[quesId]['answer_type'] + avgGTAcc = float(sum(gtAcc))/len(gtAcc) + accQA.append(avgGTAcc) + if quesType not in accQuesType: + accQuesType[quesType] = [] + accQuesType[quesType].append(avgGTAcc) + if ansType not in accAnsType: + accAnsType[ansType] = [] + accAnsType[ansType].append(avgGTAcc) + self.setEvalQA(quesId, avgGTAcc) + self.setEvalQuesType(quesId, quesType, avgGTAcc) + self.setEvalAnsType(quesId, ansType, avgGTAcc) + if step%100 == 0: + self.updateProgress(step/float(len(quesIds))) + step = step + 1 + + self.setAccuracy(accQA, accQuesType, accAnsType) + # print "Done computing accuracy" + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None): + outText = outText.replace(p, '') + else: + outText = outText.replace(p, ' ') + outText = self.periodStrip.sub("", + outText, + re.UNICODE) + return outText + + def processDigitArticle(self, inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = self.manualMap.setdefault(word, word) + if word not in self.articles: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in self.contractions: + outText[wordId] = self.contractions[word] + outText = ' '.join(outText) + return outText + + def setAccuracy(self, accQA, accQuesType, accAnsType): + self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n) + self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType} + self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType} + + def setEvalQA(self, quesId, acc): + self.evalQA[quesId] = round(100*acc, self.n) + + def setEvalQuesType(self, quesId, quesType, acc): + if quesType not in self.evalQuesType: + self.evalQuesType[quesType] = {} + self.evalQuesType[quesType][quesId] = round(100*acc, self.n) + + def setEvalAnsType(self, quesId, ansType, acc): + if ansType not in self.evalAnsType: + self.evalAnsType[ansType] = {} + self.evalAnsType[ansType][quesId] = round(100*acc, self.n) + + def updateProgress(self, progress): + barLength = 20 + status = "" + if isinstance(progress, int): + progress = float(progress) + if not isinstance(progress, float): + progress = 0 + status = "error: progress var must be float\r\n" + if progress < 0: + progress = 0 + status = "Halt...\r\n" + if progress >= 1: + progress = 1 + status = "Done...\r\n" + block = int(round(barLength*progress)) + text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status) + sys.stdout.write(text) + sys.stdout.flush() diff --git a/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py new file mode 100644 index 0000000000000000000000000000000000000000..406b59642a7c2c208b87b0222a299e48a5831eb1 --- /dev/null +++ b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py @@ -0,0 +1,73 @@ +# coding: utf-8 + +from vqaTools.vqa import VQA +import random +import skimage.io as io +import matplotlib.pyplot as plt +import os + +dataDir ='../../VQA' +versionType ='v2_' # this should be '' when using VQA v2.0 dataset +taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0 +dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0. +dataSubType ='train2014' +annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType) +quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType) +imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType) + +# initialize VQA api for QA annotations +vqa=VQA(annFile, quesFile) + +# load and display QA annotations for given question types +""" +All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder. +""" +annIds = vqa.getQuesIds(quesTypes='how many'); +anns = vqa.loadQA(annIds) +randomAnn = random.choice(anns) +vqa.showQA([randomAnn]) +imgId = randomAnn['image_id'] +imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' +if os.path.isfile(imgDir + imgFilename): + I = io.imread(imgDir + imgFilename) + plt.imshow(I) + plt.axis('off') + plt.show() + +# load and display QA annotations for given answer types +""" +ansTypes can be one of the following +yes/no +number +other +""" +annIds = vqa.getQuesIds(ansTypes='yes/no'); +anns = vqa.loadQA(annIds) +randomAnn = random.choice(anns) +vqa.showQA([randomAnn]) +imgId = randomAnn['image_id'] +imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' +if os.path.isfile(imgDir + imgFilename): + I = io.imread(imgDir + imgFilename) + plt.imshow(I) + plt.axis('off') + plt.show() + +# load and display QA annotations for given images +""" +Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[]) +Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types. +""" +ids = vqa.getImgIds() +annIds = vqa.getQuesIds(imgIds=random.sample(ids,5)); +anns = vqa.loadQA(annIds) +randomAnn = random.choice(anns) +vqa.showQA([randomAnn]) +imgId = randomAnn['image_id'] +imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' +if os.path.isfile(imgDir + imgFilename): + I = io.imread(imgDir + imgFilename) + plt.imshow(I) + plt.axis('off') + plt.show() + diff --git a/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..072d8d90cd261c19c62fa4624ca22471fe72abfd --- /dev/null +++ b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py @@ -0,0 +1 @@ +__author__ = 'aagrawal' diff --git a/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..4f769619fc64ce150d1a462d91ea29282f08104a --- /dev/null +++ b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py @@ -0,0 +1,179 @@ +__author__ = 'aagrawal' +__version__ = '0.9' + +# Interface for accessing the VQA dataset. + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). + +# The following functions are defined: +# VQA - VQA class that loads VQA annotation file and prepares data structures. +# getQuesIds - Get question ids that satisfy given filter conditions. +# getImgIds - Get image ids that satisfy given filter conditions. +# loadQA - Load questions and answers with the specified question ids. +# showQA - Display the specified questions and answers. +# loadRes - Load result file and create result object. + +# Help on each function can be accessed by: "help(COCO.function)" + +import json +import datetime +import copy + + +class VQA: + def __init__(self, annotation_file=None, question_file=None): + """ + Constructor of VQA helper class for reading and visualizing questions and answers. + :param annotation_file (str): location of VQA annotation file + :return: + """ + # load dataset + self.dataset = {} + self.questions = {} + self.qa = {} + self.qqa = {} + self.imgToQA = {} + if not annotation_file == None and not question_file == None: + # print 'loading VQA annotations and questions into memory...' + time_t = datetime.datetime.utcnow() + dataset = json.load(open(annotation_file, 'r')) + questions = json.load(open(question_file, 'r')) + # print datetime.datetime.utcnow() - time_t + self.dataset = dataset + self.questions = questions + self.createIndex() + + def createIndex(self): + imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']} + qa = {ann['question_id']: [] for ann in self.dataset['annotations']} + qqa = {ann['question_id']: [] for ann in self.dataset['annotations']} + for ann in self.dataset['annotations']: + imgToQA[ann['image_id']] += [ann] + qa[ann['question_id']] = ann + for ques in self.questions['questions']: + qqa[ques['question_id']] = ques + # print 'index created!' + + # create class members + self.qa = qa + self.qqa = qqa + self.imgToQA = imgToQA + + def info(self): + """ + Print information about the VQA annotation file. + :return: + """ + + # for key, value in self.datset['info'].items(): + # print '%s: %s'%(key, value) + + def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): + """ + Get question ids that satisfy given filter conditions. default skips that filter + :param imgIds (int array) : get question ids for given imgs + quesTypes (str array) : get question ids for given question types + ansTypes (str array) : get question ids for given answer types + :return: ids (int array) : integer array of question ids + """ + imgIds = imgIds if type(imgIds) == list else [imgIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset['annotations'] + else: + if not len(imgIds) == 0: + anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], []) + else: + anns = self.dataset['annotations'] + anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] + anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] + ids = [ann['question_id'] for ann in anns] + return ids + + def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): + """ + Get image ids that satisfy given filter conditions. default skips that filter + :param quesIds (int array) : get image ids for given question ids + quesTypes (str array) : get image ids for given question types + ansTypes (str array) : get image ids for given answer types + :return: ids (int array) : integer array of image ids + """ + quesIds = quesIds if type(quesIds) == list else [quesIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset['annotations'] + else: + if not len(quesIds) == 0: + anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], []) + else: + anns = self.dataset['annotations'] + anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] + anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] + ids = [ann['image_id'] for ann in anns] + return ids + + def loadQA(self, ids=[]): + """ + Load questions and answers with the specified question ids. + :param ids (int array) : integer ids specifying question ids + :return: qa (object array) : loaded qa objects + """ + if type(ids) == list: + return [self.qa[id] for id in ids] + elif type(ids) == int: + return [self.qa[ids]] + + def showQA(self, anns): + """ + Display the specified annotations. + :param anns (array of object): annotations to display + :return: None + """ + if len(anns) == 0: + return 0 + for ann in anns: + quesId = ann['question_id'] + print("Question: %s" % (self.qqa[quesId]['question'])) + for ans in ann['answers']: + print("Answer %d: %s" % (ans['answer_id'], ans['answer'])) + + def loadRes(self, resFile, quesFile): + """ + Load result file and return a result object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = VQA() + res.questions = json.load(open(quesFile)) + res.dataset['info'] = copy.deepcopy(self.questions['info']) + res.dataset['task_type'] = copy.deepcopy(self.questions['task_type']) + res.dataset['data_type'] = copy.deepcopy(self.questions['data_type']) + res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype']) + res.dataset['license'] = copy.deepcopy(self.questions['license']) + + # print 'Loading and preparing results... ' + time_t = datetime.datetime.utcnow() + anns = json.load(open(resFile)) + assert type(anns) == list, 'results is not an array of objects' + annsQuesIds = [ann['question_id'] for ann in anns] + assert set(annsQuesIds) == set(self.getQuesIds()), \ + 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' + for ann in anns: + quesId = ann['question_id'] + if res.dataset['task_type'] == 'Multiple Choice': + assert ann['answer'] in self.qqa[quesId][ + 'multiple_choices'], 'predicted answer is not one of the multiple choices' + qaAnn = self.qa[quesId] + ann['image_id'] = qaAnn['image_id'] + ann['question_type'] = qaAnn['question_type'] + ann['answer_type'] = qaAnn['answer_type'] + # print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()) + + res.dataset['annotations'] = anns + res.createIndex() + return res diff --git a/minigpt4/common/vqa_tools/VQA/README.md b/minigpt4/common/vqa_tools/VQA/README.md new file mode 100644 index 0000000000000000000000000000000000000000..439d59d4d7c761423ab7016ab8768105b2df6c35 --- /dev/null +++ b/minigpt4/common/vqa_tools/VQA/README.md @@ -0,0 +1,80 @@ +Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset. +=================== +## VQA v2.0 release ## +This release consists of +- Real + - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download)) + - 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing + - 4,437,570 answers for training and 2,143,540 answers for validation (10 per question) + +There is only one type of task +- Open-ended task + +## VQA v1.0 release ## +This release consists of +- Real + - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download)) + - 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image) + - 2,483,490 answers for training and 1,215,120 answers for validation (10 per question) +- Abstract + - 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images + - 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image) + - 600,000 answers for training and 300,000 answers for validation (10 per question) + +There are two types of tasks +- Open-ended task +- Multiple-choice task (18 choices per question) + +## Requirements ## +- python 2.7 +- scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation) +- matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation) + +## Files ## +./Questions +- For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder. +- For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html). +- Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below + - [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip) + - [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip) +- Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip). + +./Annotations +- For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder. +- For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html). +- Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below + - [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip) + - [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip) +- Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip). + +./Images +- For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders. +- For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders. + +./PythonHelperTools +- This directory contains the Python API to read and visualize the VQA dataset +- vqaDemo.py (demo script) +- vqaTools (API to read and visualize data) + +./PythonEvaluationTools +- This directory contains the Python evaluation code +- vqaEvalDemo.py (evaluation demo script) +- vqaEvaluation (evaluation code) + +./Results +- OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo) +- Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details. + +./QuestionTypes +- This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k. +- mscoco_question_types.txt +- abstract_v002_question_types.txt + +## References ## +- [VQA: Visual Question Answering](http://visualqa.org/) +- [Microsoft COCO](http://mscoco.org/) + +## Developers ## +- Aishwarya Agrawal (Virginia Tech) +- Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco). +- The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption). diff --git a/minigpt4/common/vqa_tools/__init__.py b/minigpt4/common/vqa_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b98da85428159ad0dcfab7685c080848ecf8c7b --- /dev/null +++ b/minigpt4/common/vqa_tools/__init__.py @@ -0,0 +1,8 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +__author__ = "aagrawal" diff --git a/minigpt4/common/vqa_tools/aokvqa/LICENSE b/minigpt4/common/vqa_tools/aokvqa/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..663d6758473aa081e00a05f6cccef39487dd49ba --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2022 Allen Institute for Artificial Intelligence + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/minigpt4/common/vqa_tools/aokvqa/README.md b/minigpt4/common/vqa_tools/aokvqa/README.md new file mode 100644 index 0000000000000000000000000000000000000000..21caefaa477e812181412127c333b38062220a59 --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/README.md @@ -0,0 +1,207 @@ +# A-OKVQA + +Official repository for **A-OKVQA: A Benchmark for Visual Question Answering using World Knowledge**. + +Links: [[Paper]](https://arxiv.org/abs/2206.01718) [[Website]](http://a-okvqa.allenai.org) [[Leaderboard]](https://leaderboard.allenai.org/a-okvqa/submissions/public) + +### Abstract + +The Visual Question Answering (VQA) task aspires to provide a meaningful testbed for the development of AI models that can jointly reason over visual and natural language inputs. Despite a proliferation of VQA datasets, this goal is hindered by a set of common limitations. These include a reliance on relatively simplistic questions that are repetitive in both concepts and linguistic structure, little world knowledge needed outside of the paired image, and limited reasoning required to arrive at the correct answer. We introduce A-OKVQA, a crowdsourced dataset composed of a diverse set of about 25K questions requiring a broad base of commonsense and world knowledge to answer. In contrast to the existing knowledge-based VQA datasets, the questions generally cannot be answered by simply querying a knowledge base, and instead require some form of commonsense reasoning about the scene depicted in the image. We demonstrate the potential of this new dataset through a detailed analysis of its contents and baseline performance measurements over a variety of state-of-the-art vision–language models. + +![dataset_web](https://user-images.githubusercontent.com/28768645/170799740-f0d9ea60-6aff-4322-98d5-cae8e05983f4.svg) + +
+ +#### Table of Contents + +- [Getting started](#getting-started) + * [Downloading the dataset](#downloading-the-dataset) +- [Evaluation & Leaderboard](#evaluation) +- [Codebase](#codebase) + * [Preparing data](#preparing-data) + * [Models and Predictions](#models-and-predictions) + +
+ +## Getting started + +```bash +git clone --single-branch --recurse-submodules https://github.com/allenai/aokvqa.git + +cd aokvqa +export PYTHONPATH=. + +conda env create --name aokvqa +conda activate aokvqa +``` + +### Downloading the dataset + +```bash +export AOKVQA_DIR=./datasets/aokvqa/ +mkdir -p ${AOKVQA_DIR} + +curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C ${AOKVQA_DIR} +``` + +
Downloading COCO 2017 + +```bash +export COCO_DIR=./datasets/coco/ +mkdir -p ${COCO_DIR} + +for split in train val test; do + wget "http://images.cocodataset.org/zips/${split}2017.zip" + unzip "${split}2017.zip" -d ${COCO_DIR}; rm "${split}2017.zip" +done + +wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip +unzip annotations_trainval2017.zip -d ${COCO_DIR}; rm annotations_trainval2017.zip +``` + +
+ +Loading our dataset is easy! Just grab our [load_aokvqa.py](https://github.com/allenai/aokvqa/blob/main/load_aokvqa.py) file and refer to the following code. + +```python +import os +aokvqa_dir = os.getenv('AOKVQA_DIR') + +from load_aokvqa import load_aokvqa, get_coco_path +train_dataset = load_aokvqa(aokvqa_dir, 'train') # also 'val' or 'test' +``` + +
Example dataset entry + +```python +dataset_example = train_dataset[0] + +print(dataset_example['question_id']) +# 22MexNkBPpdZGX6sxbxVBH + +coco_dir = os.getenv('COCO_DIR') +image_path = get_coco_path('train', dataset_example['image_id'], coco_dir) +print(image_path) +# ./datasets/coco/train2017/000000299207.jpg + +print(dataset_example['question']) +print(dataset_example['choices']) +# What is the man by the bags awaiting? +# ['skateboarder', 'train', 'delivery', 'cab'] + +correct_choice = dataset_example['choices'][ dataset_example['correct_choice_idx'] ] +# Corrrect: cab + +print(dataset_example['rationales'][0]) +# A train would not be on the street, he would not have luggage waiting for a delivery, and the skateboarder is there and not paying attention to him so a cab is the only possible answer. +``` + +
+ +## Evaluation + +Please prepare `predictions_{split}.json` files (for `split: {val,test}`) in the format below. You may omit either `multiple_choice` or `direct_answer` field if you only want to evaluate one setting. + +```python +{ + '' : { + 'multiple_choice' : '', + 'direct_answer' : '' + } +} +``` + +You can run evaluation on the validation set as follows. + +```bash +python evaluation/eval_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --preds ./predictions_val.json +``` + +### Leaderboard + +You may submit `predictions_test.json` to the [leaderboard](https://leaderboard.allenai.org/a-okvqa/submissions/get-started). + +## Codebase + +We provide all code and pretrained models necessary to replicate our experiments for Large-Scale Pretrained Models (sec. 5.2) and Rationale Generation (sec. 5.3). + +### Preparing data + +```bash +export FEATURES_DIR=./features/ +mkdir -p ${FEATURES_DIR} +``` + +You can compute CLIP features for our vocabulary and dataset. These are most commonly used by our other experiments. + +```bash +python data_scripts/encode_vocab_clip.py --vocab ${AOKVQA_DIR}/large_vocab_train.csv --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt + +for split in train val test; do + python data_scripts/extract_clip_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_${split}.pt +done +``` + +
For training ClipCap with a transformer mapping network + +If you want to train our ClipCap models with the transformer mapping network (instead of an MLP, like we do), you'll also need to run `extract_clip_features.py` with `--model-type RN50x4`. + +
+ +
For ResNet and BERT input features + +Our ResNet and BERT classification experiments require these respective features instead of CLIP. To generate these, please run the following commands: + +```bash +# ResNet +for split in train val test; do + python data_scripts/extract_resnet_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --out ${FEATURES_DIR}/resnet_${split}.pt +done + +# BERT +for split in train val test; do + python data_scripts/extract_bert_features.py --aokvqa-dir ${AOKVQA_DIR} --split ${split} --out ${FEATURES_DIR}/bert_${split}.pt +done +``` + +
+ +### Models and Predictions + +```bash +export LOG_DIR=./logs/ +export PREDS_DIR=./predictions/ +export PT_MODEL_DIR=./pretrained_models/ +mkdir -p ${LOG_DIR} ${PREDS_DIR} ${PT_MODEL_DIR} +``` + +
Download our pretrained model weights + +```bash +# Checkpoints for transfer learning experiments +curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/transfer_exp_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models + +# Checkpoints for ClipCap models (generating answers and rationales) +curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/clipcap_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models +``` + +
+ +We have included instructions for replicating each of our experiments (see README.md files below). + +All Python scripts should be run from the root of this repository. Please be sure to first run the installation and data preparation as directed above. + +- [Heuristics](./heuristics/README.md) +- [Transfer Learning Experiments](./transfer_experiments/README.md) +- [Querying GPT-3](./gpt3/README.md) +- [ClipCap](https://github.com/allenai/aokvqa/blob/ClipCap/README.md) +- [Generating Captions & Rationales](https://github.com/allenai/aokvqa/blob/ClipCap/README.md) + +For each experiment, we follow this prediction file naming scheme: `{model-name}_{split}-{setting}.json` (e.g. `random-weighted_val-mc.json` or `random-weighted_test-da.json`). As examples in these Readme files, we produce predictions on the validation set. + +We unify predictions for each split before evaluation. (You can omit one of `--mc` or `--da` prediction file if you only want to evaluate one setting.) + +```bash +python evaluation/prepare_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc ./predictions_val-mc.json --da ./predictions_val-da.json --out ./predictions_val.json +# repeat for test split ... +``` diff --git a/minigpt4/common/vqa_tools/aokvqa/data_scripts/build_vocab.py b/minigpt4/common/vqa_tools/aokvqa/data_scripts/build_vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..2c446867c75f102dce322767f8acba0e9ac4d9eb --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/data_scripts/build_vocab.py @@ -0,0 +1,45 @@ +import os +import argparse +from collections import Counter +import pathlib + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file') +args = parser.parse_args() + + +# Build vocab from train set: correct choices + (direct answers appearing in >= 3 ) + +train_set = load_aokvqa(args.aokvqa_dir, 'train') + +vocab = [] +all_choices = Counter() +direct_answers = Counter() + +for i in train_set: + vocab.append( i['choices'][i['correct_choice_idx']] ) + all_choices.update(i['choices']) + direct_answers.update(set(i['direct_answers'])) +vocab += [k for k,v in all_choices.items() if v >= 3] +vocab += [k for k,v in direct_answers.items() if v >= 3] + +vocab = sorted(set(vocab)) +print(f"Vocab size: {len(vocab)}") + +# Save vocabulary Output + +with open(args.output_file, 'w') as f: + for v in vocab: + print(v, file=f) + +## Check validation set coverage + +val_set = load_aokvqa(args.aokvqa_dir, 'val') + +val_acc = [v['choices'][v['correct_choice_idx']] in vocab for v in val_set] +val_acc = sum(val_acc) / len(val_acc) * 100 +print(f"Val set coverage: {val_acc:.2f}" ) diff --git a/minigpt4/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py b/minigpt4/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..1dce7604d02edca32bf8a0b36e2966bdadb1527a --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py @@ -0,0 +1,26 @@ +import json +from tqdm import tqdm +import argparse +import pathlib + +import torch +import clip + +parser = argparse.ArgumentParser() +parser.add_argument('--vocab', type=pathlib.Path, required=True, dest='vocab_file') +parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type') +parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file') +args = parser.parse_args() + +assert args.output_file.suffix == '.pt' + +device = "cuda" if torch.cuda.is_available() else "cpu" +model, preprocess = clip.load(args.model_type, device=device) + +with torch.no_grad(): + a = open(args.vocab_file).read().splitlines() + mc_text = clip.tokenize(a).to(device) + mc_text_features = torch.stack([model.encode_text(mct.unsqueeze(0)).cpu() for mct in tqdm(mc_text)], dim=1)[0] + mc_text_features = mc_text_features.float() + model_name = args.model_type.replace('/', '-').replace('@', '-') + torch.save(mc_text_features, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py new file mode 100644 index 0000000000000000000000000000000000000000..60cd40f501f591bd1939d7c85ec2d345b6d8e29f --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py @@ -0,0 +1,50 @@ +import os +import argparse +import pathlib +from tqdm import tqdm + +import torch +from transformers import AutoTokenizer, AutoModel + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file') +args = parser.parse_args() + +assert args.output_file.suffix == '.pt' + +## Load dataset + +dataset = load_aokvqa(args.aokvqa_dir, args.split) + +## Load model + +tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens') +model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens') +device = "cuda" if torch.cuda.is_available() else "cpu" +model = model.to(device) +model.eval() + +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] # First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + +## Encoding loop + +with torch.no_grad(): + embeddings = {} + + for d in tqdm(dataset): + encoded_input = tokenizer([d['question']], padding=True, return_tensors='pt') + encoded_input = {k:v.to(device) for k,v in encoded_input.items()} + e = mean_pooling(model(**encoded_input), encoded_input['attention_mask']) + embeddings[d['question_id']] = { + 'question' : e[0].cpu() + } + + torch.save(embeddings, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py new file mode 100644 index 0000000000000000000000000000000000000000..20d0455e76fb7285c5ef838cdfde1a0000bdcb63 --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py @@ -0,0 +1,51 @@ +import os +from PIL import Image +from tqdm import tqdm +import argparse +import pathlib + +import torch +import clip + +from load_aokvqa import load_aokvqa, get_coco_path + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type') +parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file') +args = parser.parse_args() + +assert args.output_file.suffix == '.pt' + +## Load dataset + +dataset = load_aokvqa(args.aokvqa_dir, args.split) + +## Load model + +device = "cuda" if torch.cuda.is_available() else "cpu" +model, preprocess = clip.load(args.model_type, device=device) + +## Encoding loop + +with torch.no_grad(): + embeddings = {} + + for d in tqdm(dataset): + q = d["question"] + q_text = clip.tokenize(q).to(device) + q_text_features = model.encode_text(q_text) + + img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir)) + img = preprocess(img).unsqueeze(0).to(device) + image_features = model.encode_image(img) + + embeddings[d['question_id']] = { + 'question' : q_text_features[0].float().cpu(), + 'image' : image_features[0].float().cpu(), + } + + torch.save(embeddings, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7277bfd12801545f1b052d9120f09d7ae0cdb9 --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py @@ -0,0 +1,62 @@ +import os +import argparse +import pathlib +from tqdm import tqdm +from PIL import Image + +import torch +import torch.nn as nn +from torchvision import models +from torchvision import transforms as T + +from load_aokvqa import load_aokvqa, get_coco_path + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file') +args = parser.parse_args() + +assert args.output_file.suffix == '.pt' + +## Load dataset + +dataset = load_aokvqa(args.aokvqa_dir, args.split) + +## Load model + +resnet_preprocess = T.Compose([ + T.Resize(size=224, interpolation=T.InterpolationMode.BICUBIC), + T.CenterCrop(size=(224, 224)), + T.ToTensor(), + T.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) +]) + +device = "cuda" if torch.cuda.is_available() else "cpu" + +resnet_model = models.resnet50(pretrained=True) +resnet_model = torch.nn.Sequential( + *list(resnet_model.children())[:-1], + nn.Flatten() +) # strip classification layer +resnet_model = resnet_model.to(device) + +## Encoding loop + +with torch.no_grad(): + embeddings = {} + + for d in tqdm(dataset): + img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir)).convert('RGB') + resnet_input = resnet_preprocess(img).unsqueeze(0).to(device) + resnet_features = resnet_model(resnet_input) + embeddings[d['question_id']] = { + 'image' : resnet_features[0].cpu() + } + + torch.save(embeddings, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/environment.yml b/minigpt4/common/vqa_tools/aokvqa/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..58284ec46731e1bc68856c13b9f6101d34c03439 --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/environment.yml @@ -0,0 +1,36 @@ +name: aokvqa +channels: + - pytorch + - nvidia + - huggingface + - conda-forge + - defaults +dependencies: + - python=3.7 + - cudatoolkit=11.3 + - numpy=1.21.6 + - pytorch=1.11.0 + - torchvision=0.12.0 + - pytorch-lightning=1.6.3 + - torchmetrics=0.8.1 + - gdown=4.4.0 + - pip=22.0.4 + - pip: + - argparse==1.4.0 + - Pillow==9.0.1 + - tensorboard==2.9.0 + - ftfy==6.1.1 + - regex==2022.3.15 + - tqdm==4.64.0 + - clip @ git+https://github.com/openai/CLIP.git@b46f5ac7587d2e1862f8b7b1573179d80dcdd620 + - openai==0.18.1 + - nltk==3.7 + - sacrebleu==2.0.0 + - sacremoses==0.0.53 + - sentence-transformers==2.2.0 + - datasets==2.1.0 + - tokenizers==0.10.3 + - transformers==4.10.3 + +# Next: resolve conflict between sentence-transfomers and pytorch-lightning +# pip uninstall sentencepiece diff --git a/minigpt4/common/vqa_tools/aokvqa/evaluation/eval_predictions.py b/minigpt4/common/vqa_tools/aokvqa/evaluation/eval_predictions.py new file mode 100644 index 0000000000000000000000000000000000000000..a7b5dbe6f66849ff503177ab7e6c38ae20f5a34b --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/evaluation/eval_predictions.py @@ -0,0 +1,97 @@ +import argparse +import pathlib +import json +import glob + +from load_aokvqa import load_aokvqa + + +def eval_aokvqa(dataset, preds, multiple_choice=False, strict=True): + + if isinstance(dataset, list): + dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) } + + if multiple_choice is False: + dataset = {k:v for k,v in dataset.items() if v['difficult_direct_answer'] is False} + + if strict: + dataset_qids = set(dataset.keys()) + preds_qids = set(preds.keys()) + assert dataset_qids.issubset(preds_qids) + + # dataset = q_id (str) : dataset element (dict) + # preds = q_id (str) : prediction (str) + + acc = [] + + for q in dataset.keys(): + if q not in preds.keys(): + acc.append(0.0) + continue + + pred = preds[q] + choices = dataset[q]['choices'] + direct_answers = dataset[q]['direct_answers'] + + ## Multiple Choice setting + if multiple_choice: + if strict: + assert pred in choices, 'Prediction must be a valid choice' + correct_choice_idx = dataset[q]['correct_choice_idx'] + acc.append( float(pred == choices[correct_choice_idx]) ) + ## Direct Answer setting + else: + num_match = sum([pred.lower() == da.lower() for da in direct_answers]) + vqa_acc = min(1.0, num_match / 3.0) + acc.append(vqa_acc) + + acc = sum(acc) / len(acc) * 100 + + return acc + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') + parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) + parser.add_argument('--preds', type=str, required=True, dest='prediction_files') + args = parser.parse_args() + + dataset = load_aokvqa(args.aokvqa_dir, args.split) + + for prediction_file in glob.glob(args.prediction_files): + predictions = json.load(open(prediction_file, 'r')) + + # Multiple choice + + mc_predictions = {} + + for q in predictions.keys(): + if 'multiple_choice' in predictions[q].keys(): + mc_predictions[q] = predictions[q]['multiple_choice'] + + if mc_predictions != {}: + mc_acc = eval_aokvqa( + dataset, + mc_predictions, + multiple_choice=True, + strict=False + ) + print(prediction_file, 'MC', mc_acc) + + # Direct Answer + + da_predictions = {} + + for q in predictions.keys(): + if 'direct_answer' in predictions[q].keys(): + da_predictions[q] = predictions[q]['direct_answer'] + + if da_predictions != {}: + da_acc = eval_aokvqa( + dataset, + da_predictions, + multiple_choice=False, + strict=False + ) + print(prediction_file, 'DA', da_acc) diff --git a/minigpt4/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py b/minigpt4/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3dd49c668e56a7e306e1f15d7f73ad32fa31ac --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py @@ -0,0 +1,13 @@ +import os +import json + + +def load_aokvqa(aokvqa_dir, split, version='v1p0'): + assert split in ['train', 'val', 'test', 'test_w_ans'] + dataset = json.load(open( + os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json") + )) + return dataset + +def get_coco_path(split, image_id, coco_dir): + return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg") diff --git a/minigpt4/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py b/minigpt4/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py new file mode 100644 index 0000000000000000000000000000000000000000..202f00c0f14904483146187116c7ac78c75c1a6c --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py @@ -0,0 +1,31 @@ +import argparse +import pathlib +import json + +from load_aokvqa import load_aokvqa + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') + parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) + parser.add_argument('--mc', type=argparse.FileType('r'), dest='mc_pred_file') + parser.add_argument('--da', type=argparse.FileType('r'), dest='da_pred_file') + parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file') + args = parser.parse_args() + assert args.mc_pred_file or args.da_pred_file + + dataset = load_aokvqa(args.aokvqa_dir, args.split) + mc_preds = json.load(args.mc_pred_file) if args.mc_pred_file else None + da_preds = json.load(args.da_pred_file) if args.da_pred_file else None + predictions = {} + + for d in dataset: + q = d['question_id'] + predictions[q] = {} + if mc_preds and q in mc_preds.keys(): + predictions[q]['multiple_choice'] = mc_preds[q] + if da_preds and q in da_preds.keys(): + predictions[q]['direct_answer'] = da_preds[q] + + json.dump(predictions, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/evaluation/remap_predictions.py b/minigpt4/common/vqa_tools/aokvqa/evaluation/remap_predictions.py new file mode 100644 index 0000000000000000000000000000000000000000..40ba155d5fc8bbc3b8d0a1cfdd00c43114626258 --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/evaluation/remap_predictions.py @@ -0,0 +1,44 @@ +import argparse +import pathlib +import json +from tqdm import tqdm + +from sentence_transformers import SentenceTransformer +from sentence_transformers.util import cos_sim + +from load_aokvqa import load_aokvqa + + +def map_to_choices(dataset, predictions, device='cpu'): + if isinstance(dataset, list): + dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) } + + if all([p in dataset[q]['choices'] for q, p in predictions.items()]): + return predictions + + model = SentenceTransformer('sentence-transformers/average_word_embeddings_glove.6B.300d') + model.to(device) + for q in tqdm(predictions.keys()): + choices = dataset[q]['choices'] + if predictions[q] not in choices: + choice_embeddings = model.encode([predictions[q]] + choices, convert_to_tensor=True) + a_idx = cos_sim(choice_embeddings[0], choice_embeddings[1:]).argmax().item() + predictions[q] = choices[a_idx] + + return predictions + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') + parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) + parser.add_argument('--pred', type=argparse.FileType('r'), required=True, dest='prediction_file') + parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') + args = parser.parse_args() + + + dataset = load_aokvqa(args.aokvqa_dir, args.split) + predictions = json.load(args.prediction_file) + predictions = map_to_choices(dataset, predictions) + + json.dump(predictions, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/gpt3/README.md b/minigpt4/common/vqa_tools/aokvqa/gpt3/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fc1fd6bb66f6f660a6bb0ae9b7904425c216f41a --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/gpt3/README.md @@ -0,0 +1,14 @@ +## Querying GPT-3 + +To follow our experiments which use GPT-3, you must have access to the [OpenAI API](https://openai.com/api/) (at cost). Please retrieve your [organization](https://beta.openai.com/account/org-settings) and [API](https://beta.openai.com/account/api-keys) keys and set them in your environment variables. + +```bash +export OPENAI_ORG=.... +export OPENAI_API_KEY=... +``` + +For producing predictions for both DA and MC settings, run: +```bash +python gpt3/query_gpt3.py --aokvqa-dir ${AOKVQA_DIR} --split val --out ${PREDS_DIR}/gpt3_val-da.json +python remap_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --pred ${PREDS_DIR}/gpt3_val-da.json --out ${PREDS_DIR}/gpt3_val-mc.json +``` diff --git a/minigpt4/common/vqa_tools/aokvqa/gpt3/caption_inputs.py b/minigpt4/common/vqa_tools/aokvqa/gpt3/caption_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..21174341f137aa10f9f9667c89f52613458ec3bb --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/gpt3/caption_inputs.py @@ -0,0 +1,23 @@ +import os +import json +import argparse +import pathlib + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir') +parser.add_argument('--split', type=str, choices=['train', 'val'], required=True) +parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') +args = parser.parse_args() + +aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split) + +coco_captions = json.load(open(os.path.join(args.coco_dir, 'annotations', f'captions_{args.split}2017.json')))['annotations'] +coco_captions = {c['image_id'] : c['caption'] for c in coco_captions} + +captions = { d['question_id'] : coco_captions[d['image_id']] for d in aokvqa_set } + +json.dump(captions, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/gpt3/query_gpt3.py b/minigpt4/common/vqa_tools/aokvqa/gpt3/query_gpt3.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0890097500c9521af6bee85d7c0a3abd7c67c2 --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/gpt3/query_gpt3.py @@ -0,0 +1,79 @@ +import os +import random +import json +from tqdm import tqdm +import argparse +import pathlib + +import openai +openai.organization = os.getenv('OPENAI_ORG') +openai.api_key = os.getenv('OPENAI_API_KEY') + +from load_aokvqa import load_aokvqa + + +random.seed(0) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') + parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) + parser.add_argument('--n', type=int, default=10, dest='num_examples') + parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file') + parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix') + parser.add_argument('--include-choices', action='store_true', dest='include_choices') + parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file') + parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') + args = parser.parse_args() + + + train_set = load_aokvqa(args.aokvqa_dir, 'train') + eval_set = load_aokvqa(args.aokvqa_dir, args.split) + + train_context = {} + context = {} + if args.context_file is not None: + train_context = json.load(args.train_context_file) + context = json.load(args.context_file) + + predictions = {} + + for d in tqdm(eval_set): + q = d['question_id'] + + prompt = args.prompt_prefix + for e in random.sample(train_set, args.num_examples): + prompt += prompt_element(e, + context=train_context.get(q, None), + include_choices=args.include_choices, + answer=True + ) + prompt += '\n\n' + + prompt += prompt_element(d, + context=context.get(q, None), + include_choices=args.include_choices, + answer=False + ) + + response = openai.Completion.create( + engine="text-curie-001", + prompt=prompt, + temperature=0.0, + max_tokens=10, + ) + + predictions[q] = response.choices[0].text.strip() + + json.dump(predictions, args.output_file) + + +def prompt_element(d, context=None, include_choices=False, answer=False): + return (f"Context: {context}\n" if context is not None else '') + \ + f"Q: {d['question']}\n" + \ + (f"Choices: {', '.join(d['choices'])}.\n" if include_choices else '') + \ + f"A:" + (f" {d['choices'][d['correct_choice_idx']]}" if answer else '') + +if __name__ == '__main__': + main() diff --git a/minigpt4/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py b/minigpt4/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..411d1eeb72b5a67419239e0a9b4a31dff7257ada --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py @@ -0,0 +1,16 @@ +import json +import argparse +import pathlib + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test_w_ans'], required=True) +parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') +args = parser.parse_args() + +aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split) +rationales = {d['question_id'] : d['rationales'][0] for d in aokvqa_set} +json.dump(rationales, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/heuristics/README.md b/minigpt4/common/vqa_tools/aokvqa/heuristics/README.md new file mode 100644 index 0000000000000000000000000000000000000000..67c8632ec3bc8a92c631e29072b44f67083a40f0 --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/heuristics/README.md @@ -0,0 +1,11 @@ +## Heuristics + +```bash +# These scripts accept the same arguments. +# heuristics/random_unweighted.py +# heuristics/random_weighted.py +# heuristics/most_common_answer.py + +python heuristics/random_unweighted.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc --out ${PREDS_DIR}/random-unweighted_val-mc.json +# Exclude --mc for the direct answer setting +``` diff --git a/minigpt4/common/vqa_tools/aokvqa/heuristics/most_common_answer.py b/minigpt4/common/vqa_tools/aokvqa/heuristics/most_common_answer.py new file mode 100644 index 0000000000000000000000000000000000000000..59a27bc410e306f502a8b6b0d0e15255cbbfd45f --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/heuristics/most_common_answer.py @@ -0,0 +1,39 @@ +import os +import json +import argparse +import pathlib +from collections import Counter + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--mc', action='store_true', dest='multiple_choice') +parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') +args = parser.parse_args() + + +train_set = load_aokvqa(args.aokvqa_dir, 'train') +train_freq = dict(Counter( + [d['choices'][d['correct_choice_idx']] for d in train_set] +)) +most_common_answer = max(train_freq.keys(), key=train_freq.get) + +## + +eval_set = load_aokvqa(args.aokvqa_dir, args.split) + +predictions = {} + +for d in eval_set: + q = d['question_id'] + predictions[q] = most_common_answer + + if args.multiple_choice: + choices = [c for c in d['choices'] if c in train_freq.keys()] + if len(choices) > 0: + predictions[q] = max(choices, key=train_freq.get) + +json.dump(predictions, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/heuristics/random_unweighted.py b/minigpt4/common/vqa_tools/aokvqa/heuristics/random_unweighted.py new file mode 100644 index 0000000000000000000000000000000000000000..cfcf900f9ef785db6b23409ecdc71e8859730f75 --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/heuristics/random_unweighted.py @@ -0,0 +1,38 @@ +import os +import json +from random import seed, sample +import argparse +import pathlib + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--mc', action='store_true', dest='multiple_choice') +parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') +args = parser.parse_args() + +seed(0) + +train_set = load_aokvqa(args.aokvqa_dir, 'train') + +if args.multiple_choice is False: + choices = list(set( + [d['choices'][d['correct_choice_idx']] for d in train_set] + )) + +## + +predictions = {} + +eval_set = load_aokvqa(args.aokvqa_dir, args.split) + +for d in eval_set: + q = d['question_id'] + if args.multiple_choice: + choices = d['choices'] + predictions[q] = sample(choices, 1)[0] + +json.dump(predictions, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/heuristics/random_weighted.py b/minigpt4/common/vqa_tools/aokvqa/heuristics/random_weighted.py new file mode 100644 index 0000000000000000000000000000000000000000..2ccfa614a3dcffd75427381e6eccaba3be2987d6 --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/heuristics/random_weighted.py @@ -0,0 +1,46 @@ +import os +import json +import numpy as np +import argparse +import pathlib +from collections import Counter + +from load_aokvqa import load_aokvqa + + +parser = argparse.ArgumentParser() +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--mc', action='store_true', dest='multiple_choice') +parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') +args = parser.parse_args() + +np.random.seed(0) + +train_set = load_aokvqa(args.aokvqa_dir, 'train') +train_freq = dict(Counter( + [d['choices'][d['correct_choice_idx']] for d in train_set] +)) + +if args.multiple_choice is False: + choices = list(train_freq.keys()) + probs = [f / len(train_set) for f in train_freq.values()] + +## + +predictions = {} + +eval_set = load_aokvqa(args.aokvqa_dir, args.split) + +for d in eval_set: + if args.multiple_choice: + choices = d['choices'] + probs = [train_freq.get(c, 0) for c in choices] + if probs == [0, 0, 0, 0]: + probs = [1, 1, 1, 1] + probs = [p / sum(probs) for p in probs] + + q = d['question_id'] + predictions[q] = np.random.choice(choices, size=1, p=probs)[0] + +json.dump(predictions, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/load_aokvqa.py b/minigpt4/common/vqa_tools/aokvqa/load_aokvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3dd49c668e56a7e306e1f15d7f73ad32fa31ac --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/load_aokvqa.py @@ -0,0 +1,13 @@ +import os +import json + + +def load_aokvqa(aokvqa_dir, split, version='v1p0'): + assert split in ['train', 'val', 'test', 'test_w_ans'] + dataset = json.load(open( + os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json") + )) + return dataset + +def get_coco_path(split, image_id, coco_dir): + return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg") diff --git a/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/README.md b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dc5138d297ced13a2d631968105431bdb624d14c --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/README.md @@ -0,0 +1,41 @@ +## Transfer Learning Experiments + +We use the following training/prediction scripts for the classifier, zero-shot, and contrastive experiments in Table 3. + +```bash +## Training +python transfer_experiments/train.py --aokvqa-dir ${AOKVQA_DIR} --vocab ${AOKVQA_DIR}/large_vocab_train.csv --log-dir ${LOG_DIR} + +--backbone clip --clip-model-type ViT-B/32 --train-features ${FEATURES_DIR}/clip-ViT-B-32_train.pt --val-features ${FEATURES_DIR}/clip-ViT-B-32_val.pt +--inputs question # OR --inputs image # OR --inputs question image +# OR +--backbone resnet --train-features ${FEATURES_DIR}/resnet_train.pt --val-features ${FEATURES_DIR}/resnet_val.pt --inputs image +# OR +--backbone bert --train-features ${FEATURES_DIR}/bert_train.pt --val-features ${FEATURES_DIR}/bert_val.pt --inputs question + +--objective classifier +# OR +--objective contrastive --vocab-features ${FEATURE_DIR}/clip-ViT-B-32_large_vocab.pt +``` + +You can make predictions for CLIP zero-shot or from a classifier/contrastive checkpoint trained above. + +```bash +## Predicting +python transfer_experiments/predict.py --aokvqa-dir ${AOKVQA_DIR} --out ${PREDS_DIR}/clip-classifier_val-mc.json + +--split val # or test +--features ${FEATURE_DIR}/clip-ViT-B-32_val.pt # adjust for backbone and eval split + +--ckpt path/to/model.ckpt +# OR +--zero-shot --clip-model-type ViT-B/32 +--inputs question # OR --inputs image # OR --inputs question image + +--mc # Multiple-choice. Exclude for direct-answer. + +# IF classifier OR direct-answer +--vocab ${AOKVQA_DIR}/large_vocab_train.csv +# IF contrastive/zero-shot AND direct-answer +--vocab-features ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt +``` diff --git a/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/predict.py b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..d2fbb4272bcc3bcf5f0d4cc1a5860976fb3fd3ac --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/predict.py @@ -0,0 +1,126 @@ +import sys +import os +import argparse +import pathlib +from tqdm import tqdm +import json + +import torch +import torch.nn as nn + +# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663 +import sentencepiece; import pytorch_lightning as pl; import clip + +from transfer_experiments.train import LinearClassifier +from load_aokvqa import load_aokvqa +from evaluation.remap_predictions import map_to_choices + + +parser = argparse.ArgumentParser() +parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) +parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') +parser.add_argument('--features', type=pathlib.Path, required=True) +parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file') +# +parser_weights = parser.add_mutually_exclusive_group(required=True) + +parser_weights.add_argument('--ckpt', type=pathlib.Path, dest='checkpoint_path') + +parser_weights.add_argument('--zero-shot', action='store_true', dest='clip_zero_shot') +parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=('--zero-shot' in sys.argv)) +# +parser.add_argument('--vocab', type=argparse.FileType('r')) +parser.add_argument('--vocab-features', type=pathlib.Path, dest='vocab_features') +parser.add_argument('--mc', action='store_true', dest='multiple_choice') + +parser.add_argument('--clip-model-type', type=str, + choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], + dest='clip_model_type', required=('--zero-shot' in sys.argv and '--mc' in sys.argv)) +# +args = parser.parse_args() + + +## Load dataset + +aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split) + +## Load models + +device = "cuda" if torch.cuda.is_available() else "cpu" + +if args.checkpoint_path is not None: + classifier = LinearClassifier.load_from_checkpoint(args.checkpoint_path) + classifier.to(device) + hp = classifier.hparams +elif args.clip_zero_shot: + classifier = nn.Identity().to(device) + hp = pl.utilities.AttributeDict(backbone='clip', clip_model_type=args.clip_model_type, objective='zero-shot', inputs=args.inputs) + +# Load input features + +embeddings = torch.load(args.features) +if hp.backbone == 'clip': + for q in embeddings.keys(): + embeddings[q]['question'] = embeddings[q]['question'] / embeddings[q]['question'].norm(dim=-1, keepdim=True) + embeddings[q]['image'] = embeddings[q]['image'] / embeddings[q]['image'].norm(dim=-1, keepdim=True) + +# Load vocab, vocab features, clip + +if (hp.objective == 'classifier') or \ + (hp.objective in ['contrastive', 'zero-shot'] and args.multiple_choice is False): + vocab = args.vocab.read().splitlines() + +if hp.objective in ['contrastive', 'zero-shot']: + if args.multiple_choice is False: + vocab_features = torch.load(args.vocab_features).cpu() + vocab_features /= vocab_features.norm(dim=-1, keepdim=True) + else: + clip_model = clip.load(hp.clip_model_type, device=device)[0] + logit_scale = clip_model.logit_scale.exp().cpu() + +## Prediction loop + +predictions = {} + +with torch.no_grad(): + for o in tqdm(aokvqa_set): + q = o['question_id'] + + # Load input embedding (from question / image) + if hp.objective == 'zero-shot' and ('question' in hp.inputs and 'image' in hp.inputs): + e = embeddings[q]['question'] + embeddings[q]['image'] + elif 'question' in hp.inputs and 'image' in hp.inputs: + e = torch.cat((embeddings[q]['question'], embeddings[q]['image'])) + elif 'question' in hp.inputs: + e = embeddings[q]['question'] + elif 'image' in hp.inputs: + e = embeddings[q]['image'] + + # Pass inputs through model + e = e.unsqueeze(0).to(device) + x = classifier(e)[0].cpu() + + # Predict + if hp.objective in ['contrastive', 'zero-shot']: + if args.multiple_choice: + vocab = o['choices'] + # Encode choices + vocab_features = clip.tokenize(vocab).to(device) + vocab_features = torch.stack([ + clip_model.encode_text(v.unsqueeze(0)) for v in vocab_features + ], dim=1)[0] + vocab_features /= vocab_features.norm(dim=-1, keepdim=True) + vocab_features = vocab_features.float().cpu() + + x = logit_scale * x @ vocab_features.t() + x = x.softmax(dim=-1) + + predictions[q] = vocab[x.argmax().item()] + +## Save and evaluate predictions + +# Map prediction to nearest neighbor choice (by word embeddings) +if args.multiple_choice and hp.objective == 'classifier': + predictions = map_to_choices(aokvqa_set, predictions) + +json.dump(predictions, args.output_file) diff --git a/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/train.py b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/train.py new file mode 100644 index 0000000000000000000000000000000000000000..ac48b5ad7fbc72a063e187a9097441769abe954f --- /dev/null +++ b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/train.py @@ -0,0 +1,263 @@ +import os +import sys +import json +import argparse +import pathlib +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader + +# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663 +import sentencepiece; import pytorch_lightning as pl + +import torchmetrics.functional as MF + +from load_aokvqa import load_aokvqa + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') + parser.add_argument('--vocab', type=argparse.FileType('r'), required=True) + parser.add_argument('--log-dir', type=pathlib.Path, dest='log_dir', required=True) + # + parser.add_argument('--backbone', type=str, choices=['clip', 'resnet', 'bert'], required=True) + parser.add_argument('--clip-model-type', type=str, + choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], + dest='clip_model_type', required=('clip' in sys.argv)) + parser.add_argument('--train-features', type=pathlib.Path, required=True, dest='train_features') + parser.add_argument('--val-features', type=pathlib.Path, required=True, dest='val_features') + parser.add_argument('--vocab-features', type=pathlib.Path, required=('contrastive' in sys.argv), dest='vocab_features') + # + parser.add_argument('--objective', type=str, choices=['classifier', 'contrastive'], required=True) + parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=True) + # Defaults + parser.add_argument('--bs', type=int, default=128, dest='batch_size') + parser.add_argument('--lr', type=float, default=0.01) + parser.add_argument('--epochs', type=int, default=500) + parser.add_argument('--gpus', type=int, default=1) + args = parser.parse_args() + + pl.seed_everything(1) + vocab = args.vocab.read().splitlines() + + ## Data loading + + dm = AokvqaEmbeddingsDataModule( + args.aokvqa_dir, + args.train_features, + args.val_features, + args.objective, + args.backbone, + args.inputs, + vocab, + args.vocab_features, + batch_size=args.batch_size, + num_workers=16 + ) + + ## Model definition + + model = LinearClassifier( + args.objective, + args.backbone, + args.clip_model_type, + args.inputs, + len(vocab), + args.lr + ) + + ## Training and testing loops + + logger = pl.loggers.TensorBoardLogger( + args.log_dir, + name=f'{args.backbone}-{args.objective}', + version=f"inputs:{'+'.join(args.inputs)}" + ) + + trainer = pl.Trainer( + logger=logger, + gpus=args.gpus, + max_epochs=args.epochs, + callbacks=[ + pl.callbacks.ModelCheckpoint( + monitor="val_acc", + filename="{epoch:02d}-{val_acc:.2f}", + mode="max" + ) + ], + ) + + trainer.fit(model, dm) + + +class AokvqaEmbeddingsDataset(Dataset): + def __init__(self, aokvqa_dir, split, input_features, objective, backbone, inputs, vocab, vocab_features): + + aokvqa_set = load_aokvqa(aokvqa_dir, split) + + assert ( backbone == 'resnet' and inputs == ['image'] and objective == 'classifier' ) \ + or ( backbone == 'bert' and inputs == ['question'] and objective == 'classifier' ) \ + or ( backbone == 'clip' ) + + embeddings = torch.load(input_features) + if backbone == 'clip': + for q in embeddings.keys(): + embeddings[q]['question'] /= embeddings[q]['question'].norm(dim=-1, keepdim=True) + embeddings[q]['image'] /= embeddings[q]['image'].norm(dim=-1, keepdim=True) + if objective == 'contrastive': + vocab_embeddings = torch.load(vocab_features) + vocab_embeddings /= vocab_embeddings.norm(dim=-1, keepdim=True) + + self.objective = objective + self.vocab_len = len(vocab) + + self.embeddings = [] + self.answers = [] + + for o in aokvqa_set: + correct_answers = set([o['choices'][o['correct_choice_idx']]] + o['direct_answers']) + correct_answers = [vocab.index(a) for a in correct_answers if a in vocab] + if self.objective == 'contrastive': + correct_answers = [vocab_embeddings[a] for a in correct_answers] + if len(correct_answers) == 0: continue + self.answers.append(correct_answers) + + q = o['question_id'] + if 'question' in inputs and 'image' in inputs: + e = torch.cat((embeddings[q]['question'], embeddings[q]['image'])) + elif 'question' in inputs and 'image' not in inputs: + e = embeddings[q]['question'] + elif 'question' not in inputs and 'image' in inputs: + e = embeddings[q]['image'] + self.embeddings.append(e) + + def __getitem__(self, index): + e = self.embeddings[index] + a = self.answers[index] + if self.objective == 'classifier': + a = torch.sum(F.one_hot(torch.tensor(a), num_classes=self.vocab_len), dim=0) + elif self.objective == 'contrastive': + a = random.sample(a, 1)[0] + return e, a + + def __len__(self): + return len(self.embeddings) + + +class AokvqaEmbeddingsDataModule(pl.LightningDataModule): + + def __init__(self, aokvqa_dir, train_features, val_features, objective, backbone, inputs, vocab, vocab_features, batch_size=1, num_workers=0): + super().__init__() + self.aokvqa_dir = aokvqa_dir + self.train_features = train_features + self.val_features = val_features + self.objective = objective + self.backbone = backbone + self.inputs = inputs + self.vocab = vocab + self.vocab_features = vocab_features + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage=None): + self.train_dataset = AokvqaEmbeddingsDataset( + self.aokvqa_dir, 'train', self.train_features, self.objective, + self.backbone, self.inputs, self.vocab, self.vocab_features + ) + self.val_dataset = AokvqaEmbeddingsDataset( + self.aokvqa_dir, 'val', self.val_features, self.objective, + self.backbone, self.inputs, self.vocab, self.vocab_features + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, batch_size=self.batch_size, shuffle=True, + num_workers=int(0.8 * self.num_workers) + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, batch_size=self.batch_size, shuffle=False, + num_workers=int(0.2 * self.num_workers) + ) + + +class LinearClassifier(pl.LightningModule): + def __init__(self, objective, backbone, clip_model_type, inputs, vocab_len, lr=0.001): + super().__init__() + self.save_hyperparameters(ignore=['lr']) + self.lr = lr + + if self.hparams.backbone == 'clip': + clip_dim = { + 'RN50' : 1024, + 'RN50x4' : 640, + 'RN50x16' : 768, + 'RN50x64' : 1024, + 'RN101' : 512, + 'ViT-B/32' : 512, + 'ViT-B/16' : 512, + 'ViT-L/14' : 768, + 'ViT-L/14@336px' : 768, + }[clip_model_type] + emb_dim = clip_dim * len(inputs) + elif self.hparams.backbone == 'resnet': + emb_dim = 2048 + elif self.hparams.backbone == 'bert': + emb_dim = 768 + + if self.hparams.objective == 'classifier': + out_dim = vocab_len + elif self.hparams.objective == 'contrastive': + out_dim = clip_dim + + self.linear = nn.Linear(emb_dim, out_dim) + + def forward(self, x): + x = self.linear(x) + if self.hparams.objective == 'classifier': + x = torch.sigmoid(x) + return x + + def compute_loss(self, batch): + x, y = batch + + y_pred = self.forward(x) + + if self.hparams.objective == 'classifier': + loss = F.binary_cross_entropy(y_pred, y.float()) + elif self.hparams.objective == 'contrastive': + indices = torch.arange(0, x.shape[0], dtype=torch.int64, device=self.device) + sim = (y_pred @ y.T).softmax(dim=-1) + loss = F.cross_entropy(sim, indices) + + if self.hparams.objective == 'classifier': + acc = MF.f1_score(y_pred, y) + elif self.hparams.objective == 'contrastive': + acc = torch.mean(sim[indices, indices]) + + return loss, acc + + def training_step(self, batch, batch_idx): + loss, acc = self.compute_loss(batch) + self.log("train_loss", loss) + self.log("train_acc", acc) + return loss + + def validation_step(self, batch, batch_idx): + loss, acc = self.compute_loss(batch) + self.log("val_loss", loss) + self.log("val_acc", acc) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + +if __name__ == '__main__': + main() diff --git a/minigpt4/common/vqa_tools/vqa.py b/minigpt4/common/vqa_tools/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..a386b9094b0528b33e7511aff4027f30459a7ff7 --- /dev/null +++ b/minigpt4/common/vqa_tools/vqa.py @@ -0,0 +1,211 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +__author__ = "aagrawal" +__version__ = "0.9" + +# Interface for accessing the VQA dataset. + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). + +# The following functions are defined: +# VQA - VQA class that loads VQA annotation file and prepares data structures. +# getQuesIds - Get question ids that satisfy given filter conditions. +# getImgIds - Get image ids that satisfy given filter conditions. +# loadQA - Load questions and answers with the specified question ids. +# showQA - Display the specified questions and answers. +# loadRes - Load result file and create result object. + +# Help on each function can be accessed by: "help(COCO.function)" + +import json +import datetime +import copy + + +class VQA: + def __init__(self, annotation_file=None, question_file=None): + """ + Constructor of VQA helper class for reading and visualizing questions and answers. + :param annotation_file (str): location of VQA annotation file + :return: + """ + # load dataset + self.dataset = {} + self.questions = {} + self.qa = {} + self.qqa = {} + self.imgToQA = {} + if not annotation_file == None and not question_file == None: + print("loading VQA annotations and questions into memory...") + time_t = datetime.datetime.utcnow() + dataset = json.load(open(annotation_file, "r")) + questions = json.load(open(question_file, "r")) + self.dataset = dataset + self.questions = questions + self.createIndex() + + def createIndex(self): + # create index + print("creating index...") + imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]} + qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} + qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} + for ann in self.dataset["annotations"]: + imgToQA[ann["image_id"]] += [ann] + qa[ann["question_id"]] = ann + for ques in self.questions["questions"]: + qqa[ques["question_id"]] = ques + print("index created!") + + # create class members + self.qa = qa + self.qqa = qqa + self.imgToQA = imgToQA + + def info(self): + """ + Print information about the VQA annotation file. + :return: + """ + for key, value in self.datset["info"].items(): + print("%s: %s" % (key, value)) + + def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): + """ + Get question ids that satisfy given filter conditions. default skips that filter + :param imgIds (int array) : get question ids for given imgs + quesTypes (str array) : get question ids for given question types + ansTypes (str array) : get question ids for given answer types + :return: ids (int array) : integer array of question ids + """ + imgIds = imgIds if type(imgIds) == list else [imgIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset["annotations"] + else: + if not len(imgIds) == 0: + anns = sum( + [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], + [], + ) + else: + anns = self.dataset["annotations"] + anns = ( + anns + if len(quesTypes) == 0 + else [ann for ann in anns if ann["question_type"] in quesTypes] + ) + anns = ( + anns + if len(ansTypes) == 0 + else [ann for ann in anns if ann["answer_type"] in ansTypes] + ) + ids = [ann["question_id"] for ann in anns] + return ids + + def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): + """ + Get image ids that satisfy given filter conditions. default skips that filter + :param quesIds (int array) : get image ids for given question ids + quesTypes (str array) : get image ids for given question types + ansTypes (str array) : get image ids for given answer types + :return: ids (int array) : integer array of image ids + """ + quesIds = quesIds if type(quesIds) == list else [quesIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset["annotations"] + else: + if not len(quesIds) == 0: + anns = sum( + [self.qa[quesId] for quesId in quesIds if quesId in self.qa], [] + ) + else: + anns = self.dataset["annotations"] + anns = ( + anns + if len(quesTypes) == 0 + else [ann for ann in anns if ann["question_type"] in quesTypes] + ) + anns = ( + anns + if len(ansTypes) == 0 + else [ann for ann in anns if ann["answer_type"] in ansTypes] + ) + ids = [ann["image_id"] for ann in anns] + return ids + + def loadQA(self, ids=[]): + """ + Load questions and answers with the specified question ids. + :param ids (int array) : integer ids specifying question ids + :return: qa (object array) : loaded qa objects + """ + if type(ids) == list: + return [self.qa[id] for id in ids] + elif type(ids) == int: + return [self.qa[ids]] + + def showQA(self, anns): + """ + Display the specified annotations. + :param anns (array of object): annotations to display + :return: None + """ + if len(anns) == 0: + return 0 + for ann in anns: + quesId = ann["question_id"] + print("Question: %s" % (self.qqa[quesId]["question"])) + for ans in ann["answers"]: + print("Answer %d: %s" % (ans["answer_id"], ans["answer"])) + + def loadRes(self, resFile, quesFile): + """ + Load result file and return a result object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = VQA() + res.questions = json.load(open(quesFile)) + res.dataset["info"] = copy.deepcopy(self.questions["info"]) + res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"]) + res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"]) + res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"]) + res.dataset["license"] = copy.deepcopy(self.questions["license"]) + + print("Loading and preparing results... ") + time_t = datetime.datetime.utcnow() + anns = json.load(open(resFile)) + assert type(anns) == list, "results is not an array of objects" + annsQuesIds = [ann["question_id"] for ann in anns] + assert set(annsQuesIds) == set( + self.getQuesIds() + ), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file." + for ann in anns: + quesId = ann["question_id"] + if res.dataset["task_type"] == "Multiple Choice": + assert ( + ann["answer"] in self.qqa[quesId]["multiple_choices"] + ), "predicted answer is not one of the multiple choices" + qaAnn = self.qa[quesId] + ann["image_id"] = qaAnn["image_id"] + ann["question_type"] = qaAnn["question_type"] + ann["answer_type"] = qaAnn["answer_type"] + print( + "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds()) + ) + + res.dataset["annotations"] = anns + res.createIndex() + return res diff --git a/minigpt4/common/vqa_tools/vqa_eval.py b/minigpt4/common/vqa_tools/vqa_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ee808b349bb6166c744338b02af2bc84a68650ff --- /dev/null +++ b/minigpt4/common/vqa_tools/vqa_eval.py @@ -0,0 +1,324 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +# coding=utf-8 + +__author__ = "aagrawal" + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). +import sys +import re + + +class VQAEval: + def __init__(self, vqa=None, vqaRes=None, n=2): + self.n = n + self.accuracy = {} + self.evalQA = {} + self.evalQuesType = {} + self.evalAnsType = {} + self.vqa = vqa + self.vqaRes = vqaRes + if vqa is not None: + self.params = {"question_id": vqa.getQuesIds()} + self.contractions = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + } + self.manualMap = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + self.articles = ["a", "an", "the"] + + self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") + self.commaStrip = re.compile("(\d)(,)(\d)") + self.punct = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def evaluate(self, quesIds=None): + if quesIds == None: + quesIds = [quesId for quesId in self.params["question_id"]] + gts = {} + res = {} + for quesId in quesIds: + gts[quesId] = self.vqa.qa[quesId] + res[quesId] = self.vqaRes.qa[quesId] + + # ================================================= + # Compute accuracy + # ================================================= + accQA = [] + accQuesType = {} + accAnsType = {} + print("computing accuracy") + step = 0 + for quesId in quesIds: + resAns = res[quesId]["answer"] + resAns = resAns.replace("\n", " ") + resAns = resAns.replace("\t", " ") + resAns = resAns.strip() + resAns = self.processPunctuation(resAns) + resAns = self.processDigitArticle(resAns) + gtAcc = [] + gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]] + if len(set(gtAnswers)) > 1: + for ansDic in gts[quesId]["answers"]: + ansDic["answer"] = self.processPunctuation(ansDic["answer"]) + for gtAnsDatum in gts[quesId]["answers"]: + otherGTAns = [ + item for item in gts[quesId]["answers"] if item != gtAnsDatum + ] + matchingAns = [item for item in otherGTAns if item["answer"] == resAns] + acc = min(1, float(len(matchingAns)) / 3) + gtAcc.append(acc) + quesType = gts[quesId]["question_type"] + ansType = gts[quesId]["answer_type"] + avgGTAcc = float(sum(gtAcc)) / len(gtAcc) + accQA.append(avgGTAcc) + if quesType not in accQuesType: + accQuesType[quesType] = [] + accQuesType[quesType].append(avgGTAcc) + if ansType not in accAnsType: + accAnsType[ansType] = [] + accAnsType[ansType].append(avgGTAcc) + self.setEvalQA(quesId, avgGTAcc) + self.setEvalQuesType(quesId, quesType, avgGTAcc) + self.setEvalAnsType(quesId, ansType, avgGTAcc) + if step % 100 == 0: + self.updateProgress(step / float(len(quesIds))) + step = step + 1 + + self.setAccuracy(accQA, accQuesType, accAnsType) + print("Done computing accuracy") + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + " " in inText or " " + p in inText) or ( + re.search(self.commaStrip, inText) != None + ): + outText = outText.replace(p, "") + else: + outText = outText.replace(p, " ") + outText = self.periodStrip.sub("", outText, re.UNICODE) + return outText + + def processDigitArticle(self, inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = self.manualMap.setdefault(word, word) + if word not in self.articles: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in self.contractions: + outText[wordId] = self.contractions[word] + outText = " ".join(outText) + return outText + + def setAccuracy(self, accQA, accQuesType, accAnsType): + self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n) + self.accuracy["perQuestionType"] = { + quesType: round( + 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]), + self.n, + ) + for quesType in accQuesType + } + self.accuracy["perAnswerType"] = { + ansType: round( + 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n + ) + for ansType in accAnsType + } + + def setEvalQA(self, quesId, acc): + self.evalQA[quesId] = round(100 * acc, self.n) + + def setEvalQuesType(self, quesId, quesType, acc): + if quesType not in self.evalQuesType: + self.evalQuesType[quesType] = {} + self.evalQuesType[quesType][quesId] = round(100 * acc, self.n) + + def setEvalAnsType(self, quesId, ansType, acc): + if ansType not in self.evalAnsType: + self.evalAnsType[ansType] = {} + self.evalAnsType[ansType][quesId] = round(100 * acc, self.n) + + def updateProgress(self, progress): + barLength = 20 + status = "" + if isinstance(progress, int): + progress = float(progress) + if not isinstance(progress, float): + progress = 0 + status = "error: progress var must be float\r\n" + if progress < 0: + progress = 0 + status = "Halt...\r\n" + if progress >= 1: + progress = 1 + status = "Done...\r\n" + block = int(round(barLength * progress)) + text = "\rFinshed Percent: [{0}] {1}% {2}".format( + "#" * block + "-" * (barLength - block), int(progress * 100), status + ) + sys.stdout.write(text) + sys.stdout.flush() diff --git a/minigpt4/configs/datasets/cc_sbu/align.yaml b/minigpt4/configs/datasets/cc_sbu/align.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ba1bf9e8f577735108fe3d1d0559867378b6d88c --- /dev/null +++ b/minigpt4/configs/datasets/cc_sbu/align.yaml @@ -0,0 +1,5 @@ +datasets: + cc_sbu_align: + data_type: images + build_info: + storage: "/ibex/project/c2133/minigpt4_1/MiniGPT-4/minigpt4/configs/datasets/cc_sbu_align" diff --git a/minigpt4/configs/datasets/cc_sbu/defaults.yaml b/minigpt4/configs/datasets/cc_sbu/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7188033863a5cfd8710209d9bd490025e40ec39d --- /dev/null +++ b/minigpt4/configs/datasets/cc_sbu/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + cc_sbu: + data_type: images + build_info: + storage: /ibex/project/c2133/blip_dataset/cc3m_256/cc3m_cc12m_sbu/{00000..01255}.tar diff --git a/minigpt4/configs/datasets/cmd_video/default.yaml b/minigpt4/configs/datasets/cmd_video/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..306aa8b57d41546b364a69cbe3b1ab0f1c9a2716 --- /dev/null +++ b/minigpt4/configs/datasets/cmd_video/default.yaml @@ -0,0 +1,15 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + cmd_video: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + vis_root: /ibex/ai/reference/videos/CondensedMovies/data/images + ann_paths: [datasets/training_datasets/video_text_data/cmd/train.json] + cc_path: datasets/training_datasets/video_text_data/cmd/caption.json diff --git a/minigpt4/configs/datasets/laion/defaults.yaml b/minigpt4/configs/datasets/laion/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c19b90a01e693680431cc5af3ed16cbc75baf54c --- /dev/null +++ b/minigpt4/configs/datasets/laion/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + laion: + data_type: images + build_info: + storage: /ibex/project/c2133/blip_dataset/laion_1b/laion_gpu/{00000..10488}.tar diff --git a/minigpt4/configs/datasets/video_chatgpt/default.yaml b/minigpt4/configs/datasets/video_chatgpt/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a44d0a69c1c6bcab1a6f9bc133cb12aaef59c4e0 --- /dev/null +++ b/minigpt4/configs/datasets/video_chatgpt/default.yaml @@ -0,0 +1,20 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + video_chatgpt: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + ann_paths: [datasets/training_datasets/video_text_data/video_instruct_100/VideoInstruct100K.json] + vis_root: /ibex/project/c2090/datasets/VideoInstruct100K/ + valid: + ann_path: "datasets/video_text_data/validation/all_datasets_samples_val_qa.json" + videos_path: "/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/all_datasets_samples_val" + subtitles_path: "inference_subtitles" + annotations_keys: ['question','answer','video_id'] + add_subtitles: True \ No newline at end of file diff --git a/minigpt4/configs/datasets/webvid/default.yaml b/minigpt4/configs/datasets/webvid/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..98b0b775b3b34bc32383797baaa6619320f486e0 --- /dev/null +++ b/minigpt4/configs/datasets/webvid/default.yaml @@ -0,0 +1,15 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + webvid: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + ann_paths: [datasets/training_datasets/video_text_data/webvid/train.json] + vis_root: /ibex/ai/reference/videos/webvid/data/videos + subtitles_path: /ibex/project/c2090/datasets/Webvid/webvid_val_subtitles/ diff --git a/minigpt4/configs/default.yaml b/minigpt4/configs/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ff5a6a23fa2e3914938631b96c71fdf723dbbc10 --- /dev/null +++ b/minigpt4/configs/default.yaml @@ -0,0 +1,5 @@ +env: + # For default users + # cache_root: "cache" + # For internal use with persistent storage + cache_root: "/export/home/.cache/minigpt4" diff --git a/minigpt4/configs/models/minigpt4.yaml b/minigpt4/configs/models/minigpt4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..95899ce34fcb3bdad5c031ec431bdf0b25d7f4f4 --- /dev/null +++ b/minigpt4/configs/models/minigpt4.yaml @@ -0,0 +1,35 @@ +model: + arch: mini_gpt4_1 + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + freeze_qformer: True + model_type: "vit_h" + device: "cuda" + + # Q-Former + num_query_token: 32 + + # Vicuna + llama_model: "lmsys/vicuna-13b-v1.1" + + # generation configs + prompt: "" + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/configs/models/minigpt4v.yaml b/minigpt4/configs/models/minigpt4v.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae4418b6e151615733b8e2c8b3fbe4abab1759e7 --- /dev/null +++ b/minigpt4/configs/models/minigpt4v.yaml @@ -0,0 +1,35 @@ +model: + arch: mini_gpt4v + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + freeze_qformer: True + model_type: "vit_h" + device: "cuda" + + # Q-Former + num_query_token: 32 + + # Vicuna + llama_model: "lmsys/vicuna-13b-v1.1" + + # generation configs + prompt: "" + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/conversation/__init__.py b/minigpt4/conversation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..62d74aa53f63ef814fe60082761393f82202018e --- /dev/null +++ b/minigpt4/conversation/conversation.py @@ -0,0 +1,224 @@ +import argparse +import time +from PIL import Image + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer +from transformers import StoppingCriteria, StoppingCriteriaList + +import dataclasses +from enum import auto, Enum +from typing import List, Tuple, Any + +from minigpt4.common.registry import registry + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + # system_img: List[Image.Image] = [] + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "" + sep2: str = "" + + skip_next: bool = False + conv_id: Any = None + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + # ret = self.system + self.sep + ret = self.system +"" + for role, message in self.messages: + if message: + # ret += role + ": " + message + self.sep + ret+= role + message + # ret+= role + message + else: + # ret += role + ":" + # ret += self.sep2 + role + ret += role + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + # ret = self.system + seps[0] + ret = self.system+"" + for i, (role, message) in enumerate(self.messages): + if message: + # ret += role + ": " + message + seps[i % 2] + ret += role+message+seps[i%2] + else: + # ret += role + ":" + ret += role + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + # system_img=self.system_img, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + conv_id=self.conv_id) + + def dict(self): + return { + "system": self.system, + # "system_img": self.system_img, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + "conv_id": self.conv_id, + } + + +class StoppingCriteriaSub(StoppingCriteria): + + def __init__(self, stops=[], encounters=1): + super().__init__() + self.stops = stops + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + for stop in self.stops: + if torch.all((stop == input_ids[0][-len(stop):])).item(): + return True + + return False + + +CONV_VISION = Conversation( + # system="Give the following image: ImageContent. " + # "You will be able to see the image once I provide it to you. Please answer my questions.", + system = "", + roles = (r"[INST] ",r" [/INST]"), + messages=[], + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="", +) + + +class Chat: + def __init__(self, model, vis_processor, device='cuda:0'): + self.device = device + self.model = model + self.vis_processor = vis_processor + + self.conv = CONV_VISION.copy() + self.img_list = [] + self.raw_answers = [] + + stop_words_ids = [torch.tensor([2]).to(self.device)] + self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + + def reset(self): + self.conv.messages = [] + self.img_list = [] + # self.img_list = [img for img in self.conv.system_img] + self.raw_answers = [] + + def ask(self, text, conv): + if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ + and conv.messages[-1][1][-6:] == '': # last message is image. + conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) + else: + conv.append_message(conv.roles[0], text) + + def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): + conv.append_message(conv.roles[1], None) + embs = self.get_context_emb(conv, img_list) + + current_max_len = embs.shape[1] + max_new_tokens + if current_max_len - max_length > 0: + print('Warning: The number of tokens in current conversation exceeds the max length. ' + 'The model will not see the contexts outside the range.') + begin_idx = max(0, current_max_len - max_length) + + embs = embs[:, begin_idx:] + + outputs = self.model.llama_model.generate( + inputs_embeds=embs, + max_new_tokens=max_new_tokens, + stopping_criteria=self.stopping_criteria, + num_beams=num_beams, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + temperature=temperature, + do_sample=False, + ) + output_token = outputs[0] + if output_token[0] == 0: + output_token = output_token[1:] + output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) + self.raw_answers.append(output_text) + output_text = output_text.split('')[0] # remove the stop sign '###' + output_text = output_text.replace("", "") + output_text = output_text.split(r'[/INST]')[-1].strip() + self.conv.messages[-1][1] = output_text + return output_text, output_token.cpu().numpy() + + def upload_img(self, image): + if isinstance(image, str): # is a image path + raw_image = Image.open(image).convert('RGB') + image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) + elif isinstance(image, Image.Image): + raw_image = image + image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) + elif isinstance(image, torch.Tensor): + if len(image.shape) == 3: + image = image.unsqueeze(0) + image = image.to(self.device) + + image_emb, _ = self.model.encode_img(image) + self.img_list.append(image_emb) + self.conv.append_message(self.conv.roles[0], "") + msg = "Received." + # self.conv.append_message(self.conv.roles[1], msg) + return msg + + def get_context_emb(self, conv, img_list): + prompt = conv.get_prompt() + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.model.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids + # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + + seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens] + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs diff --git a/minigpt4/datasets/__init__.py b/minigpt4/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/minigpt4/datasets/builders/__init__.py b/minigpt4/datasets/builders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33fc9bc963dd530869b4ffbc3650417ec015ca63 --- /dev/null +++ b/minigpt4/datasets/builders/__init__.py @@ -0,0 +1,124 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config +from minigpt4.datasets.builders.image_text_pair_builder import ( + LaionBuilder, + RefVisualGenomeBuilder, + OpenImageBuilder, + LocNaCOCOBuilder, + LlavaDetailBuilder, + LlavaReasonBuilder, + NavR2RBuilder, + PaintPTCOCOBuilder, + PaintRLCOCOBuilder, + PaintRLSCOCOBuilder, + PaintPixelCOCO32Builder, + PaintPixelCOCO64Builder, + PaintLanRLOpaqueCOCOBuilder, + SegRefCOCO32Builder, + SegRefCOCOG32Builder, + SegRefCOCOP32Builder, + SegRefCOCO64Builder, + SegRefCOCOG64Builder, + SegRefCOCOP64Builder, + CMDVideoBuilder, + WebVidBuilder, + VideoChatGPTBuilder, +) +from minigpt4.datasets.builders.vqa_builder import ( + COCOVQABuilder, + OKVQABuilder, +# AOKVQABuilder, + COCOVQGBuilder, +# OKVQGBuilder, +# AOKVQGBuilder, + SingleSlideVQABuilder, + OCRVQABuilder +) +from minigpt4.common.registry import registry + +__all__ = [ + "LaionBuilder", + "RefVisualGenomeBuilder", + "OpenImageBuilder", + "SingleSlideVQABuilder", + "COCOVQABuilder", + "COCOVQGBuilder", + "SingleSlideVQABuilder", + "OCRVQABuilder", + "LocNaCOCOBuilder", + "LlavaDetailBuilder", + "NavR2RBuilder", + "PaintPTCOCOBuilder", + "PaintRLCOCOBuilder", + "PaintRLSCOCOBuilder", + "PaintLanRLOpaqueCOCOBuilder", + "PaintPixelCOCO32Builder", + "PaintPixelCOCO64Builder", + "SegRefCOCO32Builder", + "SegRefCOCOG32Builder", + "SegRefCOCOP32Builder", + "SegRefCOCO64Builder", + "SegRefCOCOG64Builder", + "SegRefCOCOP64Builder", + "CMDVideoBuilder", + "WebVidBuilder", + "VideoChatGPTBuilder", +] + + +def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): + """ + Example + + >>> dataset = load_dataset("coco_caption", cfg=None) + >>> splits = dataset.keys() + >>> print([len(dataset[split]) for split in splits]) + + """ + if cfg_path is None: + cfg = None + else: + cfg = load_dataset_config(cfg_path) + + try: + builder = registry.get_builder_class(name)(cfg) + except TypeError: + print( + f"Dataset {name} not found. Available datasets:\n" + + ", ".join([str(k) for k in dataset_zoo.get_names()]) + ) + exit(1) + + if vis_path is not None: + if data_type is None: + # use default data type in the config + data_type = builder.config.data_type + + assert ( + data_type in builder.config.build_info + ), f"Invalid data_type {data_type} for {name}." + + builder.config.build_info.get(data_type).storage = vis_path + + dataset = builder.build_datasets() + return dataset + + +class DatasetZoo: + def __init__(self) -> None: + self.dataset_zoo = { + k: list(v.DATASET_CONFIG_DICT.keys()) + for k, v in sorted(registry.mapping["builder_name_mapping"].items()) + } + + def get_names(self): + return list(self.dataset_zoo.keys()) + + +dataset_zoo = DatasetZoo() diff --git a/minigpt4/datasets/builders/base_dataset_builder.py b/minigpt4/datasets/builders/base_dataset_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..4b607e3c0a8abaa6b1ccbc711e27ff3755f5ec11 --- /dev/null +++ b/minigpt4/datasets/builders/base_dataset_builder.py @@ -0,0 +1,236 @@ +""" + This file is from + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os +import shutil +import warnings + +from omegaconf import OmegaConf +import torch.distributed as dist +from torchvision.datasets.utils import download_url + +import minigpt4.common.utils as utils +from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process +from minigpt4.common.registry import registry +from minigpt4.processors.base_processor import BaseProcessor + + + +class BaseDatasetBuilder: + train_dataset_cls, eval_dataset_cls = None, None + + def __init__(self, cfg=None): + super().__init__() + + if cfg is None: + # help to create datasets from default config. + self.config = load_dataset_config(self.default_config_path()) + elif isinstance(cfg, str): + self.config = load_dataset_config(cfg) + else: + # when called from task.build_dataset() + self.config = cfg + + self.data_type = self.config.data_type + + self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + + if is_main_process(): + self._download_data() + + if is_dist_avail_and_initialized(): + dist.barrier() + + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + datasets = self.build() # dataset['train'/'val'/'test'] + + return datasets + + def build_processors(self): + vis_proc_cfg = self.config.get("vis_processor") + txt_proc_cfg = self.config.get("text_processor") + + if vis_proc_cfg is not None: + vis_train_cfg = vis_proc_cfg.get("train") + vis_eval_cfg = vis_proc_cfg.get("eval") + + self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) + self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) + + if txt_proc_cfg is not None: + txt_train_cfg = txt_proc_cfg.get("train") + txt_eval_cfg = txt_proc_cfg.get("eval") + + self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) + self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) + + @staticmethod + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else None + ) + + @classmethod + def default_config_path(cls, type="default"): + return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) + + def _download_data(self): + self._download_ann() + self._download_vis() + + def _download_ann(self): + """ + Download annotation files if necessary. + All the vision-language datasets should have annotations of unified format. + + storage_path can be: + (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. + (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. + + Local annotation paths should be relative. + """ + anns = self.config.build_info.annotations + + splits = anns.keys() + + cache_root = registry.get_path("cache_root") + + for split in splits: + info = anns[split] + + urls, storage_paths = info.get("url", None), info.storage + + if isinstance(urls, str): + urls = [urls] + if isinstance(storage_paths, str): + storage_paths = [storage_paths] + + assert len(urls) == len(storage_paths) + + for url_or_filename, storage_path in zip(urls, storage_paths): + # if storage_path is relative, make it full by prefixing with cache_root. + if not os.path.isabs(storage_path): + storage_path = os.path.join(cache_root, storage_path) + + dirname = os.path.dirname(storage_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + if os.path.isfile(url_or_filename): + src, dst = url_or_filename, storage_path + if not os.path.exists(dst): + shutil.copyfile(src=src, dst=dst) + else: + logging.info("Using existing file {}.".format(dst)) + else: + if os.path.isdir(storage_path): + # if only dirname is provided, suffix with basename of URL. + raise ValueError( + "Expecting storage_path to be a file path, got directory {}".format( + storage_path + ) + ) + else: + filename = os.path.basename(storage_path) + + download_url(url=url_or_filename, root=dirname, filename=filename) + + def _download_vis(self): + + storage_path = self.config.build_info.get(self.data_type).storage + storage_path = utils.get_cache_path(storage_path) + + if not os.path.exists(storage_path): + warnings.warn( + f""" + The specified path {storage_path} for visual inputs does not exist. + Please provide a correct path to the visual inputs or + refer to datasets/download_scripts/README.md for downloading instructions. + """ + ) + + def build(self): + """ + Create by split datasets inheriting torch.utils.data.Datasets. + + # build() can be dataset-specific. Overwrite to customize. + """ + self.build_processors() + + build_info = self.config.build_info + + ann_info = build_info.annotations + vis_info = build_info.get(self.data_type) + + datasets = dict() + for split in ann_info.keys(): + if split not in ["train", "val", "test"]: + continue + + is_train = split == "train" + + # processors + vis_processor = ( + self.vis_processors["train"] + if is_train + else self.vis_processors["eval"] + ) + text_processor = ( + self.text_processors["train"] + if is_train + else self.text_processors["eval"] + ) + + # annotation path + ann_paths = ann_info.get(split).storage + if isinstance(ann_paths, str): + ann_paths = [ann_paths] + + abs_ann_paths = [] + for ann_path in ann_paths: + if not os.path.isabs(ann_path): + ann_path = utils.get_cache_path(ann_path) + abs_ann_paths.append(ann_path) + ann_paths = abs_ann_paths + + # visual data storage path + vis_path = os.path.join(vis_info.storage, split) + + if not os.path.isabs(vis_path): + # vis_path = os.path.join(utils.get_cache_path(), vis_path) + vis_path = utils.get_cache_path(vis_path) + + if not os.path.exists(vis_path): + warnings.warn("storage path {} does not exist.".format(vis_path)) + + # create datasets + dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls + datasets[split] = dataset_cls( + vis_processor=vis_processor, + text_processor=text_processor, + ann_paths=ann_paths, + vis_root=vis_path, + ) + + return datasets + + +def load_dataset_config(cfg_path): + cfg = OmegaConf.load(cfg_path).datasets + cfg = cfg[list(cfg.keys())[0]] + + return cfg diff --git a/minigpt4/datasets/builders/image_text_pair_builder.py b/minigpt4/datasets/builders/image_text_pair_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..812a8b3e2a8a1afff6492f27c19f54c6489d34f6 --- /dev/null +++ b/minigpt4/datasets/builders/image_text_pair_builder.py @@ -0,0 +1,1080 @@ +import os +import logging +import warnings + +from minigpt4.common.registry import registry +from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from minigpt4.datasets.datasets.laion_dataset import LaionDataset +from minigpt4.datasets.datasets.vg_dataset import ReferVisualGenomeDataset +from minigpt4.datasets.datasets.open_images import OpenImageDataset,OpenBboxToObjectDataset +from minigpt4.datasets.datasets.locna_dataset import LocNaCOCODataset +from minigpt4.datasets.datasets.llava_dataset import LlavaDetailDataset, LlavaReasonDataset, LlavaConversationDataset +from minigpt4.datasets.datasets.lvis_dataset import LVISBBOXDataset,LVISBboxToObjectDataset +from minigpt4.datasets.datasets.text_caps import TextCapBboxToObjectDataset, TextCapDataset +from minigpt4.datasets.datasets.coco_caption import COCOCapDataset,COCOCapEvalDataset +from minigpt4.datasets.datasets.coyo_dataset import COYOCaptionWDSDataset,COYOBoxToPhraseWDSDataset,COYOPhraseToBoxWDSDataset +# , COYOBBoxPhraseDataset +from minigpt4.datasets.datasets.grounded_detailed_image_caption_dataset import GroundedDetailDataset +from minigpt4.datasets.datasets.reasoning_dataset import ReasoningDataset +from minigpt4.datasets.datasets.video_datasets import CMDVideoDataset, WebVidDataset,VideoChatGPTDataset,Video_validation_Dataset +from minigpt4.datasets.datasets.cot import CoTDataset +from minigpt4.datasets.datasets.unnatural_instruction import UnnaturalDataset +from minigpt4.datasets.datasets.caption_reasoning import CaptionReasonDataset +from minigpt4.datasets.datasets.aok_vqa_reasoning_datasets import AOKVQAReasoningDataset +from minigpt4.datasets.datasets.paint_dataset import PaintPTCOCODataset, PaintRLCOCODataset, PaintPixelCOCODataset, SegReferCOCODataset, PaintLanRLOpaqueCOCODataset +from minigpt4.datasets.datasets.nav_dataset import NavR2RDataset + +@registry.register_builder("yifan_reasoning") +class LlavaDetailBuilder(BaseDatasetBuilder): + train_dataset_cls = AOKVQAReasoningDataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/aokvqa_reasoning/defaults.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_paths=build_info.ann_path, + vis_root=build_info.image_path, + ) + + return datasets + + +@registry.register_builder("caption_reasoning") +class CaptionReasoningBuilder(BaseDatasetBuilder): + train_dataset_cls = CaptionReasonDataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/mm_reasoning/mm_reasoning.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + + # print("ann_path",build_info.ann_path) + # print("vis root",build_info.image_path ) + + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors['train'], + text_processor=self.text_processors['train'], + ann_path=build_info.ann_path, + vis_root=build_info.image_path, + ) + + + return datasets + + +@registry.register_builder("unnatural_instruction") +class UnnaturalInstructionBuilder(BaseDatasetBuilder): + train_dataset_cls = UnnaturalDataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/nlp/unnatural_instruction.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + text_processor=self.text_processors["train"], + ann_path=build_info.ann_path, + ) + + return datasets + +@registry.register_builder("cot") +class CoTBuilder(BaseDatasetBuilder): + train_dataset_cls = CoTDataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/nlp/cot.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + text_processor=self.text_processors["train"], + ann_path=build_info.ann_path, + ) + + return datasets + + + + +@registry.register_builder("coco_caption") +class COCOCapBuilder(BaseDatasetBuilder): + train_dataset_cls = COCOCapDataset + eval_dataset_cls = COCOCapEvalDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/coco/caption.yaml", + "eval": "configs/datasets/coco/caption.yaml", + } + + +@registry.register_builder("open_images") +class OpenImageBuilder(BaseDatasetBuilder): + train_dataset_cls = OpenImageDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/open_images/default.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + + +@registry.register_builder("open_images_bbox_to_object") +class OpenBboxToObjectuilder(BaseDatasetBuilder): + train_dataset_cls = OpenBboxToObjectDataset + DATASET_CONFIG_DICT = {"default": "configs/datasets/open_images/default_bbox.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("lvis_images_bbox") +class LVISBBOxBuilder(BaseDatasetBuilder): + train_dataset_cls = LVISBBOXDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/lvis/default_bbox.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + + +@registry.register_builder("lvis_bbox_to_object") +class LVISBBoxToObjectBuilder(BaseDatasetBuilder): + train_dataset_cls = LVISBboxToObjectDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/lvis/bbox_to_object.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + + + +@registry.register_builder("spatial_reasoning") +class ReasoningBuilder(BaseDatasetBuilder): + train_dataset_cls = ReasoningDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/reasoning/default.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + ann_path=build_info.ann_path, + vis_root=build_info.image_path, + ) + + return datasets + + + + + +@registry.register_builder("textcaps_caption") +class TextcapCaptionBuilder(BaseDatasetBuilder): + train_dataset_cls = TextCapDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/textcaps/caption.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + ann_path=build_info.ann_path, + vis_root=build_info.image_path, + ) + + return datasets + + + + + +@registry.register_builder("coyo_caption") +class CoyoCaptionBuilder(BaseDatasetBuilder): + train_dataset_cls = COYOCaptionWDSDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/coyo/default.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + + +@registry.register_builder("coyo_bbox_phrase") +class CoyoBboxPhraseBuilder(BaseDatasetBuilder): + train_dataset_cls = COYOBoxToPhraseWDSDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/coyo/bbox_phrase.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("coyo_phrase_bbox") +class CoyoBboxPhraseBuilder(BaseDatasetBuilder): + train_dataset_cls = COYOPhraseToBoxWDSDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/coyo/phrase_bbox.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + + + + + + +@registry.register_builder("textcaps_ocr") +class TextcapCaptionBuilder(BaseDatasetBuilder): + train_dataset_cls = TextCapBboxToObjectDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/textcaps/ocr.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + ann_path=build_info.ann_path, + vis_root=build_info.image_path, + ) + + return datasets + + + + + +@registry.register_builder("laion") +class LaionBuilder(BaseDatasetBuilder): + train_dataset_cls = LaionDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("locna_coco") +class LocNaCOCOBuilder(BaseDatasetBuilder): + train_dataset_cls = LocNaCOCODataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/coco/defaults_locna.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + ann_paths = build_info.annotations.train.storage + + datasets = dict() + + for ann_path in ann_paths: + if not os.path.exists(ann_path): + warnings.warn("storage path {} does not exist.".format(ann_path)) + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_paths=ann_paths, + vis_root=build_info.images.storage, + ) + + return datasets + + +@registry.register_builder("llava_detail") +class LlavaDetailBuilder(BaseDatasetBuilder): + train_dataset_cls = LlavaDetailDataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/llava/detail.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_path=build_info.ann_path, + vis_root=build_info.image_path, + ) + + return datasets + +@registry.register_builder("grounded_detailed_image_caption") +class GroundedCaptionBuilder(BaseDatasetBuilder): + train_dataset_cls = GroundedDetailDataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/grounded_image_caption/default.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_path=build_info.ann_path, + vis_root=build_info.image_path, + ) + + return datasets + + + + +@registry.register_builder("llava_reason") +class LlavaReasonBuilder(BaseDatasetBuilder): + train_dataset_cls = LlavaReasonDataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/llava/reason.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_path=build_info.ann_path, + vis_root=build_info.image_path, + ) + + return datasets + + + + + +@registry.register_builder("llava_conversation") +class LlavaReasonBuilder(BaseDatasetBuilder): + train_dataset_cls = LlavaConversationDataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/llava/conversation.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_path=build_info.ann_path, + vis_root=build_info.image_path, + ) + + return datasets + + +class AllRefCOCOBuilder(BaseDatasetBuilder): + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + image_path = build_info.image_path + ann_path = build_info.ann_path + + datasets = dict() + + if not os.path.exists(image_path): + warnings.warn("image path {} does not exist.".format(image_path)) + if not os.path.exists(ann_path): + warnings.warn("ann path {} does not exist.".format(ann_path)) + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_path=ann_path, + vis_root=image_path, + dataset=build_info.dataset, + splitBy=build_info.splitBy + ) + + return datasets + + +@registry.register_builder("refvg") +class RefVisualGenomeBuilder(BaseDatasetBuilder): + train_dataset_cls = ReferVisualGenomeDataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/vg/ref.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + data_dir = build_info.data_dir + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + data_dir=data_dir, + ) + + return datasets + + +@registry.register_builder("cmd_video") +class CMDVideoBuilder(BaseDatasetBuilder): + train_dataset_cls = CMDVideoDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/cmd_video/default.yaml", + } + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + + self.build_processors() + + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + vis_root=build_info.vis_root, + ann_paths=build_info.ann_paths, + cc_path=build_info.cc_path + ) + + return datasets + + +@registry.register_builder("webvid") +class WebVidBuilder(BaseDatasetBuilder): + train_dataset_cls = WebVidDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/webvid/default.yaml", + } + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + + self.build_processors() + + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + vis_root=build_info.vis_root, + ann_paths=build_info.ann_paths, + subtitles_path=build_info.subtitles_path, + ) + + return datasets + + +@registry.register_builder("video_chatgpt") +class VideoChatGPTBuilder(BaseDatasetBuilder): + train_dataset_cls = VideoChatGPTDataset + eval_dataset_cls=Video_validation_Dataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/video_chatgpt/default.yaml", + } + print(DATASET_CONFIG_DICT) + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + self.build_processors() + + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + vis_root=build_info.vis_root, + ann_paths=build_info.ann_paths, + ) + + return datasets + +@registry.register_builder("r2r") +class NavR2RBuilder(BaseDatasetBuilder): + train_dataset_cls = NavR2RDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/nav/r2r.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + data_root=build_info.data_root + ) + + return datasets + + +@registry.register_builder("paintcoco") +class PaintPTCOCOBuilder(BaseDatasetBuilder): + train_dataset_cls = PaintPTCOCODataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/coco.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + img_root = build_info.img_root + stroke_root = build_info.stroke_root + max_step = build_info.max_step + + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + img_root=img_root, + stroke_root=stroke_root, + max_step=max_step + ) + + return datasets + + +class PaintRLCOCOBuilderBase(BaseDatasetBuilder): + train_dataset_cls = PaintRLCOCODataset + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + img_root = build_info.img_root + stroke_root = build_info.stroke_root + max_step = build_info.max_step + single_stroke = build_info.single_stroke + + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + img_root=img_root, + stroke_root=stroke_root, + max_step=max_step, + single_stroke=single_stroke + ) + + return datasets + + +@registry.register_builder("paintrlcoco") +class PaintRLCOCOBuilder(PaintRLCOCOBuilderBase): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/rl_coco.yaml", + } + + +@registry.register_builder("paintrlscoco") +class PaintRLSCOCOBuilder(PaintRLCOCOBuilderBase): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/rls_coco.yaml", + } + + +@registry.register_builder("paintlanrlsococo") +class PaintLanRLOpaqueCOCOBuilder(BaseDatasetBuilder): + train_dataset_cls = PaintLanRLOpaqueCOCODataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/lan_rls_o_coco.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + img_root = build_info.img_root + stroke_root = build_info.stroke_root + max_step = build_info.max_step + single_stroke = build_info.single_stroke + ann_path = build_info.ann_path + + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + img_root=img_root, + stroke_root=stroke_root, + ann_path=ann_path, + max_step=max_step, + single_stroke=single_stroke + ) + + return datasets + + +class PaintPixelCOCOBuilder(BaseDatasetBuilder): + train_dataset_cls = PaintPixelCOCODataset + + def build(self): + """ + Create by split datasets inheriting torch.utils.data.Datasets. + + # build() can be dataset-specific. Overwrite to customize. + """ + self.build_processors() + + build_info = self.config.build_info + + ann_info = build_info.annotations + vis_info = build_info.get(self.data_type) + res = build_info.res + + datasets = dict() + split = 'train' + + # annotation path + ann_paths = ann_info.get(split).storage + if isinstance(ann_paths, str): + ann_paths = [ann_paths] + + # visual data storage path + vis_path = os.path.join(vis_info.storage, split) + + # create datasets + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_paths=ann_paths, + vis_root=vis_path, + res=res + ) + + return datasets + + +@registry.register_builder("paintpixelcoco32") +class PaintPixelCOCO32Builder(PaintPixelCOCOBuilder): + train_dataset_cls = PaintPixelCOCODataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/pixel_coco_32.yaml", + } + + +@registry.register_builder("paintpixelcoco64") +class PaintPixelCOCO64Builder(PaintPixelCOCOBuilder): + train_dataset_cls = PaintPixelCOCODataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/pixel_coco_64.yaml", + } + + +class AllSegRefCOCOBuilder(BaseDatasetBuilder): + train_dataset_cls = SegReferCOCODataset + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + image_path = build_info.image_path + ann_path = build_info.ann_path + res = build_info.res + + datasets = dict() + + if not os.path.exists(image_path): + warnings.warn("image path {} does not exist.".format(image_path)) + if not os.path.exists(ann_path): + warnings.warn("ann path {} does not exist.".format(ann_path)) + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_path=ann_path, + vis_root=image_path, + res=res, + dataset=build_info.dataset, + splitBy=build_info.splitBy + ) + + return datasets + + +@registry.register_builder("segrefcoco32") +class SegRefCOCO32Builder(AllSegRefCOCOBuilder): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/segrefcoco32.yaml", + } + + +@registry.register_builder("segrefcocop32") +class SegRefCOCOP32Builder(AllSegRefCOCOBuilder): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/segrefcocop32.yaml", + } + + +@registry.register_builder("segrefcocog32") +class SegRefCOCOG32Builder(AllSegRefCOCOBuilder): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/segrefcocog32.yaml", + } + + +@registry.register_builder("segrefcoco64") +class SegRefCOCO64Builder(AllSegRefCOCOBuilder): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/segrefcoco64.yaml", + } + + +@registry.register_builder("segrefcocop64") +class SegRefCOCOP64Builder(AllSegRefCOCOBuilder): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/segrefcocop64.yaml", + } + + +@registry.register_builder("segrefcocog64") +class SegRefCOCOG64Builder(AllSegRefCOCOBuilder): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/paint/segrefcocog64.yaml", + } diff --git a/minigpt4/datasets/builders/vqa_builder.py b/minigpt4/datasets/builders/vqa_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9309e5e866023a3223614ea03d18ed3a51dff2 --- /dev/null +++ b/minigpt4/datasets/builders/vqa_builder.py @@ -0,0 +1,131 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder + +from minigpt4.common.registry import registry +from minigpt4.datasets.datasets.aok_vqa_datasets import AOKVQADataset +from minigpt4.datasets.datasets.aok_vqa_reasoning_datasets import AOKVQAReasoningDataset +#, AOKVQGDataset, AOKVQAEvalDataset +from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset, COCOVQGDataset, COCOVQAEvalDataset +# from minigpt4.datasets.datasets.vg_vqa_datasets import VGVQADataset +from minigpt4.datasets.datasets.gqa_datasets import GQADataset, GQAEvalDataset +from minigpt4.datasets.datasets.doc_dataset import SingleSlideVQADataset, OCRVQADataset + + + +@registry.register_builder("coco_vqa") +class COCOVQABuilder(BaseDatasetBuilder): + train_dataset_cls = COCOVQADataset + eval_dataset_cls = COCOVQAEvalDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/coco/defaults_vqa.yaml", + "eval": "configs/datasets/coco/eval_vqa.yaml", + } + + +# @registry.register_builder("vg_vqa") +# class VGVQABuilder(BaseDatasetBuilder): +# train_dataset_cls = VGVQADataset +# DATASET_CONFIG_DICT = {"default": "configs/datasets/vg/defaults_vqa.yaml"} + + +@registry.register_builder("ok_vqa") +class OKVQABuilder(COCOVQABuilder): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/okvqa/defaults.yaml", + } + + +@registry.register_builder("aok_vqa") +class AOKVQABuilder(BaseDatasetBuilder): + train_dataset_cls = AOKVQADataset + # eval_dataset_cls = AOKVQAEvalDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults.yaml"} + +@registry.register_builder("aok_vqa_reasoning") +class AOKVQABuilder(BaseDatasetBuilder): + train_dataset_cls = AOKVQAReasoningDataset + # eval_dataset_cls = AOKVQAEvalDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa_reasoning/defaults.yaml"} + + +@registry.register_builder("gqa") +class GQABuilder(BaseDatasetBuilder): + train_dataset_cls = GQADataset + # eval_dataset_cls = GQAEvalDataset + + DATASET_CONFIG_DICT = { + # "default": "configs/datasets/gqa/defaults.yaml", + # "balanced_val": "configs/datasets/gqa/balanced_val.yaml", + "default": "configs/datasets/gqa/balanced_val.yaml", + # "balanced_testdev": "configs/datasets/gqa/balanced_testdev.yaml", + } + + + +@registry.register_builder("coco_vqg") +class COCOVQGBuilder(BaseDatasetBuilder): + train_dataset_cls = COCOVQGDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/coco/defaults_vqg.yaml", + } + + +@registry.register_builder("ok_vqg") +class OKVQGBuilder(COCOVQGBuilder): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/okvqa/defaults_vqg.yaml", + } + + +# @registry.register_builder("aok_vqg") +# class AOKVQGBuilder(BaseDatasetBuilder): +# train_dataset_cls = AOKVQGDataset + +# DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults_vqg.yaml"} + + +class DocumentVQABuilder(BaseDatasetBuilder): + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + build_info = self.config.build_info + + datasets = dict() + split = "train" + + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + vis_root=build_info.image_path, + ann_path=build_info.ann_path + ) + + return datasets + + +@registry.register_builder("sslidevqa") +class SingleSlideVQABuilder(DocumentVQABuilder): + train_dataset_cls = SingleSlideVQADataset + DATASET_CONFIG_DICT = {"default": "configs/datasets/doc/sslidevqa.yaml"} + + +@registry.register_builder("ocrvqa") +class OCRVQABuilder(DocumentVQABuilder): + train_dataset_cls = OCRVQADataset + DATASET_CONFIG_DICT = {"default": "configs/datasets/doc/ocrvqa.yaml"} \ No newline at end of file diff --git a/minigpt4/datasets/data_utils.py b/minigpt4/datasets/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..773b10facf26e89f71db6f7841a0377f93f1a2a9 --- /dev/null +++ b/minigpt4/datasets/data_utils.py @@ -0,0 +1,199 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import gzip +import logging +import os +import random as rnd +import tarfile +import zipfile +import random +from typing import List +from tqdm import tqdm + +import decord +from decord import VideoReader +import webdataset as wds +import numpy as np +import torch +from torch.utils.data.dataset import IterableDataset + +from minigpt4.common.registry import registry +from minigpt4.datasets.datasets.base_dataset import ConcatDataset + + +decord.bridge.set_bridge("torch") +MAX_INT = registry.get("MAX_INT") + + +class ChainDataset(wds.DataPipeline): + r"""Dataset for chaining multiple :class:`DataPipeline` s. + + This class is useful to assemble different existing dataset streams. The + chaining operation is done on-the-fly, so concatenating large-scale + datasets with this class will be efficient. + + Args: + datasets (iterable of IterableDataset): datasets to be chained together + """ + def __init__(self, datasets: List[wds.DataPipeline]) -> None: + super().__init__() + self.datasets = datasets + self.prob = [] + self.names = [] + for dataset in self.datasets: + if hasattr(dataset, 'name'): + self.names.append(dataset.name) + else: + self.names.append('Unknown') + if hasattr(dataset, 'sample_ratio'): + self.prob.append(dataset.sample_ratio) + else: + self.prob.append(1) + logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") + + def __iter__(self): + datastreams = [iter(dataset) for dataset in self.datasets] + while True: + select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] + yield next(select_datastream) + + +def apply_to_sample(f, sample): + if len(sample) == 0: + return {} + + def _apply(x): + if torch.is_tensor(x): + return f(x) + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + else: + return x + + return _apply(sample) + + +def move_to_cuda(sample): + def _move_to_cuda(tensor): + return tensor.cuda() + + return apply_to_sample(_move_to_cuda, sample) + + +def prepare_sample(samples, cuda_enabled=True): + if cuda_enabled: + samples = move_to_cuda(samples) + + # TODO fp16 support + + return samples + + +def reorg_datasets_by_split(datasets, batch_sizes): + """ + Organizes datasets by split. + + Args: + datasets: dict of torch.utils.data.Dataset objects by name. + + Returns: + Dict of datasets by split {split_name: List[Datasets]}. + """ + # if len(datasets) == 1: + # return datasets[list(datasets.keys())[0]] + # else: + reorg_datasets = dict() + reorg_batch_sizes = dict() + + # reorganize by split + for dataset_name, dataset in datasets.items(): + for split_name, dataset_split in dataset.items(): + if split_name not in reorg_datasets: + reorg_datasets[split_name] = [dataset_split] + reorg_batch_sizes[split_name] = [batch_sizes[dataset_name]] + else: + reorg_datasets[split_name].append(dataset_split) + reorg_batch_sizes[split_name].append(batch_sizes[dataset_name]) + + return reorg_datasets, reorg_batch_sizes + + +def concat_datasets(datasets): + """ + Concatenates multiple datasets into a single dataset. + + It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support + generic IterableDataset because it requires creating separate samplers. + + Now only supports conctenating training datasets and assuming validation and testing + have only a single dataset. This is because metrics should not be computed on the concatenated + datasets. + + Args: + datasets: dict of torch.utils.data.Dataset objects by split. + + Returns: + Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, + "val" and "test" remain the same. + + If the input training datasets contain both map-style and DataPipeline datasets, returns + a tuple, where the first element is a concatenated map-style dataset and the second + element is a chained DataPipeline dataset. + + """ + # concatenate datasets in the same split + for split_name in datasets: + if split_name != "train": + assert ( + len(datasets[split_name]) == 1 + ), "Do not support multiple {} datasets.".format(split_name) + datasets[split_name] = datasets[split_name][0] + else: + iterable_datasets, map_datasets = [], [] + for dataset in datasets[split_name]: + if isinstance(dataset, wds.DataPipeline): + logging.info( + "Dataset {} is IterableDataset, can't be concatenated.".format( + dataset + ) + ) + iterable_datasets.append(dataset) + elif isinstance(dataset, IterableDataset): + raise NotImplementedError( + "Do not support concatenation of generic IterableDataset." + ) + else: + map_datasets.append(dataset) + + # if len(iterable_datasets) > 0: + # concatenate map-style datasets and iterable-style datasets separately + if len(iterable_datasets) > 1: + chained_datasets = ( + ChainDataset(iterable_datasets) + ) + elif len(iterable_datasets) == 1: + chained_datasets = iterable_datasets[0] + else: + chained_datasets = None + + concat_datasets = ( + ConcatDataset(map_datasets) if len(map_datasets) > 0 else None + ) + + train_datasets = concat_datasets, chained_datasets + train_datasets = tuple([x for x in train_datasets if x is not None]) + train_datasets = ( + train_datasets[0] if len(train_datasets) == 1 else train_datasets + ) + + datasets[split_name] = train_datasets + + return datasets + diff --git a/minigpt4/datasets/datasets/__init__.py b/minigpt4/datasets/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/minigpt4/datasets/datasets/aok_vqa_datasets.py b/minigpt4/datasets/datasets/aok_vqa_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..b65b42d267b1184578615ea19610b60a9b54a5ae --- /dev/null +++ b/minigpt4/datasets/datasets/aok_vqa_datasets.py @@ -0,0 +1,212 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from collections import OrderedDict +import json +import os +import random +import torch + +from PIL import Image + +from minigpt4.datasets.datasets.vqa_datasets import VQADataset #, VQAEvalDataset + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + return OrderedDict( + { + "file": ann["image"], + "question": ann["question"], + "question_id": ann["question_id"], + "direct_answers": "; ".join(ann["direct_answers"]), + "choices": "; ".join(ann["choices"]), + "correct_choice": ann["choices"][ann["correct_choice_idx"]], + "image": sample["image"], + } + ) + + +class AOKVQADataset(VQADataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.instruction_pool =[ + "[vqa] {}", + "[vqa] Based on the image, respond to this question with a short answer: {}" + ] + + exist_annotation = [] + for ann in self.annotation: + image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + if os.path.exists(image_path): + exist_annotation.append(ann) + self.annotation = exist_annotation + + def get_data(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + answer_key = "direct_answers" + + # print("answer key", answer_key) + # for answer in ann[answer_key]: + # print(answer) + + answer_weight = {} + for answer in ann[answer_key]: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len(ann[answer_key]) + else: + answer_weight[answer] = 1 / len(ann[answer_key]) + + answers = list(answer_weight.keys()) + weights = list(answer_weight.values()) + + answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights + + return { + "image": image, + "question": question, + "answer": answer, + } + + def __getitem__(self, index): + data = self.get_data(index) + question = self.text_processor(data["question"]) + instruction = random.choice(self.instruction_pool).format(question) + + instruction = " {} ".format(instruction) + + answer = self.text_processor(data['answer']) + + + return { + "image": data['image'], + "instruction_input": instruction, + "answer": answer, + } + + +class AOKVQGDataset(AOKVQADataset): + + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + self.instruction_pool = [ + 'Given the image, generate a question whose answer is: {}', + 'Based on the image, provide a question with the answer: {}', + 'Given the visual representation, create a question for which the answer is "{}"', + 'From the image provided, craft a question that leads to the reply: {}', + 'Considering the picture, come up with a question where the answer is: {}', + 'Taking the image into account, generate an question that has the answer: {}' + ] + + def __getitem__(self, index): + data = self.get_data(index) + instruction = random.choice(self.instruction_pool).format(data['answer']) + + return { + "image": data['image'], + "instruction_input": instruction, + "answer": data['question'], + } + + +# class AOKVQAEvalDataset(VQAEvalDataset, __DisplMixin): +# def __init__(self, vis_processor, text_processor, vis_root, ann_paths): +# """ +# vis_root (string): Root directory of images (e.g. coco/images/) +# ann_root (string): directory to store the annotation file +# """ +# +# self.vis_root = vis_root +# +# self.annotation = json.load(open(ann_paths[0])) +# +# answer_list_path = ann_paths[1] +# if os.path.exists(answer_list_path): +# self.answer_list = json.load(open(answer_list_path)) +# else: +# self.answer_list = None +# +# try: +# self.coco_fmt_qust_file = ann_paths[2] +# self.coco_fmt_anno_file = ann_paths[3] +# except IndexError: +# self.coco_fmt_qust_file = None +# self.coco_fmt_anno_file = None +# +# self.vis_processor = vis_processor +# self.text_processor = text_processor +# +# self._add_instance_ids() +# +# def collater(self, samples): +# ( +# image_list, +# question_list, +# question_id_list, +# instance_id_list, +# choices_list, +# correct_choice_idx_list, +# direct_answers_list, +# ) = ([], [], [], [], [], [], []) +# +# for sample in samples: +# image_list.append(sample["image"]) +# question_list.append(sample["text_input"]) +# question_id_list.append(sample["question_id"]) +# instance_id_list.append(sample["instance_id"]) +# choices_list.append(sample["choices"]) +# correct_choice_idx_list.append(sample["correct_choice_idx"]) +# direct_answers_list.append(sample["direct_answers"]) +# +# return { +# "image": torch.stack(image_list, dim=0), +# "text_input": question_list, +# "question_id": question_id_list, +# "instance_id": instance_id_list, +# "choices": choices_list, +# "correct_choice_idx": correct_choice_idx_list, +# "direct_answers": direct_answers_list, +# } +# +# def __getitem__(self, index): +# ann = self.annotation[index] +# +# image_path = os.path.join(self.vis_root, ann["image"]) +# image = Image.open(image_path).convert("RGB") +# +# image = self.vis_processor(image) +# question = self.text_processor(ann["question"]) +# +# choices = ann["choices"] +# if "correct_choice_idx" in ann: +# correct_choice_idx = ann["correct_choice_idx"] +# else: +# correct_choice_idx = None +# +# if "direct_answers" in ann: +# direct_answers = ann["direct_answers"] +# else: +# direct_answers = None +# +# return { +# "image": image, +# "text_input": question, +# "question_id": ann["question_id"], +# "instance_id": ann["instance_id"], +# "choices": choices, +# "correct_choice_idx": correct_choice_idx, +# "direct_answers": direct_answers, +# } diff --git a/minigpt4/datasets/datasets/aok_vqa_reasoning_datasets.py b/minigpt4/datasets/datasets/aok_vqa_reasoning_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..14ed1bbddaf91a8fae375623a4ed35a26100f098 --- /dev/null +++ b/minigpt4/datasets/datasets/aok_vqa_reasoning_datasets.py @@ -0,0 +1,262 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from collections import OrderedDict +import json +import os +import random +import torch +from torch.utils.data import Dataset + +from PIL import Image + +from minigpt4.datasets.datasets.vqa_datasets import VQADataset #, VQAEvalDataset + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + return OrderedDict( + { + "file": ann["image"], + "question": ann["question"], + "question_id": ann["question_id"], + "direct_answers": "; ".join(ann["direct_answers"]), + "choices": "; ".join(ann["choices"]), + "correct_choice": ann["choices"][ann["correct_choice_idx"]], + "image": sample["image"], + } + ) + + +class AOKVQAReasoningDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + # super().__init__(vis_processor, text_processor, vis_root, ann_paths) + # self.instruction_pool = [ + # '{}', + # 'Question: {}', + # '{} A short answer to the question is', + # 'Q: {} A:', + # 'Answer the following question based on the image content. Question: {} Short answer:', + # # 'Given the image, answer the following question with no more than three words. {}', + # 'Based on the image, respond to this question with a short answer: {}.', + # 'Use the provided image to answer the question: {} Provide your answer as short as possible.', + # 'What is the answer to the following question? "{}"', + # 'Given this image, answer this question concisely: {} ', + # 'The question "{}" can be answered using the image. A short answer is' + # ] + # self.instruction_pool =[ + # "[vqa] {}", + # "[vqa] Based on the image, respond to this question with a short answer: {}" + # ] + self.vis_processor = vis_processor + self.text_processor = text_processor + self.vis_root = vis_root + self.instruction_pool =[ + "[vqa] {}" + ] + annotation = [] + with open(ann_paths, 'r') as f: + for line in f.readlines(): + json_data = json.loads(line) + annotation.append(json_data) + + exist_annotation = [] + for ann in annotation: + image_path = os.path.join(self.vis_root, ann["image_path"].split('/')[-1]) + if os.path.exists(image_path): + exist_annotation.append(ann) + else: + print("does not exists", image_path) + self.annotation = exist_annotation + + def __len__(self): + return len(self.annotation) + + def get_data(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image_path"].split('/')[-1]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + rationales = ann["analysis"] + + + + # print("answer key", answer_key) + # for answer in ann[answer_key]: + # print(answer) + + # answer_weight = {} + # for answer in ann[answer_key]: + # if answer in answer_weight.keys(): + # answer_weight[answer] += 1 / len(ann[answer_key]) + # else: + # answer_weight[answer] = 1 / len(ann[answer_key]) + + # answers = list(answer_weight.keys()) + # weights = list(answer_weight.values()) + + # answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights + # choices = ann["choices"] + + # print("question",question) + # print("answer", rationales) + return { + "image": image, + "question": question, + # "answer": analysis, + "reason":rationales, + # "choice":choices + } + + def __getitem__(self, index): + data = self.get_data(index) + question = self.text_processor(data["question"]) + instruction = random.choice(self.instruction_pool).format(question) + + instruction = " {} ".format(instruction) + + random_index = random.randint(0,1) + # reason = random.choice(data["reason"]) + answer = data["reason"] + + analysis = answer.split("\nAnswer:")[0] + answer = answer.split("\nAnswer:")[-1] + + # answer = data["reaso"] + + if random_index ==0: + instruction = instruction+analysis+"\nAnswer:" + + elif random_index==1: + answer = analysis+"\nAnswer:"+answer + + + return { + "image": data['image'], + "instruction_input": instruction, + "answer": answer, + } + + +class AOKVQGDataset(AOKVQAReasoningDataset): + + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + self.instruction_pool = [ + 'Given the image, generate a question whose answer is: {}', + 'Based on the image, provide a question with the answer: {}', + 'Given the visual representation, create a question for which the answer is "{}"', + 'From the image provided, craft a question that leads to the reply: {}', + 'Considering the picture, come up with a question where the answer is: {}', + 'Taking the image into account, generate an question that has the answer: {}' + ] + + def __getitem__(self, index): + data = self.get_data(index) + instruction = random.choice(self.instruction_pool).format(data['answer']) + # instruction = "###Human: {}###Assistant: ".format(instruction) + + return { + "image": data['image'], + "instruction_input": instruction, + "answer": data['question'], + } + + +# class AOKVQAEvalDataset(VQAEvalDataset, __DisplMixin): +# def __init__(self, vis_processor, text_processor, vis_root, ann_paths): +# """ +# vis_root (string): Root directory of images (e.g. coco/images/) +# ann_root (string): directory to store the annotation file +# """ +# +# self.vis_root = vis_root +# +# self.annotation = json.load(open(ann_paths[0])) +# +# answer_list_path = ann_paths[1] +# if os.path.exists(answer_list_path): +# self.answer_list = json.load(open(answer_list_path)) +# else: +# self.answer_list = None +# +# try: +# self.coco_fmt_qust_file = ann_paths[2] +# self.coco_fmt_anno_file = ann_paths[3] +# except IndexError: +# self.coco_fmt_qust_file = None +# self.coco_fmt_anno_file = None +# +# self.vis_processor = vis_processor +# self.text_processor = text_processor +# +# self._add_instance_ids() +# +# def collater(self, samples): +# ( +# image_list, +# question_list, +# question_id_list, +# instance_id_list, +# choices_list, +# correct_choice_idx_list, +# direct_answers_list, +# ) = ([], [], [], [], [], [], []) +# +# for sample in samples: +# image_list.append(sample["image"]) +# question_list.append(sample["text_input"]) +# question_id_list.append(sample["question_id"]) +# instance_id_list.append(sample["instance_id"]) +# choices_list.append(sample["choices"]) +# correct_choice_idx_list.append(sample["correct_choice_idx"]) +# direct_answers_list.append(sample["direct_answers"]) +# +# return { +# "image": torch.stack(image_list, dim=0), +# "text_input": question_list, +# "question_id": question_id_list, +# "instance_id": instance_id_list, +# "choices": choices_list, +# "correct_choice_idx": correct_choice_idx_list, +# "direct_answers": direct_answers_list, +# } +# +# def __getitem__(self, index): +# ann = self.annotation[index] +# +# image_path = os.path.join(self.vis_root, ann["image"]) +# image = Image.open(image_path).convert("RGB") +# +# image = self.vis_processor(image) +# question = self.text_processor(ann["question"]) +# +# choices = ann["choices"] +# if "correct_choice_idx" in ann: +# correct_choice_idx = ann["correct_choice_idx"] +# else: +# correct_choice_idx = None +# +# if "direct_answers" in ann: +# direct_answers = ann["direct_answers"] +# else: +# direct_answers = None +# +# return { +# "image": image, +# "text_input": question, +# "question_id": ann["question_id"], +# "instance_id": ann["instance_id"], +# "choices": choices, +# "correct_choice_idx": correct_choice_idx, +# "direct_answers": direct_answers, +# } diff --git a/minigpt4/datasets/datasets/base_dataset.py b/minigpt4/datasets/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..81d58372ad888c525af681932a642a36fa3a91a7 --- /dev/null +++ b/minigpt4/datasets/datasets/base_dataset.py @@ -0,0 +1,75 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import json +from typing import Iterable + +from torch.utils.data import Dataset, ConcatDataset +from torch.utils.data.dataloader import default_collate + + +class BaseDataset(Dataset): + def __init__( + self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] + ): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.annotation = [] + # print("ann paths", ann_paths) + for ann_path in ann_paths: + # print("ann_path", ann_path) + ann = json.load(open(ann_path, "r")) + if isinstance(ann, dict): + self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) + # self.annotation.extend(json.load(open(ann_path, "r"))) + else: + self.annotation.extend(json.load(open(ann_path, "r"))) + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + def __len__(self): + return len(self.annotation) + + def collater(self, samples): + return default_collate(samples) + + def set_processors(self, vis_processor, text_processor): + self.vis_processor = vis_processor + self.text_processor = text_processor + + def _add_instance_ids(self, key="instance_id"): + for idx, ann in enumerate(self.annotation): + ann[key] = str(idx) + + +class ConcatDataset(ConcatDataset): + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__(datasets) + + def collater(self, samples): + # TODO For now only supports datasets with same underlying collater implementations + + all_keys = set() + for s in samples: + all_keys.update(s) + + shared_keys = all_keys + for s in samples: + shared_keys = shared_keys & set(s.keys()) + + samples_shared_keys = [] + for s in samples: + samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) + + return self.datasets[0].collater(samples_shared_keys) diff --git a/minigpt4/datasets/datasets/caption_datasets.py b/minigpt4/datasets/datasets/caption_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..be40164cf38df218f7c1e96e8c8ef31c18cce841 --- /dev/null +++ b/minigpt4/datasets/datasets/caption_datasets.py @@ -0,0 +1,150 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from collections import OrderedDict + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from PIL import Image +import random + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "caption": ann["caption"], + "image": sample["image"], + } + ) + + +class CaptionDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + img_file = '{:0>12}.jpg'.format(ann["image_id"]) + image_path = os.path.join(self.vis_root, img_file) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + caption = self.text_processor(ann["caption"]) + + return { + "image": image, + "answer": caption, + "image_id": self.img_ids[ann["image_id"]], + } + + +class COCOCaptionDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.img_ids = {} + n = 0 + + self.filter_anntation = [] + + for ann in self.annotation: + if "train" in ann["image"]: + self.filter_anntation.append(ann) + self.annotation = self.filter_anntation + + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + self.instruction_pool = [ + 'Briefly describe this image.', + 'Provide a concise depiction of this image.', + 'Present a short description of this image.', + 'Summarize this image in a few words.', + 'A short image caption:', + 'A short image description:', + 'A photo of ', + 'An image that shows ', + 'Write a short description for the image. ', + 'Write a description for the photo.', + 'Provide a description of what is presented in the photo.', + 'Briefly describe the content of the image.', + 'Can you briefly explain what you see in the image?', + 'Could you use a few words to describe what you perceive in the photo?', + 'Please provide a short depiction of the picture.', + 'Using language, provide a short account of the image.', + 'Use a few words to illustrate what is happening in the picture.', + ] + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + # img_file = '{:0>12}.jpg'.format(ann["image_id"]) + img_file = ann["image"].split("/")[-1] + image_path = os.path.join(self.vis_root, img_file) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + caption = self.text_processor(ann["caption"]) + + instruction = random.choice(self.instruction_pool) + instruction = " [caption] {} ".format(instruction) + + return { + "image": image, + "answer": caption, + "instruction_input": instruction, + } + +class CaptionEvalDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + + return { + "image": image, + "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } diff --git a/minigpt4/datasets/datasets/caption_reasoning.py b/minigpt4/datasets/datasets/caption_reasoning.py new file mode 100644 index 0000000000000000000000000000000000000000..7eb81a86db3fa1a80cda65f649ef3eb61fd0773c --- /dev/null +++ b/minigpt4/datasets/datasets/caption_reasoning.py @@ -0,0 +1,120 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds + + +from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset + +from collections import OrderedDict + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "question": ann["question"], + "question_id": ann["question_id"], + "answers": "; ".join(ann["answer"]), + "image": sample["image"], + } + ) + + +# class CaptionReasonDataset(VQADataset, __DisplMixin): +class CaptionReasonDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self.instruction_pool =[ + "[reasoning] {}" + ] + # print(ann_path) + with open(ann_path, 'r') as f: + self.ann = json.load(f) + + + # exist_annotation = [] + # for ann in self.annotation: + # image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + # if os.path.exists(image_path): + # exist_annotation.append(ann) + # self.annotation = exist_annotation + + + def get_data(self, index): + ann = self.ann[index] + + image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + question_id = ann["question_id"] + + answer_weight = {} + for answer in ann["answer"]: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len(ann["answer"]) + else: + answer_weight[answer] = 1 / len(ann["answer"]) + + answers = list(answer_weight.keys()) + weights = list(answer_weight.values()) + + answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights + + + + grounded_caption = ann["grounded_caption"] + detailed_caption = ann["detailed_caption"] + return { + "image": image, + "question": question, + "question_id": question_id, + "answer": answer, + "detailed_caption": detailed_caption, + "grounded_caption": grounded_caption + } + + def __len__(self): + return len(self.ann) + + def __getitem__(self, index): + data = self.get_data(index) + + question =data['question'] + detailed_caption = data["detailed_caption"] + grounded_caption = data["grounded_caption"] + + instruction = random.choice(self.instruction_pool).format(question) + instruction = " {}".format(instruction) + + answer = grounded_caption+" short answer: "+data['answer'] + # print("instruction", instruction) + # print("answer", answer) + + + return { + "image": data['image'], + "question_id": data["question_id"], + "instruction_input": instruction, + "answer": answer, + } diff --git a/minigpt4/datasets/datasets/coco_caption.py b/minigpt4/datasets/datasets/coco_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..4fd39cf9538febd65f0f7eca2d8a6e9a2afb81de --- /dev/null +++ b/minigpt4/datasets/datasets/coco_caption.py @@ -0,0 +1,120 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import json +import torch +import numpy as np + +from PIL import Image +from PIL import ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +from minigpt4.datasets.datasets.caption_datasets import COCOCaptionDataset, CaptionEvalDataset + +COCOCapDataset = COCOCaptionDataset + + + + + +class COCOCapEvalDataset(CaptionEvalDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + + img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1] + + return { + "image": image, + "image_id": img_id, + "instance_id": ann["instance_id"], + } + + +class NoCapsEvalDataset(CaptionEvalDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + + img_id = ann["img_id"] + + return { + "image": image, + "image_id": img_id, + "instance_id": ann["instance_id"], + } + + +class RefCOCOEvalData(torch.utils.data.Dataset): + def __init__(self, loaded_data, vis_processor, root_path): + self.loaded_data = loaded_data + self.root_path = root_path + self.vis_processor = vis_processor + + def __len__(self): + return len(self.loaded_data) + + def __getitem__(self, idx): + data = self.loaded_data[idx] + img_id = data['img_id'] + sent = data['sents'] + image_path = os.path.join(self.root_path, f'{img_id[:27]}.jpg') + image = Image.open(image_path).convert('RGB') + image = self.vis_processor(image) + question = f"[refer] where is {sent}?" + return image, question, img_id + +class EvalCaptionData(torch.utils.data.Dataset): + def __init__(self, loaded_data, vis_processor, root_path): + self.loaded_data = loaded_data + self.root_path = root_path + self.vis_processor = vis_processor + ann = dict() + for item in self.loaded_data: + image_id = item['image_id'] + ann[image_id] = item['image'] + self.ann = [{'image_id':image_id, 'image': ann[image_id]} for image_id in ann] + + def __len__(self): + return len(self.ann) + + def __getitem__(self, idx): + data = self.ann[idx] + image_id = data['image_id'] + img_file = data['image'].split('/')[-1] + image_path = os.path.join(self.root_path, img_file) + image = Image.open(image_path).convert('RGB') + + image = self.vis_processor(image) + question = f"[caption] please describe this image?" + return image, question, image_id diff --git a/minigpt4/datasets/datasets/coco_vqa_datasets.py b/minigpt4/datasets/datasets/coco_vqa_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..6b06828e1af9ac4b93edcd67143c076c05af7961 --- /dev/null +++ b/minigpt4/datasets/datasets/coco_vqa_datasets.py @@ -0,0 +1,184 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import json +import random + +from PIL import Image + +from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset + +from collections import OrderedDict + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "question": ann["question"], + "question_id": ann["question_id"], + "answers": "; ".join(ann["answer"]), + "image": sample["image"], + } + ) + + +class COCOVQADataset(VQADataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.instruction_pool =[ + "[vqa] {}", + "[vqa] Based on the image, respond to this question with a short answer: {}" + ] + + exist_annotation = [] + for ann in self.annotation: + image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + if os.path.exists(image_path): + exist_annotation.append(ann) + self.annotation = exist_annotation + + + def get_data(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + question_id = ann["question_id"] + + answer_weight = {} + for answer in ann["answer"]: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len(ann["answer"]) + else: + answer_weight[answer] = 1 / len(ann["answer"]) + + answers = list(answer_weight.keys()) + weights = list(answer_weight.values()) + + answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights + + if "unk" in answer: + print("cocovqa", answer) + + return { + "image": image, + "question": question, + "question_id": question_id, + "answer": answer, + } + + def __getitem__(self, index): + data = self.get_data(index) + instruction = random.choice(self.instruction_pool).format(data['question']) + instruction = " {} ".format(instruction) + + return { + "image": data['image'], + "question_id": data["question_id"], + "instruction_input": instruction, + "answer": self.text_processor(data['answer']), + } + + +class COCOVQGDataset(COCOVQADataset): + + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + self.instruction_pool = [ + 'Given the image, generate a question whose answer is: {}', + 'Based on the image, provide a question with the answer: {}', + 'Given the visual representation, create a question for which the answer is "{}"', + 'From the image provided, craft a question that leads to the reply: {}', + 'Considering the picture, come up with a question where the answer is: {}', + 'Taking the image into account, generate an question that has the answer: {}' + ] + + def __getitem__(self, index): + data = self.get_data(index) + instruction = random.choice(self.instruction_pool).format(data['answer']) + instruction = " {}".format(instruction) + + return { + "image": data['image'], + "question_id": data["question_id"], + "instruction_input": instruction, + "answer": data['question'], + } + + + +class COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + + self.instruction_pool = [ +# '{}', +# 'Question: {}', +# '{} A short answer to the question is', +# 'Q: {} A:', + 'Question: {} Short answer:', +# 'Given the image, answer the following question with no more than three words. {}', +# 'Based on the image, respond to this question with a short answer: {}.', +# 'Use the provided image to answer the question: {} Provide your answer as short as possible.', +# 'What is the answer to the following question? "{}"', +# 'The question "{}" can be answered using the image. A short answer is' + ] +# print('vis_root', vis_root) + self.vis_root = vis_root + + self.annotation = json.load(open(ann_paths[0])) + + answer_list_path = ann_paths[1] + if os.path.exists(answer_list_path): + self.answer_list = json.load(open(answer_list_path)) + else: + self.answer_list = None + + try: + self.coco_fmt_qust_file = ann_paths[2] + self.coco_fmt_anno_file = ann_paths[3] + except IndexError: + self.coco_fmt_qust_file = None + self.coco_fmt_anno_file = None + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + instruction = random.choice(self.instruction_pool).format(question) + instruction = " {} ".format(instruction) + + return { + "image": image, + 'image_path': image_path, + "question": question, + "question_id": ann["question_id"], + "instruction_input": instruction, + "instance_id": ann["instance_id"], + } diff --git a/minigpt4/datasets/datasets/cot.py b/minigpt4/datasets/datasets/cot.py new file mode 100644 index 0000000000000000000000000000000000000000..3ebe89ef0011c49b71373252302aa2f4d05f9dd1 --- /dev/null +++ b/minigpt4/datasets/datasets/cot.py @@ -0,0 +1,43 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset + + +class CoTDataset(Dataset): + def __init__(self, text_processor, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + + self.text_processor = text_processor + + with open(ann_path, 'r') as f: + self.ann = json.load(f) + + def __len__(self): + return len(self.ann) + + def __getitem__(self, index): + info = self.ann[index] + input = info["inputs"] + target = info["targets"] + return { + "instruction_input": input, + "answer": target, + } diff --git a/minigpt4/datasets/datasets/coyo_dataset.py b/minigpt4/datasets/datasets/coyo_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4b581ce1983983ec35f4ff3db2f8b3479a98529a --- /dev/null +++ b/minigpt4/datasets/datasets/coyo_dataset.py @@ -0,0 +1,469 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset +from minigpt4.datasets.datasets.base_dataset import BaseDataset + + +class COYOCaptionWDSDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json"), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + self.instruction_pool = [ + '[grounding] Briefly describe this image with grounding objects.', + '[grounding] Provide a concise depiction of this image with grounding objects.', + '[grounding] Present a short description of this image with grounding objects.', + '[grounding] Summarize this image in a few words with grounding objects.', + '[grounding] A short image caption with grounding objects:', + '[grounding] A short image description with grounding objects:', + '[grounding] Write a short description for the image with grounding objects.', + '[grounding] Write a description for the photo with grounding objects.', + '[grounding] Briefly describe the content of the image with grounding objects.', + '[grounding] Please provide a short depiction of the picture with grounding objects.', + ] + + # self.instruction_pool = [ + # '[grounding] Briefly describe this image.', + # '[grounding] Provide a concise depiction of this image.', + # '[grounding] Present a short description of this image.', + # '[grounding] Summarize this image in a few words.', + # '[grounding] A short image caption:', + # '[grounding] A short image description:', + # '[grounding] A photo of', + # '[grounding] An image that shows', + # '[grounding] Write a short description for the image.', + # '[grounding] Write a description for the photo.', + # '[grounding] Provide a description of what is presented in the photo.', + # '[grounding] Briefly describe the content of the image.', + # '[grounding] Can you briefly explain what you see in the image?', + # '[grounding] Could you use a few words to describe what you perceive in the photo?', + # '[grounding] Please provide a short depiction of the picture.', + # '[grounding] Using language, provide a short account of the image.', + # '[grounding] Use a few words to illustrate what is happening in the picture.', + # ] + + def generate_ground_caption(self,image_caption, phrases, bounding_boxes): + + grounded_caption = image_caption + + # Iterate over the phrases and bounding boxes + phrase_bbox={} + for phrase, bbox in zip(phrases, bounding_boxes): + # Replace the phrase with the grounded HTML format + # print(phrase, bbox, type(phrase), type(bbox)) + + if phrase not in phrase_bbox.keys(): + grounded_phrase = "

{}

".format(phrase) + grounded_phrase_bbox = grounded_phrase+str(bbox) + else: + grounded_phrase = phrase_bbox[phrase] + + grounded_phrase_bbox = grounded_phrase+""+str(bbox) + + phrase_bbox[phrase] = grounded_phrase_bbox + + + grounded_caption = grounded_caption.replace(phrase, grounded_phrase_bbox) + + return grounded_caption + + + def preprocess_ground_caption(self, sample): + + # info = self.ann["data"][index] + image_id = sample[1]["id"] + + + caption = sample[1]["caption"] + ref_exps = sample[1]["noun_chunks"] + image_size = 100 + + bboxs = [] + ref_phrases = [] + for item in ref_exps: + phrase_start = int(item[0]) + phrase_end = int(item[1]) + + x_min = item[2] + y_min = item[3] + x_max = item[4] + y_max = item[5] + ref_phrase = caption[phrase_start: phrase_end] + + x1 = int(x_min*image_size) + y1 = int(y_min*image_size) + x2 = int(x_max*image_size) + y2 = int(y_max*image_size) + assert x1>=0 and x1<=image_size + assert x2>=0 and x2<=image_size + assert y1>=0 and y1<=image_size + assert y2>=0 and y2<=image_size + # print(x1, y2, x2, y2) + bbox = [str(x1),str(y1),str(x2),str(y2)] + # bbox = "<"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">" + bbox = "{{<{}><{}><{}><{}>}}".format(*bbox) + bboxs.append(bbox) + ref_phrases.append(ref_phrase) + + grounded_caption = self.generate_ground_caption(caption, ref_phrases,bboxs) + + + + return { + "answer": grounded_caption + } + + + def to_dict(self, sample): + data = self.preprocess_ground_caption(sample) + + instruction = random.choice(self.instruction_pool) + instruction = " {} ".format(instruction) + + answer = self.text_processor(data['answer']) + return { + "image": sample[0], + "instruction_input": instruction, + "answer": answer, + } + + + +class COYOBoxToPhraseWDSDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + + self.instruction_pool = [ + "[identify] {}", + "[identify] what object is in this location {}", + "[identify] identify the object present at this location {}", + "[identify] what is it in {}", + "[identify] describe this object in {}", + "[identify] this {} is", + "[identify] the object in {} is", + ] + def bbox_phrase_preprocess(self, sample): + + caption = sample[1]["caption"] + # ref_exps = sample[1]["ref_exps"] + ref_exps = sample[1]["noun_chunks"] + image_size = 100 + + bboxs = [] + ref_phrases = [] + for item in ref_exps: + # print(item) + phrase_start = int(item[0]) + phrase_end = int(item[1]) + + x_min = item[2] + y_min = item[3] + x_max = item[4] + y_max = item[5] + ref_phrase = caption[phrase_start: phrase_end] + + x1 = int(x_min*image_size) + y1 = int(y_min*image_size) + x2 = int(x_max*image_size) + y2 = int(y_max*image_size) + assert x1>=0 and x1<=image_size + assert x2>=0 and x2<=image_size + assert y1>=0 and y1<=image_size + assert y2>=0 and y2<=image_size + + bbox = [str(x1),str(y1),str(x2),str(y2)] + + + # bbox = "<"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">" + bbox = "{{<{}><{}><{}><{}>}}".format(*bbox) + bboxs.append(bbox) + ref_phrases.append(ref_phrase) + + # print(ref_phrase, bbox) + + index = random.randint(0, len(bboxs)-1) + + # Retrieve the corresponding elements + sampled_bbox = bboxs[index] + sampled_phrase = ref_phrases[index] + + return { + "instruction_input": sampled_bbox, + "answer": sampled_phrase, + } + + def to_dict(self, sample): + + data = self.bbox_phrase_preprocess(sample) + + instruction = random.choice(self.instruction_pool).format(data['instruction_input']) + instruction = " {} ".format(instruction) + + answer = self.text_processor(data['answer']) + + return { + "image": sample[0], + "instruction_input": instruction, + "answer": answer, + } + + + +class COYOPhraseToBoxWDSDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + self.instruction_pool = [ + "[refer] {}", + "[refer] give me the location of {}", + "[refer] where is {} ?", + "[refer] from this image, tell me the location of {}", + "[refer] the location of {} is ", + "[refer] could you tell me the location for {}?", + "[refer] where can I locate the {}?", + ] + + # self.instruction_pool = [ + # # "[refer] {}", + # "[refer] give me the bounding box location of {}", + # "[refer] where is bounding box location of {} ?", + # "[refer] from this image, tell me the bounding box location of {}", + # "[refer] the bounding box location of {} is", + # "[refer] could you tell me the bounding box location for {} ?", + # "[refer] where can I locate the bounding box of {} ?", + # ] + def phrase_bbox_preprocess(self, sample): + + caption = sample[1]["caption"] + ref_exps = sample[1]["ref_exps"] + image_size = 100 + + bboxs = [] + ref_phrases = [] + for item in ref_exps: + phrase_start = int(item[0]) + phrase_end = int(item[1]) + + x_min = item[2] + y_min = item[3] + x_max = item[4] + y_max = item[5] + ref_phrase = caption[phrase_start: phrase_end] + + x1 = int(x_min*image_size) + y1 = int(y_min*image_size) + x2 = int(x_max*image_size) + y2 = int(y_max*image_size) + assert x1>=0 and x1<=image_size + assert x2>=0 and x2<=image_size + assert y1>=0 and y1<=image_size + assert y2>=0 and y2<=image_size + + # bbox = "<"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">" + bbox = [str(x1),str(y1),str(x2),str(y2)] + + bbox = "{{<{}><{}><{}><{}>}}".format(*bbox) + bboxs.append(bbox) + ref_phrases.append(ref_phrase) + + index = random.randint(0, len(bboxs)-1) + + # Retrieve the corresponding elements + sampled_bbox = bboxs[index] + sampled_phrase = ref_phrases[index] + + return { + "instruction_input": sampled_phrase, + "answer": sampled_bbox, + } + + + def to_dict(self, sample): + data = self.phrase_bbox_preprocess(sample) + instruction_input = self.text_processor(data['instruction_input']) + instruction = random.choice(self.instruction_pool).format(instruction_input) + instruction = " {} ".format(instruction) + + return { + "image": sample[0], + "instruction_input": instruction, + "answer": data["answer"], + } + + + + +# class COYOBBoxPhraseDataset(Dataset): +# def __init__(self, vis_processor, text_processor, vis_root, ann_path): +# """ +# vis_root (string): Root directory of images (e.g. coco/images/) +# ann_root (string): directory to store the annotation file +# """ +# self.vis_root = vis_root + +# self.vis_processor = vis_processor +# self.text_processor = text_processor + +# self.ann = {"data":[]} + + +# with open(ann_path, 'r') as f: +# for line in f.readlines(): +# line = line.strip() +# # print(line, type(line)) +# try: +# item = json.loads(line.strip()) +# except: +# print(line) +# # print(item) +# assert False + +# # print(item, type(item)) +# # assert False +# self.ann["data"].append(item) + + +# self.bbox_phrase_instruction_pool = [ +# " what object is in this bounding box location {} ", +# " what object is in this location {} ", +# " identify the object present at this location {} ", +# " what is it in bounding box location{} ", +# " describe this object in {} ", +# " this {} is ", +# " the object in {} is ", +# " please tell me what is inside the bounding box position {} ", +# " what can you find in the bounding box area at position {}? ", +# " what is the object occupying this area {} ", +# " could you identify the content within the bounding box located at {} ", +# ] + +# def __len__(self): +# return len(self.ann["data"]) + +# def bbox_phrase_preprocess(self, index): + +# info = self.ann["data"][index] +# image_id = info["id"] + +# image_file = str(image_id)+".jpg" +# image_path = os.path.join(self.vis_root, image_file) +# image = Image.open(image_path).convert("RGB") +# image = self.vis_processor(image) + +# caption = info["caption"] +# ref_exps = info["ref_exps"] +# image_size = 100 + +# bboxs = [] +# ref_phrases = [] +# for item in ref_exps: +# # print(item) +# phrase_start = int(item[0]) +# phrase_end = int(item[1]) + +# x_min = item[2] +# y_min = item[3] +# x_max = item[4] +# y_max = item[5] +# ref_phrase = caption[phrase_start: phrase_end] + +# x1 = int(x_min*image_size) +# y1 = int(y_min*image_size) +# x2 = int(x_max*image_size) +# y2 = int(y_max*image_size) +# assert x1>=0 and x1<=image_size +# assert x2>=0 and x2<=image_size +# assert y1>=0 and y1<=image_size +# assert y2>=0 and y2<=image_size + +# bbox = [str(x1),str(y1),str(x2),str(y2)] + + +# # bbox = "<"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">" +# bbox = "{{<{}><{}><{}><{}>}}".format(*bbox) +# bboxs.append(bbox) +# ref_phrases.append(ref_phrase) + +# # print(ref_phrase, bbox) + +# index = random.randint(0, len(bboxs)-1) + +# # Retrieve the corresponding elements +# sampled_bbox = bboxs[index] +# sampled_phrase = ref_phrases[index] + +# return { +# "image": image, +# "instruction_input": sampled_phrase, +# "answer": sampled_bbox, +# "image_id": info['id'], +# } + + + +# def __getitem__(self, index): + +# data = self.preprocess(index) +# instruction = random.choice(self.instruction_pool).format(data['instruction_input']) +# return { +# "image": data['image'], +# "instruction_input": instruction, +# "answer": data['answer'], +# "image_id": data['image_id'], +# } diff --git a/minigpt4/datasets/datasets/dataloader_utils.py b/minigpt4/datasets/datasets/dataloader_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8eaa3a58b0ad42ca7937fb51b46e53511cc3cd0c --- /dev/null +++ b/minigpt4/datasets/datasets/dataloader_utils.py @@ -0,0 +1,162 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import time +import random +import torch +from minigpt4.datasets.data_utils import move_to_cuda +from torch.utils.data import DataLoader + + +class MultiIterLoader: + """ + A simple wrapper for iterating over multiple iterators. + + Args: + loaders (List[Loader]): List of Iterator loaders. + ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. + """ + + def __init__(self, loaders, ratios=None): + # assert all loaders has __next__ method + for loader in loaders: + assert hasattr( + loader, "__next__" + ), "Loader {} has no __next__ method.".format(loader) + + if ratios is None: + ratios = [1.0] * len(loaders) + else: + assert len(ratios) == len(loaders) + ratios = [float(ratio) / sum(ratios) for ratio in ratios] + + self.loaders = loaders + self.ratios = ratios + + def __next__(self): + # random sample from each loader by ratio + loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] + return next(self.loaders[loader_idx]) + + +class PrefetchLoader(object): + """ + Modified from https://github.com/ChenRocks/UNITER. + + overlap compute and cuda data transfer + (copied and then modified from nvidia apex) + """ + + def __init__(self, loader): + self.loader = loader + self.stream = torch.cuda.Stream() + + def __iter__(self): + loader_it = iter(self.loader) + self.preload(loader_it) + batch = self.next(loader_it) + while batch is not None: + is_tuple = isinstance(batch, tuple) + if is_tuple: + task, batch = batch + + if is_tuple: + yield task, batch + else: + yield batch + batch = self.next(loader_it) + + def __len__(self): + return len(self.loader) + + def preload(self, it): + try: + self.batch = next(it) + except StopIteration: + self.batch = None + return + # if record_stream() doesn't work, another option is to make sure + # device inputs are created on the main stream. + # self.next_input_gpu = torch.empty_like(self.next_input, + # device='cuda') + # self.next_target_gpu = torch.empty_like(self.next_target, + # device='cuda') + # Need to make sure the memory allocated for next_* is not still in use + # by the main stream at the time we start copying to next_*: + # self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.batch = move_to_cuda(self.batch) + # more code for the alternative if record_stream() doesn't work: + # copy_ will record the use of the pinned source tensor in this + # side stream. + # self.next_input_gpu.copy_(self.next_input, non_blocking=True) + # self.next_target_gpu.copy_(self.next_target, non_blocking=True) + # self.next_input = self.next_input_gpu + # self.next_target = self.next_target_gpu + + def next(self, it): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is not None: + record_cuda_stream(batch) + self.preload(it) + return batch + + def __getattr__(self, name): + method = self.loader.__getattribute__(name) + return method + + +def record_cuda_stream(batch): + if isinstance(batch, torch.Tensor): + batch.record_stream(torch.cuda.current_stream()) + elif isinstance(batch, list) or isinstance(batch, tuple): + for t in batch: + record_cuda_stream(t) + elif isinstance(batch, dict): + for t in batch.values(): + record_cuda_stream(t) + else: + pass + + +class IterLoader: + """ + A wrapper to convert DataLoader as an infinite iterator. + + Modified from: + https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py + """ + + def __init__(self, dataloader: DataLoader, use_distributed: bool = False): + self._dataloader = dataloader + self.iter_loader = iter(self._dataloader) + self._use_distributed = use_distributed + self._epoch = 0 + + @property + def epoch(self) -> int: + return self._epoch + + def __next__(self): + try: + data = next(self.iter_loader) + except StopIteration: + self._epoch += 1 + if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: + self._dataloader.sampler.set_epoch(self._epoch) + time.sleep(2) # Prevent possible deadlock during epoch transition + self.iter_loader = iter(self._dataloader) + data = next(self.iter_loader) + + return data + + def __iter__(self): + return self + + def __len__(self): + return len(self._dataloader) diff --git a/minigpt4/datasets/datasets/doc_dataset.py b/minigpt4/datasets/datasets/doc_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..232bb73817074ce701025a49a28ab204f9e4a187 --- /dev/null +++ b/minigpt4/datasets/datasets/doc_dataset.py @@ -0,0 +1,280 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset + + +class SingleSlideVQADataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + self.data = self.create_data(ann_path) + + # self.instruction_pool = [ + # "###Human: {}###Assistant: ", + # "###Human: From this slide, {}###Assistant: ", + # ] + self.instruction_pool = [ + " {}", + " From this slide, {}", + ] + def create_data(self, ann_path): + with open(ann_path, 'r') as f: + samples = f.readlines() + data = [] + for sample in samples: + sample = json.loads(sample) + if len(sample['evidence_pages']) != 1: continue # skip questions that need more than one slide page + page = sample['evidence_pages'][0] + image_name = 'slide_{}_1024.jpg'.format(page) + # assert [int(image_name.split('-')[-2]) for image_name in image_names] == list(range(1, 21)) # check the format + image_path = os.path.join(sample['deck_name'], image_name) + data.append({ + 'qa_id': sample['qa_id'], + 'question': sample['question'], + 'answer': sample['answer'], + 'image_path': image_path + }) + + print("single slide ",len(data)) + return data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB") + image = self.vis_processor(image) + + # instruction = self.text_processor(sample["question"]) + instruction = random.choice(self.instruction_pool).format(self.text_processor(sample["question"])) + + # instruction = random.choice(self.instruction_pool).format(self.text_processor(sample["question"])) + return { + "image": image, + "instruction_input": instruction, + "answer": sample['answer'], + "qa_id": sample['qa_id'], + } + + +class OCRVQADataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + self.data = self.create_data(ann_path) + + self.instruction_pool =[ + "Q: {} A: ", + ] + + def create_data(self, ann_path): + processed_data = [] + with open(ann_path, 'r') as f: + data = json.load(f) + for k in data.keys(): + if data[k]['split'] != 1: continue # 1 for training, 2 for validation, 3 for test + ext = os.path.splitext(data[k]['imageURL'])[1] + imageFile = k + ext + assert len(data[k]['questions']) == len(data[k]['answers']) + for q, a in zip(data[k]['questions'], data[k]['answers']): + processed_data.append( + {'question': q, + 'answer': a, + 'image_path': imageFile, + 'image_id': k, + 'title': data[k]['title'], + 'genre': data[k]['genre'], + } + ) + print("ocr vqa", len(processed_data)) + return processed_data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB") + image = self.vis_processor(image) + question = self.text_processor(sample["question"]) + answer = self.text_processor(sample["answer"]) + + instruction = random.choice(self.instruction_pool).format(question) + instruction = " {} ".format(instruction) + return { + "image": image, + "instruction_input": instruction, + "answer": answer, + "image_id": sample['image_id'] + } + + + + + +class TextOCRDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + self.data = self.create_data(ann_path) + + self.instruction_pool = [ + " [OCR] {}" + ] + + def create_data(self, ann_path): + processed_data = [] + with open(ann_path, 'r') as f: + data = json.load(f) + for k in data["anns"].keys(): + # ext = os.path.splitext(data[k]['imageURL'])[1] + imageFile = data["anns"][k]["image_id"]+".jpg" + bbox = data["anns"][k]["bbox"] + text = data["anns"][k]["utf8_string"] + # assert len(data[k]['questions']) == len(data[k]['answers']) + # for q, a in zip(data[k]['questions'], data[k]['answers']): + + processed_data.append( + {'bbox': bbox, + 'answer': text, + 'image_path': imageFile, + 'image_id': k, + } + ) + + return processed_data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB") + width, height = image.size + image = self.vis_processor(image) + + new_bbox ="" + image_size = 100 + bbox = sample['bbox'] + for index in range(len(bbox)): + + x1 = int(bbox[0]/width*image_size) + y1 = int(bbox[1]/height*image_size) + x2 = x1 + int(bbox[2]/width*image_size) + y2 = y1 + int(bbox[3]/height*image_size) + assert x1>=0 and x1<=image_size + assert x2>=0 and x2<=image_size + assert y1>=0 and y1<=image_size + assert y2>=0 and y2<=image_size + + new_bbox = " <"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">" + + instruction = random.choice(self.instruction_pool).format(new_bbox) + + return { + "image": image, + "instruction_input": instruction, + "answer": sample['answer'], + "image_id": sample['image_id'] + } + + + +class PlotVQADataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + self.data = self.create_data(ann_path) + + self.instruction_pool = [ + 'Q: {} A:', + ] + + def create_data(self, ann_path): + processed_data = [] + with open(ann_path, 'r') as f: + data = json.load(f) + for da in data["qa_pairs"]: + # ext = os.path.splitext(data[k]['imageURL'])[1] + + imageFile = str(da["image_index"])+".png" + question = da["question_string"] + answer = str(da["answer"]) + # assert len(data[k]['questions']) == len(data[k]['answers']) + # for q, a in zip(data[k]['questions'], data[k]['answers']): + + processed_data.append( + {'question': question, + 'answer': answer, + 'image_path': imageFile, + 'image_id': str(da["image_index"]), + } + ) + + return processed_data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB") + # width, height = image.size + image = self.vis_processor(image) + + + # image_shape = image.shape + instruction = " {} ".format(sample["question"]) + + instruction = random.choice(self.instruction_pool).format(instruction) + + answer = sample["answer"] + + + return { + "image": image, + "instruction_input": instruction, + "answer": answer, + "image_id": sample['image_id'] + } + diff --git a/minigpt4/datasets/datasets/gqa_datasets.py b/minigpt4/datasets/datasets/gqa_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..610d6a90b49772c330165a80bb2827bd1a1c9d33 --- /dev/null +++ b/minigpt4/datasets/datasets/gqa_datasets.py @@ -0,0 +1,130 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import json + +from PIL import Image + +from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset + +from collections import OrderedDict +import random + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "question": ann["question"], + "question_id": ann["question_id"], + "answers": "; ".join(ann["answer"]), + "image": sample["image"], + } + ) + + +class GQADataset(VQADataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + self.instruction_pool =[ + "[vqa] {}", + "[vqa] Based on the image, respond to this question with a short answer: {}" + ] + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + instruction = random.choice(self.instruction_pool).format(question) + instruction = " {} ".format(instruction) + + answers = self.text_processor(ann["answer"]) + if "unk" in answers: + print("gqa",answers) + + # print(answers) + + return { + "image": image, + "instruction_input": instruction, + "answer": answers, + # "weights": weights, + } + + +class GQAEvalDataset(VQAEvalDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. gqa/images/) + ann_root (string): directory to store the annotation file + """ + + self.instruction_pool = [ +# '{}', +# 'Question: {}', +# '{} A short answer to the question is', +# 'Q: {} A:', + # '[vqa] Question: {} Short answer:', + "[vqa] Based on the image, respond to this question with a short answer: {}" +# 'Given the image, answer the following question with no more than three words. {}', +# 'Based on the image, respond to this question with a short answer: {}.', +# 'Use the provided image to answer the question: {} Provide your answer as short as possible.', +# 'What is the answer to the following question? "{}"', +# 'The question "{}" can be answered using the image. A short answer is' + ] + + self.vis_root = vis_root + + self.annotation = json.load(open(ann_paths[0])) + + ## TODO: support inference method == 'ranking' + answer_list_path = ann_paths[1] if len(ann_paths) > 1 else '' + if os.path.exists(answer_list_path): + self.answer_list = json.load(open(answer_list_path)) + else: + self.answer_list = None + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + instruction = random.choice(self.instruction_pool).format(question) + instruction = " {} ".format(instruction) + + if "answer" in ann: + # answer is a string + answer = ann["answer"] + else: + answer = None + + return { + "image": image, + "text_input": question, + "answer": answer, + 'image_path': image_path, + "instruction_input": instruction, + "question_id": ann["question_id"], + "instance_id": ann["instance_id"], + } diff --git a/minigpt4/datasets/datasets/grounded_caption_reasoning.py b/minigpt4/datasets/datasets/grounded_caption_reasoning.py new file mode 100644 index 0000000000000000000000000000000000000000..0ee511b6e78d85d3783014d9bdfbcb9e02397d04 --- /dev/null +++ b/minigpt4/datasets/datasets/grounded_caption_reasoning.py @@ -0,0 +1,92 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds + +from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset + +from collections import OrderedDict + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "question": ann["question"], + "question_id": ann["question_id"], + "answers": "; ".join(ann["answer"]), + "image": sample["image"], + } + ) + + +class GroundedCaptionReasonDataset(VQADataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.instruction_pool =[ + "[vqa] {}" + ] + + exist_annotation = [] + for ann in self.annotation: + image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + if os.path.exists(image_path): + exist_annotation.append(ann) + self.annotation = exist_annotation + + + def get_data(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + question_id = ann["question_id"] + + answer_weight = {} + for answer in ann["answer"]: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len(ann["answer"]) + else: + answer_weight[answer] = 1 / len(ann["answer"]) + + answers = list(answer_weight.keys()) + weights = list(answer_weight.values()) + + answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights + + return { + "image": image, + "question": question, + "question_id": question_id, + "answer": answer, + } + + def __getitem__(self, index): + data = self.get_data(index) + instruction = random.choice(self.instruction_pool).format(data['question']) + instruction = " {}".format(instruction) + + return { + "image": data['image'], + "question_id": data["question_id"], + "instruction_input": instruction, + "answer": data['answer'], + } diff --git a/minigpt4/datasets/datasets/grounded_detailed_image_caption_dataset.py b/minigpt4/datasets/datasets/grounded_detailed_image_caption_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2cac659ff203b40e973aa161e6d7ddab854607b9 --- /dev/null +++ b/minigpt4/datasets/datasets/grounded_detailed_image_caption_dataset.py @@ -0,0 +1,64 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset + + +class GroundedDetailDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self.instruction_pool = [ + '[grounding] please describe this image in details', + '[grounding] describe this image as detailed as possible', + '[grounding] summarize this image in details', + '[grounding] give a thorough description of what you see in this image', + ] + + with open(ann_path, 'r') as f: + self.ann = json.load(f) + + def __len__(self): + return len(self.ann) + + def __getitem__(self, index): + info = self.ann[index] + + image_file = 'COCO_train2014_{}.jpg'.format(info['image_id']) + image_path = os.path.join(self.vis_root, image_file) + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image) + + answer = info['grounded_caption'] + + instruction = random.choice(self.instruction_pool) + + instruction = " {} ".format(instruction) + + return { + "image": image, + "instruction_input": instruction, + "answer": answer, + "image_id": info['image_id'], + } diff --git a/minigpt4/datasets/datasets/laion_dataset.py b/minigpt4/datasets/datasets/laion_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2ee07bb3b9da76eec7329e7af1268c7d0a87216a --- /dev/null +++ b/minigpt4/datasets/datasets/laion_dataset.py @@ -0,0 +1,57 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import random + +import webdataset as wds +from minigpt4.datasets.datasets.base_dataset import BaseDataset + + +class LaionDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + self.instruction_pool = [ + 'Briefly describe this image.', + 'Provide a concise depiction of this image.', + 'Present a short description of this image.', + 'Summarize this image in a few words.', + 'A short image caption:', + 'A short image description:', + 'A photo of ', + 'An image that shows ', + 'Write a short description for the image. ', + 'Write a description for the photo.', + 'Provide a description of what is presented in the photo.', + 'Briefly describe the content of the image.', + 'Can you briefly explain what you see in the image?', + 'Could you use a few words to describe what you perceive in the photo?', + 'Please provide a short depiction of the picture.', + 'Using language, provide a short account of the image.', + 'Use a few words to illustrate what is happening in the picture.', + ] + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + instruction = random.choice(self.instruction_pool) + + # instruction = "###Human: {}###Assistant: ".format(instruction) + instruction = " [caption] {} ".format(instruction) + + return { + "image": sample[0], + "instruction_input": instruction, + "answer": self.text_processor(sample[1]["caption"]), + } + diff --git a/minigpt4/datasets/datasets/llava_dataset.py b/minigpt4/datasets/datasets/llava_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd728b439beb4b812d0157e2941b1d418faabb2 --- /dev/null +++ b/minigpt4/datasets/datasets/llava_dataset.py @@ -0,0 +1,158 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset + + +class LlavaDetailDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + + with open(ann_path, 'r') as f: + self.ann = json.load(f) + + def __len__(self): + return len(self.ann) + + def __getitem__(self, index): + info = self.ann[index] + + image_file = 'COCO_train2014_{}.jpg'.format(info['id']) + image_path = os.path.join(self.vis_root, image_file) + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image) + + answer = info['conversations'][1]['value'] + instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip() + + instruction = ' {} '.format(self.text_processor(instruction)) + + return { + "image": image, + "instruction_input": instruction, + "answer": answer, + "image_id": info['id'], + } + +class LlavaReasonDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + + with open(ann_path, 'r') as f: + self.ann = json.load(f) + + def __len__(self): + return len(self.ann) + + def __getitem__(self, index): + info = self.ann[index] + + image_file = 'COCO_train2014_{}.jpg'.format(info['id']) + image_path = os.path.join(self.vis_root, image_file) + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image) + + answer = info['conversations'][1]['value'] + instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip() + + instruction = ' {} '.format(self.text_processor(instruction)) + + # instruction = ' {} '.format(self.text_processor(instruction)) + # answer = self.text_processor(answer) + + return { + "image": image, + "instruction_input": instruction, + "answer": answer, + "image_id": info['id'], + } + + + + + +class LlavaConversationDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path, template=['[INST]', '[\INST]']): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self.human_tag = r'[INST]' + self.assistant_tag = r"[\INST]" + + + with open(ann_path, 'r') as f: + self.ann = json.load(f) + + self.connect_sym = "!@#" + + def __len__(self): + return len(self.ann) + + def __getitem__(self, index): + info = self.ann[index] + + image_file = 'COCO_train2014_{}.jpg'.format(info['id']) + image_path = os.path.join(self.vis_root, image_file) + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image) + + first_instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip() + first_instruction = ' {} '.format(first_instruction) + + questions = [first_instruction] + answers = [] + + for i, item in enumerate(info["conversations"][1:]): + if i % 2 ==0: # assistant + assistant_answer = item["value"] + answers.append(assistant_answer) + else: + human_instruction = item["value"] + questions.append(human_instruction) + + questions = self.connect_sym.join(questions) + # questions = questions.replace("\\\\","\\") + answers = self.connect_sym.join(answers) + + + return { + "image": image, + "conv_q": questions, + 'conv_a': answers, + "image_id": info['id'], + "connect_sym": self.connect_sym + } \ No newline at end of file diff --git a/minigpt4/datasets/datasets/locna_dataset.py b/minigpt4/datasets/datasets/locna_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..07febaa01f95618b622e763e0e89099eb9ac136e --- /dev/null +++ b/minigpt4/datasets/datasets/locna_dataset.py @@ -0,0 +1,68 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset + + +class LocNaCOCODataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths, min_len=60): + self.vis_root = vis_root + self.vis_processor = vis_processor + self.text_processor = text_processor + self.min_len = min_len + self.data = self.create_data(ann_paths) + + self.instruction_pool = [ + ' Describe this image in detail.', + ' Take a look at this image and describe what you notice.', + ' Please provide a detailed description of the picture.', + ' Could you describe the contents of this image for me?' + ] + + def create_data(self, ann_paths): + raw_data = [] + for ann_path in ann_paths: + with open(ann_path, 'r') as f: + raw_data.extend([json.loads(line) for line in f]) + + data = [] + for d in raw_data: + if len(d['caption'].split(' ')) < 60: continue + data.append( + {'caption': d['caption'], + 'image_path': '{:012d}.jpg'.format(int(d['image_id'])) + } + ) + return data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB") + image = self.vis_processor(image) + instruction = random.choice(self.instruction_pool) + instruction = "###Human: {} ###Assistant: ".format(instruction) + + return { + "image": image, + "instruction_input": instruction, + "answer": sample['caption'], + } + + diff --git a/minigpt4/datasets/datasets/lvis_dataset.py b/minigpt4/datasets/datasets/lvis_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9d3b3c92cda29c400ca50816e2d4d1d61b8439 --- /dev/null +++ b/minigpt4/datasets/datasets/lvis_dataset.py @@ -0,0 +1,202 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset + + +def sample_object_bbox(objects, bbox): + + + + zipped_list = list(zip(objects, bbox)) + + # Shuffle the zipped list + random.shuffle(zipped_list) + + # Generate the new string with interleaved format + # interleaved_list = str([{'{},{}'.format(obj, str(bbox).replace("[","").replace("]","") )} for obj, bbox in zipped_list]) + + # print("objects", objects) + # print("bbox",bbox) + + interleaved_list = str([{'{},{}'.format(obj, bbox.strip())} for obj, bbox in zipped_list]).replace("'","").replace("[","").replace("]","") + + # interleaved_list = " "+interleaved_list + # print(interleaved_list) + return interleaved_list + +def bbox_to_object(objects, bbox): + + index_sample = random.sample(range(len(objects)),1)[0] + + sample_object = str(objects[index_sample]) + sample_bbox = bbox[index_sample] + # sample_center_point = center_point[index_sample] + + sample_bbox = r"{"+str(sample_bbox) + "}" + return sample_bbox, sample_object + +def object_to_bbox(objects, bbox, center_point): + index_sample = random.sample(range(len(objects)),1)[0] + + sample_object = objects[index_sample] + sample_bbox = bbox[index_sample] + sample_center_point = center_point[index_sample] + + instruction = "what is object and the bounding box in the center coordinate of "+str(sample_center_point)+"? " + answer = "{"+str(sample_object)+","+str(sample_bbox)+"}" + + + + return instruction, answer + + +class LVISBBOXDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + objects = sample[1]["objects"] + boxes = sample[1]["bbox"] + + + new_bboxes = [] + + image_size = sample[0].shape[1] + image_size = 100 + for index in range(len(boxes)): + box = boxes[index] + x1 = int(box[0]*image_size) + y1 = int(box[1]*image_size) + x2 = x1 + int(box[2]*image_size) + y2 = y1 + int(box[3]*image_size) + assert x1>=0 and x1<=image_size + assert x2>=0 and x2<=image_size + assert y1>=0 and y1<=image_size + assert y2>=0 and y2<=image_size + + new_bbox = " <"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">" + # new_bbox = " <"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">" + new_bboxes.append(new_bbox) + + instruction = r"Given an image, identify the objects and their bounding boxes in the format of {object,x1 y1 x2 y2}. " + instruction = " {}".format(self.text_processor(instruction)) + + answer = sample_object_bbox(objects, new_bboxes) + + # print("instruction",instruction) + # print("answer", answer) + + return { + "image": sample[0], + "instruction_input": instruction, + "answer": self.text_processor(answer), + "data_type": "bbox", + "question_split": True + } + + +class LVISBboxToObjectDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + # self.instruction_pool = [ + # "###Human: what object is in this bounding box location {}###Assistant: ", + # "###Human: what object is in this location {}###Assistant: ", + # "###Human: identify the object present at this location {}###Assistant: ", + # "###Human: what is it in bounding box location{}###Assistant: ", + # "###Human: describe this object in {} ###Assistant: ", + # "###Human: this {} is ###Assistant: ", + # "###Human: the object in {} is ###Assistant: ", + # "###Human: please tell me what is inside the bounding box position {} ###Assistant: ", + # "###Human: what can you find in the bounding box area at position {}? ###Assistant: ", + # "###Human: what is the object occupying this bbox area {}###Assistant: ", + # "###Human: could you identify the content within the bounding box located at {}###Assistant: ", + # ] + + + self.instruction_pool = [ + "what object is in this bounding box location {} ", + "what object is in this location {} ", + "identify the object present at this location {} ", + "what is it in bounding box location{} ", + "describe this object in {} ", + "this {} is ", + "the object in {} is ", + "please tell me what is inside the bounding box position {} ", + "what can you find in the bounding box area at position {}? ", + "what is the object occupying this area {} ", + "could you identify the content within the bounding box located at {} ", + ] + def to_dict(self, sample): + + objects = sample[1]["objects"] + boxes = sample[1]["bbox"] + + new_bboxes = [] + + image_size = sample[0].shape[1] + image_size= 100 + for index in range(len(boxes)): + box = boxes[index] + x1 = int(box[0]*image_size) + y1 = int(box[1]*image_size) + x2 = x1 + int(box[2]*image_size) + y2 = y1 + int(box[3]*image_size) + assert x1>=0 and x1<=image_size + assert x2>=0 and x2<=image_size + assert y1>=0 and y1<=image_size + assert y2>=0 and y2<=image_size + + new_bbox = "<"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">" + new_bboxes.append(new_bbox) + + bbox, object = bbox_to_object(objects, new_bboxes) + instruction = random.choice(self.instruction_pool).format(bbox) + + # instruction = "###Human: {} ###Assistant: ".format(instruction) + + instruction = " {} ".format(instruction) + + return { + "image": sample[0], + "instruction_input": instruction, + "answer": self.text_processor(object), + "data_type": "bbox", + "question_split": True + } + + diff --git a/minigpt4/datasets/datasets/nav_dataset.py b/minigpt4/datasets/datasets/nav_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7ea27cf02ab0ac817211d285eea9c6cc1e67536b --- /dev/null +++ b/minigpt4/datasets/datasets/nav_dataset.py @@ -0,0 +1,69 @@ +import os +import json +import pickle +import math +import random +import glob +import torch +import time +import itertools + +from torch.utils.data import Dataset +from PIL import Image, ImageDraw + + +class NavR2RDataset(Dataset): + def __init__(self, vis_processor, text_processor, data_root): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.data_root = data_root + self.data_ids = [subfolder.split('/')[-1] for subfolder in glob.glob(os.path.join(self.data_root, '*'))] + + self.vis_processor = vis_processor + self.text_processor = text_processor + self.connect_sym = "!@#" + + def __len__(self): + return len(self.data_ids) + + def preprocess(self, index): + data_id = self.data_ids[index] + with open(os.path.join(self.data_root, data_id, 'data.json'), 'r') as f: + meta_data = json.load(f) + + instructions = meta_data['instructions'] + actions = meta_data['action'] + + frames = [] + for i in range(meta_data['n_steps']): + image_path = os.path.join(self.data_root, data_id, '{}.jpg'.format(i)) + frames.append(self.vis_processor(Image.open(image_path).convert("RGB"))) + + return { + "frames": frames, + "instructions": instructions, + "actions": actions, + "data_id": data_id, + } + + def __getitem__(self, index): + data = self.preprocess(index) + instruction = random.choice(data['instructions']) + instruction = "Command: {}\n\n".format(instruction) + + obs = self.connect_sym.join([' A: ' for _ in data['actions']]) + obs = instruction + obs + act = self.connect_sym.join(data['actions']) + + stacked_frames = torch.stack(data["frames"][:-1], dim=0) + + return { + "image": stacked_frames, + "conv_q": obs, + "conv_a": act, + "connect_sym": self.connect_sym, + "data_id": data['data_id'], + } + \ No newline at end of file diff --git a/minigpt4/datasets/datasets/open_images.py b/minigpt4/datasets/datasets/open_images.py new file mode 100644 index 0000000000000000000000000000000000000000..6d603656d70cc26c04b9bffdc01103cfae4b6922 --- /dev/null +++ b/minigpt4/datasets/datasets/open_images.py @@ -0,0 +1,192 @@ +import os +from PIL import Image +import webdataset as wds +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset +import json +import random +from webdataset import select + + +def sample_object_bbox(objects, bbox): + + + + zipped_list = list(zip(objects, bbox)) + + # Shuffle the zipped list + random.shuffle(zipped_list) + + # Generate the new string with interleaved format + # interleaved_list = str([{'{},{}'.format(obj, str(bbox).replace("[","").replace("]","") )} for obj, bbox in zipped_list]) + + # print("objects", objects) + # print("bbox",bbox) + + interleaved_list = str([{'{},{}'.format(obj, bbox.strip())} for obj, bbox in zipped_list]).replace("'","").replace("[","").replace("]","") + + # interleaved_list = " "+interleaved_list + # print(interleaved_list) + + return interleaved_list + +def bbox_to_object(objects, bbox): + + index_sample = random.sample(range(len(objects)),1)[0] + + sample_object = str(objects[index_sample]) + sample_bbox = bbox[index_sample] + # sample_center_point = center_point[index_sample] + + sample_bbox = r"{"+str(sample_bbox) + "}" + return sample_bbox, sample_object + +def object_to_bbox(objects, bbox, center_point): + index_sample = random.sample(range(len(objects)),1)[0] + + sample_object = objects[index_sample] + sample_bbox = bbox[index_sample] + sample_center_point = center_point[index_sample] + + instruction = "what is object and the bounding box in the center coordinate of "+str(sample_center_point)+"? " + answer = "{"+str(sample_object)+","+str(sample_bbox)+"}" + + + + return instruction, answer + + +class OpenImageDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + print("open Image dataset") + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + + def to_dict(self, sample): + + objects = sample[1]["objects"] + boxes = sample[1]["bbox"] + + new_bboxes = [] + + image_size = sample[0].shape[1] + image_size = 100 + for index in range(len(boxes)): + box = boxes[index] + x1 = int(box[0]*image_size) + y1 = int(box[1]*image_size) + x2 = x1 + int(box[2]*image_size) + y2 = y1 + int(box[3]*image_size) + assert x1>=0 and x1<=image_size + assert x2>=0 and x2<=image_size + assert y1>=0 and y1<=image_size + assert y2>=0 and y2<=image_size + + new_bbox = "<"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">" + new_bboxes.append(new_bbox) + + + instruction = r"Given an image, identify the objects and their bounding boxes in the format of {object,x1 y1 x2 y2}. " + instruction = " {} ".format( self.text_processor(instruction)) + + + answer = sample_object_bbox(objects, new_bboxes) + + return { + "image": sample[0], + "instruction_input": instruction, + "answer": self.text_processor(answer), + "data_type": "bbox", + "question_split": True + } + + + + + + +class OpenBboxToObjectDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + # self.instruction_pool = [ + # "###Human: what object is in this bounding box location {}###Assistant: ", + # "###Human: what object is in this location {}###Assistant: ", + # "###Human: identify the object present at this location {}###Assistant: ", + # "###Human: what is it in bounding box location{}###Assistant: ", + # "###Human: describe this object in {} ###Assistant: ", + # "###Human: this {} is ###Assistant: ", + # "###Human: the object in {} is ###Assistant: ", + # "###Human: please tell me what is inside the bounding box position {} ###Assistant: ", + # "###Human: what can you find in the bounding box area at position {}? ###Assistant: ", + # "###Human: what is the object occupying this bbox area {}###Assistant: ", + # "###Human: could you identify the content within the bounding box located at {}###Assistant: ", + # ] + + self.instruction_pool = [ + " what object is in this bounding box location {} ", + " what object is in this location {} ", + " identify the object present at this location {} ", + " what is it in bounding box location{} ", + " describe this object in {} ", + " this {} is ", + " the object in {} is ", + " please tell me what is inside the bounding box position {} ", + " what can you find in the bounding box area at position {}? ", + " what is the object occupying this area {} ", + " could you identify the content within the bounding box located at {} ", + ] + def to_dict(self, sample): + + objects = sample[1]["objects"] + boxes = sample[1]["bbox"] + + new_bboxes = [] + + image_size = sample[0].shape[1] + image_size=100 + for index in range(len(boxes)): + box = boxes[index] + x1 = int(box[0]*image_size) + y1 = int(box[1]*image_size) + x2 = x1 + int(box[2]*image_size) + y2 = y1 + int(box[3]*image_size) + assert x1>=0 and x1<=image_size + assert x2>=0 and x2<=image_size + assert y1>=0 and y1<=image_size + assert y2>=0 and y2<=image_size + + new_bbox = "<"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">" + new_bboxes.append(new_bbox) + + bbox, object = bbox_to_object(objects, new_bboxes) + instruction = random.choice(self.instruction_pool).format(bbox) + return { + "image": sample[0], + "instruction_input": instruction, + "answer": self.text_processor(object), + "data_type": "bbox", + "question_split": True + } + + diff --git a/minigpt4/datasets/datasets/paint_dataset.py b/minigpt4/datasets/datasets/paint_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e842b7f486a148224f5235cc9f8e366dc7f4793e --- /dev/null +++ b/minigpt4/datasets/datasets/paint_dataset.py @@ -0,0 +1,600 @@ +import os +import json +import pickle +import math +import random +import glob + +import numpy as np +import torch +import time +import cv2 + +from torch.utils.data import Dataset +from PIL import Image, ImageDraw +import cv2 +from pycocotools.coco import COCO + +from minigpt4.datasets.datasets.base_dataset import BaseDataset + + +def pt_paint(strokes, num_steps=999): + # Create a black canvas + img = Image.new('RGB', (256, 256), color='black') + draw = ImageDraw.Draw(img) + max_steps = len(strokes) + num_steps = min(num_steps, max_steps) + + for i in range(0, num_steps): + stroke = strokes[i] + + x = stroke[0] + y = stroke[1] + w = stroke[2] + h = stroke[3] + theta = stroke[4] * 180 + rgb = tuple(int(val * 255) for val in stroke[5:8]) # Scale RGB values to 0-255 + + # Convert degrees to radians for rotation + angle_rad = theta * (3.141592653589793 / 180.0) + cos_val = math.cos(angle_rad) + sin_val = math.sin(angle_rad) + + # Calculate the coordinates of the rectangle vertices after rotation + x1 = x - w/2 + y1 = y - h/2 + x2 = x + w/2 + y2 = y - h/2 + x3 = x + w/2 + y3 = y + h/2 + x4 = x - w/2 + y4 = y + h/2 + + # Rotate the rectangle coordinates + x1_new = cos_val * (x1 - x) - sin_val * (y1 - y) + x + y1_new = sin_val * (x1 - x) + cos_val * (y1 - y) + y + x2_new = cos_val * (x2 - x) - sin_val * (y2 - y) + x + y2_new = sin_val * (x2 - x) + cos_val * (y2 - y) + y + x3_new = cos_val * (x3 - x) - sin_val * (y3 - y) + x + y3_new = sin_val * (x3 - x) + cos_val * (y3 - y) + y + x4_new = cos_val * (x4 - x) - sin_val * (y4 - y) + x + y4_new = sin_val * (x4 - x) + cos_val * (y4 - y) + y + + # Draw the rotated rectangle + draw.polygon([(x1_new, y1_new), (x2_new, y2_new), (x3_new, y3_new), (x4_new, y4_new)], fill=rgb) + + return img + + +def pt_stroke2str(single_stroke): + x, y, w, h, theta, r, g, b = single_stroke + theta = theta * 180 + r, g, b = r * 255, g * 255, b * 255 + param = [x, y, w, h, theta, r, g, b] + param = ','.join([str(int(i)) for i in param]) + + str_stroke = '({})'.format(param) + return str_stroke + + +class PaintPTCOCODataset(Dataset): + def __init__(self, vis_processor, text_processor, img_root, stroke_root, max_step=200): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.img_root = img_root + self.stroke_root = stroke_root + self.image_ids = [file.split('/')[-1].split('.')[0] + for file in glob.glob(os.path.join(self.stroke_root, '*.pkl'))] + self.max_step = max_step + self.vis_processor = vis_processor + self.text_processor = text_processor + + def __len__(self): + return len(self.image_ids) + + def preprocess(self, index, step=-1): + image_id = self.image_ids[index] + with open(os.path.join(self.stroke_root, '{}.pkl'.format(image_id)), "rb") as f: + strokes_dict = pickle.load(f) + + strokes = np.concatenate(strokes_dict['strokes'], axis=0) + if step < 0: + step = random.randint(0, min(len(strokes) - 1, self.max_step)) + canvas = pt_paint(strokes, num_steps=step) + next_stroke = strokes[step] + + image_file = '{}.jpg'.format(image_id) + image_path = os.path.join(self.img_root, image_file) + orig_image = Image.open(image_path).convert("RGB") + + return { + "orig_image": orig_image, + "canvas": canvas, + "next_stroke": pt_stroke2str(next_stroke), + "image_id": image_id, + } + + def __getitem__(self, index): + data = self.preprocess(index) + orig_image = self.vis_processor(data['orig_image']) + canvas = self.vis_processor(data['canvas']) + instruction = " Next Stroke: " + + return { + "image": torch.stack([orig_image, canvas], dim=0), + "instruction_input": instruction, + "answer": data['next_stroke'], + "image_id": data['image_id'], + "length": 2 + } + + +def normal(x, width): + return (int)(x * (width - 1) + 0.5) + + +def draw(f, canvas=None, width=128, res=100): + x0, y0, x1, y1, x2, y2, z0, z2, w0, w2, b, g, r = [float(i) for i in f] + x1 = x0 + (x2 - x0) * x1 + y1 = y0 + (y2 - y0) * y1 + x0 = normal(x0, width) + x1 = normal(x1, width) + x2 = normal(x2, width) + y0 = normal(y0, width) + y1 = normal(y1, width) + y2 = normal(y2, width) + z0 = (int)(1 + z0 * width // 4) + z2 = (int)(1 + z2 * width // 4) + if canvas is None: + canvas = np.zeros([width, width, 4]) + tmp = 1. / res + for i in range(res): + t = i * tmp + x = (int)((1-t) * (1-t) * x0 + 2 * t * (1-t) * x1 + t * t * x2) + y = (int)((1-t) * (1-t) * y0 + 2 * t * (1-t) * y1 + t * t * y2) + z = (int)((1-t) * z0 + t * z2) + # w = (1-t) * w0 + t * w2 + w = 1 + + cv2.circle(canvas, (y, x), z, [w, r * w, g * w, b * w], -1) + + return canvas + + +def rl_decode(x, canvas, res=100): + stroke = [] + color_stroke = [] + for step in range(x.shape[1]): + stroke_canvas = np.zeros([canvas.shape[-1], canvas.shape[-1], 4], dtype=np.float32) # alpha, alpha * r, alpha * g, alpha * b + for idx in range(x.shape[0]): + stroke_canvas = draw(x[idx, step], canvas=stroke_canvas, width=canvas.shape[-1], res=res) + stroke_canvas = stroke_canvas.transpose(2, 0, 1) + stroke.append(stroke_canvas[:1]) + color_stroke.append(stroke_canvas[1:]) + + for i in range(len(stroke)): + canvas = canvas * (1 - stroke[i]) + color_stroke[i] + return canvas + + +def rel2abs(strokes, n_d=4): + abs_strokes = [] + for i, stroke in enumerate(strokes): + yi = i % n_d + xi = i // n_d + stroke = np.stack([ + stroke[:, 0] / n_d + xi / n_d, + stroke[:, 1] / n_d + yi / n_d, + stroke[:, 2] / n_d + xi / n_d, + stroke[:, 3] / n_d + yi / n_d, + stroke[:, 4] / n_d + xi / n_d, + stroke[:, 5] / n_d + yi / n_d, + stroke[:, 6] / n_d, + stroke[:, 7] / n_d, + stroke[:, 8], + stroke[:, 9], + stroke[:, 10], + stroke[:, 11], + stroke[:, 12], + ], axis=1) + abs_strokes.append(stroke) + abs_strokes = np.stack(abs_strokes) + return abs_strokes + + +def rl_paint(strokes_dict, step, width=256, single_stroke=False): + canvas = np.zeros([1, 3, width, width], dtype=np.float32) + + if_fine_strokes = [int(len(strokes.shape) > 2) for strokes in strokes_dict['strokes']] + if single_stroke: + n_steps = (len(if_fine_strokes) - sum(if_fine_strokes)) * 5 + 16 * 5 * sum(if_fine_strokes) + else: + n_steps = len(if_fine_strokes) + 4 * sum(if_fine_strokes) + + step = min(step, n_steps-1) + + for strokes in strokes_dict['strokes']: + + strokes = strokes.astype(np.float32) + if len(strokes.shape) < 3: # coarse stage. shape 5, 13 + if single_stroke: # 1 stroke per step + actions_list = [stroke[None, None] for stroke in strokes] + else: # 5 strokes per step + actions_list = [strokes[None]] + else: # fine stage. shape 16, 5, 13 + strokes = rel2abs(strokes) + + if single_stroke: # 1 stroke per step + strokes = strokes.transpose(1, 0, 2) + actions_list = [stroke[None, None] for step_strokes in strokes for stroke in step_strokes] + + else: # 16 strokes per step. each variable strokes contains 5 steps + actions_list = [strokes[:, i:i+1] for i in range(strokes.shape[1])] + + for actions in actions_list: + if step > 0: + canvas = rl_decode(actions, canvas, res=100) + step = step - 1 + else: + next_stroke = actions + return canvas, next_stroke + + raise StopIteration + + +def rl_stroke2str(action): + a, b, _ = action.shape + + if a == 1 and b == 5: # coarse step, contains 5 strokes + action = action[0] # 5 x 13 + tag = '[coarse]' + elif a == 16 and b == 1: # fine step. contains 16 strokes + action = action[:, 0] # 16 x 13 + tag = '[detail]' + elif a == 1 and b == 1: + action = action[0] + tag = '' + else: + raise ValueError + + strokes = [] + for i, stroke in enumerate(action): + stroke = [str(int(i * 255)) for i in stroke] + stroke = ",".join(stroke) + stroke = "{}({})".format(i, stroke) + strokes.append(stroke) + strokes = ';'.join(strokes) + strokes = tag + strokes + + return strokes + + +def rlo_stroke2str(action): + a, b, _ = action.shape + + if a == 1 and b == 5: # coarse step, contains 5 strokes + action = action[0] # 5 x 13 + tag = '[coarse]' + elif a == 16 and b == 1: # fine step. contains 16 strokes + action = action[:, 0] # 16 x 13 + tag = '[detail]' + elif a == 1 and b == 1: + action = action[0] + tag = '' + else: + raise ValueError + + strokes = [] + + for i, stroke in enumerate(action): + x0, y0, x1, y1, x2, y2, z0, z2, w0, w2, b, g, r = stroke + stroke = [x0, y0, x1, y1, x2, y2, z0, z2, b, g, r] # remove unused transparancy + stroke = [str(int(i * 255)) for i in stroke] + stroke = ",".join(stroke) + stroke = "{}({})".format(i, stroke) + strokes.append(stroke) + strokes = ';'.join(strokes) + strokes = tag + strokes + + return strokes + + +class PaintRLCOCODataset(Dataset): + def __init__(self, vis_processor, text_processor, img_root, stroke_root, single_stroke=False, max_step=50): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.img_root = img_root + self.stroke_root = stroke_root + self.image_ids = [file.split('/')[-1].split('.')[0] + for file in glob.glob(os.path.join(self.stroke_root, '*.pkl'))] + self.max_step = max_step + self.vis_processor = vis_processor + self.text_processor = text_processor + self.single_stroke=single_stroke + self.width = 256 + + def __len__(self): + return len(self.image_ids) + + def preprocess(self, index, step=-1): + image_id = self.image_ids[index] + image_file = '{}.jpg'.format(image_id) + image_path = os.path.join(self.img_root, image_file) + orig_image = Image.open(image_path).convert("RGB") + + with open(os.path.join(self.stroke_root, '{}.pkl'.format(image_id)), "rb") as f: + strokes_dict = pickle.load(f) + + if_fine_strokes = [int(len(strokes.shape) > 2) for strokes in strokes_dict['strokes']] + if self.single_stroke: + n_steps = (len(if_fine_strokes) - sum(if_fine_strokes)) * 5 + 16 * 5 * sum(if_fine_strokes) + else: + n_steps = len(if_fine_strokes) + 4 * sum(if_fine_strokes) + + if step < 0: + step = random.randint(0, min(n_steps - 1, self.max_step)) + + canvas, next_stroke = rl_paint(strokes_dict, step, width=self.width, single_stroke=self.single_stroke) + canvas = Image.fromarray((canvas[0].transpose(1, 2, 0) * 255).astype(np.uint8)) + + return { + "orig_image": orig_image, + "canvas": canvas, + "next_stroke": rl_stroke2str(next_stroke), + "image_id": image_id, + } + + def __getitem__(self, index): + data = self.preprocess(index) + orig_image = self.vis_processor(data['orig_image']) + canvas = self.vis_processor(data['canvas']) + instruction = " Action: " + + return { + "image": torch.stack([orig_image, canvas], dim=0), + "instruction_input": instruction, + "answer": data['next_stroke'], + "image_id": data['image_id'], + "length": 2 + } + + +class PaintLanRLOpaqueCOCODataset(Dataset): + def __init__(self, vis_processor, text_processor, img_root, stroke_root, ann_path, single_stroke=False, max_step=50): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.img_root = img_root + self.stroke_root = stroke_root + self.image_ids = [file.split('/')[-1].split('.')[0] + for file in glob.glob(os.path.join(self.stroke_root, '*.pkl'))] + self.max_step = max_step + self.vis_processor = vis_processor + self.text_processor = text_processor + self.single_stroke = single_stroke + + self.captions = {} + with open(ann_path, 'r') as f: + anns = json.load(f) + for ann in anns['annotations']: + if ann['image_id'] in self.captions: + self.captions[ann['image_id']].append(ann['caption']) + else: + self.captions[ann['image_id']] = [ann['caption']] + for idx in self.image_ids: + assert int(idx) in self.captions + + self.width = 256 + self.instruction = "Task: {}\nCanvas: Action: " + + def __len__(self): + return len(self.image_ids) + + def preprocess(self, index, step=-1): + image_id = self.image_ids[index] + image_file = '{}.jpg'.format(image_id) + image_path = os.path.join(self.img_root, image_file) + orig_image = Image.open(image_path).convert("RGB") + captions = self.captions[int(image_id)] + + with open(os.path.join(self.stroke_root, '{}.pkl'.format(image_id)), "rb") as f: + strokes_dict = pickle.load(f) + + if_fine_strokes = [int(len(strokes.shape) > 2) for strokes in strokes_dict['strokes']] + if self.single_stroke: + n_steps = (len(if_fine_strokes) - sum(if_fine_strokes)) * 5 + 16 * 5 * sum(if_fine_strokes) + else: + n_steps = len(if_fine_strokes) + 4 * sum(if_fine_strokes) + + if step < 0: + step = random.randint(0, min(n_steps - 1, self.max_step)) + + canvas, next_stroke = rl_paint(strokes_dict, step, width=self.width, single_stroke=self.single_stroke) + canvas = Image.fromarray((canvas[0].transpose(1, 2, 0) * 255).astype(np.uint8)) + + return { + "orig_image": orig_image, + "captions": captions, + "canvas": canvas, + "next_stroke": rlo_stroke2str(next_stroke), + "image_id": image_id, + } + + def __getitem__(self, index): + data = self.preprocess(index) + canvas = self.vis_processor(data['canvas']) + instruction = self.instruction.format(random.choice(data['captions'])) + + return { + "image": canvas, + "instruction_input": instruction, + "answer": data['next_stroke'], + "image_id": data['image_id'], + } + + +class PaintPixelCOCODataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths, res): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.res = res + self.img_ids = {} + n = 0 + + self.filter_anntation = [] + + for ann in self.annotation: + if "train" in ann["image"]: + self.filter_anntation.append(ann) + self.annotation = self.filter_anntation + + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __getitem__(self, index): + ann = self.annotation[index] + + img_file = ann["image"].split("/")[-1] + image_path = os.path.join(self.vis_root, img_file) + image = Image.open(image_path).convert("RGB") + + pixelized = np.array(image.resize([self.res, self.res])) + + image = self.vis_processor(image) + + loc_y = random.randint(0, self.res - 1) + loc_x = random.randint(0, self.res - 1) + rgb = pixelized[loc_y, loc_x] + + instruction = " [reconstruct] loc: [{},{}] rgb: ".format(loc_y, loc_x) + answer = '[{},{},{}]'.format(rgb[0], rgb[1], rgb[2]) + + return { + "image": image, + "answer": answer, + "instruction_input": instruction, + } + + +class SegReferCOCODataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path, res, dataset='refcoco', splitBy='unc'): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_path (string): directory to store the annotation file + """ + self.vis_root = vis_root + self.ann_path = ann_path + self.splitBy = splitBy + self.res = res + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self.ann_dir = os.path.join(ann_path, dataset) + ref_file = os.path.join(self.ann_dir, 'refs(' + splitBy + ').p') + + self.data = {} + with open(ref_file, 'rb') as f: + data_refs = pickle.load(f) + data_refs = [ref for ref in data_refs if ref['split'] == 'train'] # only use train split + + for ref in data_refs: + if ref['image_id'] in self.data: + self.data[ref['image_id']].append(ref) + else: + self.data[ref['image_id']] = [ref] + self.img_id_list = list(self.data.keys()) + + # load annotations from data/dataset/instances.json + instances_file = os.path.join(self.ann_dir, 'instances.json') + self.coco = COCO(instances_file) + + def __len__(self): + return len(self.img_id_list) + + def prepare_data(self, index): + image_id = self.img_id_list[index] + raw_anns = self.data[image_id] + anns = [] + for ann in raw_anns: + refers = [sentence['sent'] for sentence in ann['sentences']] + ann_id = ann['ann_id'] + annotations = self.coco.loadAnns([ann_id]) + mask = Image.fromarray(self.coco.annToMask(annotations[0])) + anns.append({'refers': refers, 'mask': mask}) + + img_data = self.coco.loadImgs(image_id)[0] + image_path = os.path.join(self.vis_root, img_data['file_name']) + image = Image.open(image_path).convert("RGB") + + return { + 'image': image, + 'anns': anns, + } + + def __getitem__(self, index): + data = self.prepare_data(index) + image = self.vis_processor(data['image']) + all_masks = [np.array(ann['mask'].resize([self.res, self.res], 0)) for ann in data['anns']] + ann_id = random.randint(0, len(data['anns']) - 1) + + selected_ann = data['anns'][ann_id] + selected_refer = random.choice(selected_ann['refers']) + pixelized_mask = all_masks[ann_id] + all_mask = sum(all_masks) + + pixelized_mask[pixelized_mask != 0] = 1 + all_mask[all_mask != 0] = 1 + + has_other_obj = bool((all_mask != pixelized_mask).sum()) + + if (pixelized_mask == 0).sum() in [0, pixelized_mask.size]: # all black or all white + loc_y = random.randint(0, self.res - 1) + loc_x = random.randint(0, self.res - 1) + else: + if random.uniform(0, 1) < 0.4: # in 40% cases we sample object region + # object region + ys, xs = np.where(pixelized_mask != 0) + else: + # background + dice = random.uniform(0, 1) + if dice < 0.1: + # easy background points + ys, xs = np.where(pixelized_mask == 0) + elif has_other_obj and dice < 0.6: + # points on other unrelated objects + other_obj_mask = cv2.bitwise_xor(pixelized_mask, all_mask) + ys, xs = np.where(other_obj_mask != 0) + else: + # contour points around the object + dilate_mask = cv2.dilate(pixelized_mask, np.ones([self.res // 8, self.res // 8], dtype=np.uint8), + iterations=1) + contour_mask = cv2.bitwise_xor(pixelized_mask, dilate_mask) + ys, xs = np.where(contour_mask != 0) + + idx = random.randint(0, len(ys) - 1) + loc_y, loc_x = ys[idx], xs[idx] + + mask_value = pixelized_mask[loc_y, loc_x] + + instruction = " [segmentation] {} loc: [{},{}] mask: ".format( + selected_refer, loc_y, loc_x) + answer = str(mask_value) + + return { + "image": image, + "answer": answer, + "instruction_input": instruction, + } diff --git a/minigpt4/datasets/datasets/reasoning_dataset.py b/minigpt4/datasets/datasets/reasoning_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1ae48ffde60f778b20fdd67e3c413fac0ed00900 --- /dev/null +++ b/minigpt4/datasets/datasets/reasoning_dataset.py @@ -0,0 +1,64 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset + + + +class ReasoningDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + self.data = json.load(open(ann_path)) + + # self.data = self.create_data(ann_path) + + # def create_data(self, ann_path): + # # processed_data = [] + # with open(ann_path, 'r') as f: + # data = json.load(f) + + # return processed_data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image_id = sample["image_id"]+".jpg" + question = sample["question"] + answer = sample["answer"] + + + image = Image.open(os.path.join(self.vis_root, image_id)).convert("RGB") + image = self.vis_processor(image) + + instruction = ' {} '.format(question) + + return { + "image": image, + "instruction_input": instruction, + "answer": answer + } + + diff --git a/minigpt4/datasets/datasets/text_caps.py b/minigpt4/datasets/datasets/text_caps.py new file mode 100644 index 0000000000000000000000000000000000000000..271f1b07388abd35f7fed16f20854c608869b817 --- /dev/null +++ b/minigpt4/datasets/datasets/text_caps.py @@ -0,0 +1,179 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset + + + + + +class TextCapDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self.instruction_pool = [ + 'Briefly describe this image.', + 'Provide a concise depiction of this image.', + 'Present a short description of this image.', + 'Summarize this image in a few words.', + 'A short image caption:', + 'A short image description:', + 'A photo of ', + 'An image that shows ', + 'Write a short description for the image. ', + 'Write a description for the photo.', + 'Provide a description of what is presented in the photo.', + 'Briefly describe the content of the image.', + 'Can you briefly explain what you see in the image?', + 'Could you use a few words to describe what you perceive in the photo?', + 'Please provide a short depiction of the picture.', + 'Using language, provide a short account of the image.', + 'Use a few words to illustrate what is happening in the picture.', + ] + + with open(ann_path, 'r') as f: + self.ann = json.load(f) + + + def __len__(self): + return len(self.ann["data"]) + + + def __getitem__(self, index): + info = self.ann["data"][index] + + image_file = '{}.jpg'.format(info['image_id']) + + image_path = os.path.join(self.vis_root, image_file) + image = Image.open(image_path).convert("RGB") + # image_width,image_length = image.size + image = self.vis_processor(image) + + # ocr_info = self.ann[index]["data"] + caption = info["caption_str"] + caption = self.text_processor(caption) + + # instruction = random.choice(self.instruction_pool).format(word_bbox) + instruction = " [caption] {} ".format(random.choice(self.instruction_pool)) + return { + "image": image, + "instruction_input": instruction, + "answer": caption, + "data_type": "bbox", + "question_split": True + } + +class TextCapBboxToObjectDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + + # self.instruction_pool = [ + # " What text does it show in {} ", + # " Extract the text from {} ", + # " What is the textual content in {} ", + # " Extract the textual information present in the {} ", + # " What is the text written within this defined region {}", + # " Transcribe the text located inside {}", + # " Can you read and extract the text from this specific area {}", + # ] + + self.instruction_pool = [ + " [OCR] {}" + ] + with open(ann_path, 'r') as f: + self.ann = json.load(f) + + self.new_ann = {"data":[]} + for da in self.ann["data"]: + if da["ocr_info"] !=[]: + ocr_info_filter = [] + for d in da["ocr_info"]: + if (d["bounding_box"]["width"]+d["bounding_box"]["top_left_x"])<=1.0 and (d["bounding_box"]["height"]+d["bounding_box"]["top_left_y"]) <=1.0 \ + and d["bounding_box"]["top_left_x"]>=0 and d["bounding_box"]["top_left_y"]>=0: + ocr_info_filter.append(d) + if ocr_info_filter !=[]: + da["ocr_info"]=ocr_info_filter + self.new_ann["data"].append(da) + self.ann = self.new_ann + + + def __len__(self): + return len(self.ann["data"]) + + + def __getitem__(self, index): + + info = self.ann["data"][index] + + + image_file = '{}.jpg'.format(info['image_id']) + + image_path = os.path.join(self.vis_root, image_file) + image = Image.open(image_path).convert("RGB") + # image_width,image_length = image.size + image = self.vis_processor(image) + + + + image_size = 100 + + ocr_info = info["ocr_info"] + + sampled_ocr = random.sample(ocr_info,1)[0] + + # print("sampled ocr", sampled_ocr) + + word_text = sampled_ocr["word"] + width = sampled_ocr["bounding_box"]["width"] + height = sampled_ocr["bounding_box"]["height"] + top_left_x = sampled_ocr["bounding_box"]["top_left_x"] + top_left_y = sampled_ocr["bounding_box"]["top_left_y"] + + x1 = int(top_left_x*image_size) + y1 = int(top_left_y*image_size) + x2 = x1 + int(width*image_size) + y2 = y1 + int(height*image_size) + assert x1>=0 and x1<=image_size + assert x2>=0 and x2<=image_size + assert y1>=0 and y1<=image_size + assert y2>=0 and y2<=image_size + + + word_bbox = "{<"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">}" + + instruction = random.choice(self.instruction_pool).format(word_bbox) + return { + "image": image, + "instruction_input": instruction, + "answer": word_text, + "data_type": "bbox", + "question_split": True + } \ No newline at end of file diff --git a/minigpt4/datasets/datasets/textvqa_datasets.py b/minigpt4/datasets/datasets/textvqa_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..cf15b83ac88a583502c62090c7b305f01538b001 --- /dev/null +++ b/minigpt4/datasets/datasets/textvqa_datasets.py @@ -0,0 +1,82 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch + +from PIL import Image + +from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset + +from collections import OrderedDict + + +# class textVQADataset(VQADataset): +# def __init__(self, vis_processor, text_processor, vis_root, ann_paths): +# super().__init__(vis_processor, text_processor, vis_root, ann_paths) + +# def collater(self, samples): +# image_list, question_list, answer_list, weight_list = [], [], [], [] + +# num_answers = [] + +# for sample in samples: +# image_list.append(sample["image"]) +# question_list.append(sample["text_input"]) + +# weight_list.extend(sample["weights"]) + +# answers = sample["answers"] + +# answer_list.extend(answers) +# num_answers.append(len(answers)) + +# return { +# "image": torch.stack(image_list, dim=0), +# "text_input": question_list, +# "answer": answer_list, +# "weight": torch.Tensor(weight_list), +# "n_answers": torch.LongTensor(num_answers), +# } + + + +from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset + +class textVQAEvalDataset(VQADataset): + def __init__(self, vis_processor, text_processor, vis_root=None, ann_paths=None): +# super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + from datasets import load_dataset + self.annotation = load_dataset("textvqa", split="validation") + + def __getitem__(self, index): + ann = self.annotation[index] + image = ann["image"].convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + instruction = random.choice(self.instruction_pool).format(question) + instruction = " {} ".format(instruction) + print("instruction", instruction) + answers = ann["answers"] + + if "unk" in answers: + print(answers) + return { + "image": image, + "text_input": question, + "answer": answers, + # 'image_path': image_path, + "instruction_input": instruction, + "question_id": ann["question_id"], + "instance_id": ann["instance_id"], + } + + +dataset = textVQAEvalDataset(vis_processor, text_processor) +dataloader = torch.utils.data.DataLoader(dataset, batch_size=1) \ No newline at end of file diff --git a/minigpt4/datasets/datasets/unnatural_instruction.py b/minigpt4/datasets/datasets/unnatural_instruction.py new file mode 100644 index 0000000000000000000000000000000000000000..2abac562650f9a4669b0753e6e8506fb0e721566 --- /dev/null +++ b/minigpt4/datasets/datasets/unnatural_instruction.py @@ -0,0 +1,52 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset + + +class UnnaturalDataset(Dataset): + def __init__(self, text_processor, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.text_processor = text_processor + + with open(ann_path, 'r') as f: + self.ann = json.load(f) + + # with open(ann_path, 'r') as f: + # for data in f.readlines(): + # data = json.loads(data) + # self.ann.append(data) + + def __len__(self): + return len(self.ann) + + def __getitem__(self, index): + info = self.ann[index]["instances"][0] + instruction = info["instruction_with_input"] + constraints = info["constraints"] + answer = info["output"] + if constraints != None: + instruction = instruction+" "+constraints + + return { + # "image":None, + "instruction_input": instruction, + "answer": answer, + } diff --git a/minigpt4/datasets/datasets/vg_dataset.py b/minigpt4/datasets/datasets/vg_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3933fbf865df4dac1e635f381c69324ec9e26cb0 --- /dev/null +++ b/minigpt4/datasets/datasets/vg_dataset.py @@ -0,0 +1,96 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +from torch.utils.data import Dataset +from visual_genome import local + + +import threading + +# Global lock +lock = threading.Lock() + + +class ReferVisualGenomeDataset(Dataset): + def __init__(self, vis_processor, text_processor, data_dir): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.data_dir = data_dir + + self.vis_processor = vis_processor + self.text_processor = text_processor + + all_regions = local.get_all_region_descriptions(self.data_dir) + all_regions = [region for regions in all_regions for region in regions] + + # follow OFA practice, only regions smaller than 16384 pixels are used for refer + self.regions = [region for region in all_regions if region.width * region.height < 16384] + + print('Visual Genome grounding', len(self.regions)) + + + self.instruction_pool = [ + "[refer] {}", + "[refer] give me the location of {}", + "[refer] where is {} ?", + "[refer] from this image, tell me the location of {}", + "[refer] the location of {} is", + "[refer] could you tell me the location for {} ?", + "[refer] where can I locate the {} ?", + ] + + + def __len__(self): + return len(self.regions) + + def preprocess(self, index): + region = self.regions[index] + image_file = region.image.url.split('/')[-2:] + image_path = os.path.join(self.data_dir, *image_file) + image = Image.open(image_path).convert("RGB") + image_orig_size = image.size + image = self.vis_processor(image) + image_new_size = [100,100] + + sample_sentence = region.phrase + refer_sentence = self.text_processor(sample_sentence) + + bbox = [region.x, region.y, region.width, region.height] + + bbox = [ + bbox[0] / image_orig_size[0] * image_new_size[0], + bbox[1] / image_orig_size[1] * image_new_size[1], + (bbox[0] + bbox[2]) / image_orig_size[0] * image_new_size[0], + (bbox[1] + bbox[3]) / image_orig_size[1] * image_new_size[1] + ] + bbox = [int(x) for x in bbox] + bbox = "{{<{}><{}><{}><{}>}}".format(*bbox) + return { + "image": image, + "refer_sentence": refer_sentence, + "bbox": bbox, + "image_id": region.image.id, + } + + def __getitem__(self, index): + data = self.preprocess(index) + instruction = random.choice(self.instruction_pool).format(data['refer_sentence']) + + instruction = " {} ".format(instruction) + + return { + "image": data['image'], + "instruction_input": instruction, + "answer": data['bbox'], + "image_id": data['image_id'], + } + + diff --git a/minigpt4/datasets/datasets/video_datasets.py b/minigpt4/datasets/datasets/video_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..2128ebef59b635f9fdfc638290eb554f3033fda5 --- /dev/null +++ b/minigpt4/datasets/datasets/video_datasets.py @@ -0,0 +1,1051 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from collections import OrderedDict +import sys +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from PIL import Image +import random +import json + +import cv2 +import torch +import torchvision.transforms as transforms + +import numpy as np +import webvtt +import math +from moviepy.editor import VideoFileClip +from minigpt4.processors.blip_processors import Blip2ImageTrainProcessor,BlipCaptionProcessor +import pickle +import time +from decord import VideoReader, cpu, gpu +from tqdm import tqdm +import pysrt +import chardet +import re + +def duration_to_seconds(duration_str): + duration_str = duration_str[2:] # Removing 'PT' prefix + seconds = 0 + if 'H' in duration_str: + hours_str = duration_str.split('H')[0] + seconds += int(hours_str) * 3600 + duration_str = duration_str.split('H')[1] + if 'M' in duration_str: + minutes_str = duration_str.split('M')[0] + seconds += int(minutes_str) * 60 + duration_str = duration_str.split('M')[1] + if 'S' in duration_str: + seconds_str = duration_str.split('S')[0] + seconds += int(seconds_str) + return seconds + +def extract_audio(video_path, audio_path): + video_clip = VideoFileClip(video_path) + audio_clip = video_clip.audio + audio_clip.write_audiofile(audio_path, codec="libmp3lame", bitrate="320k") + +def generate_subtitles(video_path,existed_subtitles): + video_id=video_path.split('/')[-1].split('.')[0] + audio_path = f"workspace/misssing_eval_subtitles/mp3/{video_id}"+'.mp3' + if existed_subtitles.get(video_id,False): + print("subtitle already generated") + return f"workspace/misssing_eval_subtitles/{video_id}"+'.vtt' + try: + extract_audio(video_path,audio_path) + print("successfully extracted") + os.system(f"whisper {audio_path} --language English --model large --output_format vtt --output_dir workspace/misssing_eval_subtitles") + # remove the audio file + os.system(f"rm {audio_path}") + print("subtitle successfully generated") + return f"workspace/misssing_eval_subtitles/{video_id}"+'.vtt' + except Exception as e: + print("error",video_path ,e) + return None + +def read_subtitles(subtitle_path): + # read the subtitle file and detect the encoding + try: + with open(subtitle_path, 'rb') as f: + result = chardet.detect(f.read()) + subs = pysrt.open(subtitle_path, encoding=result['encoding']) + return subs + except: + return [] + + +def srt_time_to_seconds(time): + return time.hours * 3600 + time.minutes * 60 + time.seconds + time.milliseconds / 1000 + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "caption": ann["caption"], + "image": sample["image"], + } + ) + + +class CMDVideoDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths, cc_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + self.instruction_pool = [ + 'Describe this video.', + 'Provide a concise depiction of this video.', + 'Present a description of this video.', + 'Summarize this video.', + 'Generate video caption:', + 'Generate video description:', + 'Write a description for the video.', + 'Provide a description of what is presented in the video.', + 'Describe the content of the video.', + 'Can you explain what you see in the video?', + 'Could you describe what you perceive in the video?', + 'Please provide a depiction of the video.', + 'Illustrate what is happening in the video.', + ] + self.img_ids = {} + n = 0 + self.length = 90 + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + self.cc = json.load(open(cc_path,'r')) + self.image_sep = "" + self.text_sep = "" + + def __getitem__(self, index): + ann = self.annotation[index] + video_id = ann["image_id"] + captions = self.cc[video_id] if video_id in self.cc else None + answer = self.text_processor(ann["caption"]) + instruction = random.choice(self.instruction_pool) + images = [] + img_placeholder = "" + num_of_images=len(os.listdir(os.path.join(self.vis_root, video_id))) + sampling_interval = int(num_of_images / self.length) + if sampling_interval == 0: + sampling_interval = 1 + for frame_id in range(0,num_of_images,sampling_interval): + image_path = os.path.join(self.vis_root, video_id, f'frame_{frame_id}.jpg') + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image) + images.append(image) + img_placeholder += f"{self.image_sep}" + time_step = str(frame_id * 2) + if captions is not None: + if time_step in captions: + img_placeholder += f"{self.text_sep}{captions[time_step]}" + if len(images) >= self.length: + break + + if len(images) < self.length: + last_item = images[-1] + while len(images) < self.length: + images.append(last_item) + images = torch.stack(images) + instruction = f"{img_placeholder}\n{instruction}" + return { + "image": images, + "answer": answer, + "image_id": video_id, + "instruction_input": instruction, + "length": self.length, + } + + + + +class WebVidDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths,subtitles_path,add_subtitles=False): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + self.instruction_pool = [ + 'Describe this video.', + 'Provide a concise depiction of this video.', + 'Present a description of this video.', + 'Summarize this video.', + 'Generate video caption:', + 'Generate video description:', + 'Write a description for the video.', + 'Provide a description of what is presented in the video.', + 'Describe the content of the video.', + 'Can you explain what you see in the video?', + 'Could you describe what you perceive in the video?', + 'Please provide a depiction of the video.', + 'Illustrate what is happening in the video.', + ] + self.img_ids = {} + n = 0 + self.length = 90 + self.max_sub_len = 800 + self.add_subtitles = add_subtitles + self.videos_has_subtitles = {} + if self.add_subtitles: + self.subtitle_folder = os.path.join(subtitles_path) + for sub in os.listdir(self.subtitle_folder): + video_id = sub.split('.')[0] + self.videos_has_subtitles[video_id] = True + for ann in self.annotation: + img_id = ann["videoid"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + self.transform = transforms.Compose([ + transforms.ToPILImage(), + ]) + + def __getitem__(self, index): + ann = self.annotation[index] + + video_id = ann["videoid"] + images = [] + caption = ann["name"].split('-')[-1].split(':')[-1] + # caption = self.text_processor(caption) + + video_path = os.path.join(self.vis_root, ann['page_dir'], f'{video_id}.mp4') + has_subtitles = self.videos_has_subtitles.get(video_id, False) + if self.add_subtitles and has_subtitles: + subtitle_path = os.path.join(self.subtitle_folder, f'{video_id}.vtt') + # Load the VTT subtitle file + vtt_file = webvtt.read(subtitle_path) + + cap = cv2.VideoCapture(video_path) + clip = VideoFileClip(video_path) + total_num_frames = int(clip.duration * clip.fps) + clip.close() + cap = cv2.VideoCapture(video_path) + images = [] + frame_count = 0 + sampling_interval = int(total_num_frames /self.length) + if sampling_interval == 0: + sampling_interval = 1 + img_placeholder = "" + subtitle_text_in_interval = "" + history_subtitles = {} + number_of_sub_words=0 + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + # Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle + # we choose 1 frame for every 2 seconds,so we need to combine the subtitles in the interval of 2 seconds + if self.add_subtitles and has_subtitles: + for subtitle in vtt_file: + sub=subtitle.text.replace('\n',' ') + if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds) and sub not in subtitle_text_in_interval: + if not history_subtitles.get(sub,False): + subtitle_text_in_interval+=sub+" " + history_subtitles[sub]=True + break + if frame_count % sampling_interval == 0: + frame = self.transform(frame[:,:,::-1]) + frame = self.vis_processor(frame) + images.append(frame) + img_placeholder += '' + if self.add_subtitles and has_subtitles and subtitle_text_in_interval != "" and number_of_sub_words{subtitle_text_in_interval}' + number_of_sub_words+=len(subtitle_text_in_interval.split(' ')) + subtitle_text_in_interval = "" + frame_count += 1 + if len(images) >= self.length: + break + cap.release() + + if len(images) < self.length: + last_item = images[-1] + while len(images) < self.length: + images.append(last_item) + img_placeholder += '' + + images = torch.stack(images) + instruction = random.choice(self.instruction_pool) + instruction = img_placeholder + '\n' + instruction + return { + "image": images, + "answer": caption, + "image_id": video_id, + "instruction_input": instruction, + "length": self.length, + } + +class VideoChatGPTDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths,add_subtitles=True,llm_name="llama2"): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + self.img_ids = {} + n=0 + self.length = 90 + self.max_sub_len = 800 + self.add_subtitles = add_subtitles + self.videos_has_subtitles = {} + if self.add_subtitles: + self.subtitle_folder = os.path.join(self.vis_root,'subtitles') + for sub in os.listdir(self.subtitle_folder): + video_id = sub.split('.')[0] + self.videos_has_subtitles[video_id] = True + for ann in self.annotation: + img_id = ann["video_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n+= 1 + + self.videos_extension={} + for video in os.listdir(os.path.join(self.vis_root,'videos')): + self.videos_extension[video.split('.')[0]]=video.split('.')[1] + + self.transform = transforms.Compose([ + transforms.ToPILImage(), + ]) + def __len__(self): + return len(self.annotation) + def __getitem__(self, index): + ann = self.annotation[index] + video_id = ann["video_id"] + answer=ann["a"] + instruction=ann["q"] + images=[] + img_placeholder = "" + has_subtitles = self.videos_has_subtitles.get(video_id, False) + if self.add_subtitles and has_subtitles: + subtitle_path = os.path.join(self.subtitle_folder, f'{video_id}.vtt') + # Load the VTT subtitle file + vtt_file = webvtt.read(subtitle_path) + + video_path = os.path.join(self.vis_root,'videos',f'{video_id}.{self.videos_extension[video_id]}') + clip = VideoFileClip(video_path) + total_num_frames = int(clip.duration * clip.fps) + clip.close() + cap = cv2.VideoCapture(video_path) + frame_count = 0 + sampling_interval = int(total_num_frames / self.length) + if sampling_interval == 0: + sampling_interval = 1 + img_placeholder = "" + subtitle_text_in_interval = "" + history_subtitles = {} + number_of_sub_words=0 + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + # Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle + # we choose 1 frame for every 2 seconds,so we need to combine the subtitles in the interval of 2 seconds + if self.add_subtitles and has_subtitles: + for subtitle in vtt_file: + sub=subtitle.text.replace('\n',' ') + if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds) and sub not in subtitle_text_in_interval: + if not history_subtitles.get(sub,False): + subtitle_text_in_interval+=sub+" " + history_subtitles[sub]=True + break + if frame_count % sampling_interval == 0: + frame = self.transform(frame[:,:,::-1])# BGR to RGB + frame = self.vis_processor(frame) + images.append(frame) + img_placeholder += '' + if self.add_subtitles and has_subtitles and number_of_sub_words{subtitle_text_in_interval}' + number_of_sub_words+=len(subtitle_text_in_interval.split(' ')) + subtitle_text_in_interval = "" + frame_count += 1 + if len(images) >= self.length: + break + cap.release() + if len(images) ==0: + print("Video not found",video_path) + + if 0 =self.length: + filtered_annotation.append(ann) + self.annotation = filtered_annotation + self.annotation = self.annotation + self.add_caption = add_captions + + self.cc = json.load(open(cc_path,'r')) + self.image_sep = "" + self.text_sep = "" + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, idx): + ann = self.annotation[idx] + video_id = ann["image_id"] + images = [] + subtitles=[] + length = min(self.length, ann['length']) + caption = ann["caption"] + instruction = "Write a detailed description for the video." + interleave = "" + captions = self.cc[video_id] if video_id in self.cc else None + for frame_id in range(length): + image_path = os.path.join(self.root_path, video_id, f'frame_{frame_id}.jpg') + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image).half().cuda() + images.append(image) + interleave += f"{self.image_sep}" + time_step = str(frame_id* 2) + if captions is not None and self.add_caption: + caption_found=captions.get(time_step,False) + if caption_found: + interleave += f"{self.text_sep}{captions[time_step]}" + subtitles.append(captions[time_step]) + + if 0 < len(images) < self.length: + last_item = images[-1] + while len(images) < self.length: + images.append(last_item) + interleave += f"{self.image_sep}" + instruction = f"{interleave}\n{instruction}" + images = torch.stack(images) + return images, instruction, caption, self.length,video_id + + +class WebVidEvalDataset(torch.utils.data.Dataset): + def __init__(self, vis_processor, root_path, ann_path, length, fix=False,add_captions=False): + self.root_path = root_path + self.vis_processor = vis_processor + self.length = length + with open(ann_path,'r') as f: + self.annotation=json.load(f) + self.fix = fix + if fix: + filtered_annotation = [] + for ann in self.annotation: + if duration_to_seconds(ann['duration']) // 2 >= self.length: + filtered_annotation.append(ann) + self.annotation = filtered_annotation + self.transform = transforms.Compose([ + transforms.ToPILImage(), + ]) + self.annotation = self.annotation + self.add_subtitles = add_captions + self.videos_has_subtitles = {} + if self.add_subtitles: + self.subtitle_folder = os.path.join("datasets/video_text_data/webvid/webvid_val_subtitles") + for sub in os.listdir(self.subtitle_folder): + video_id = sub.split('.')[0] + self.videos_has_subtitles[video_id] = True + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, idx): + ann = self.annotation[idx] + + video_id = ann["videoid"] + length = min(self.length, duration_to_seconds(ann['duration']) // 2) + caption = ann["name"] + + video_path = os.path.join(self.root_path, ann['page_dir'], f'{video_id}.mp4') + has_subtitles = self.videos_has_subtitles.get(video_id, False) + if self.add_subtitles and has_subtitles: + subtitle_path = os.path.join(self.subtitle_folder, f'{video_id}.vtt') + # Load the VTT subtitle file + vtt_file = webvtt.read(subtitle_path) + cap = cv2.VideoCapture(video_path) + clip = VideoFileClip(video_path) + total_num_frames = int(clip.duration * clip.fps) + clip.close() + cap = cv2.VideoCapture(video_path) + images = [] + frame_count = 0 + sampling_interval = int(total_num_frames /self.length) + if sampling_interval == 0: + sampling_interval = 1 + img_placeholder = "" + subtitle_text_in_interval = "" + history_subtitles = {} + number_of_sub_words=0 + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + # Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle + # we choose 1 frame for every 2 seconds,so we need to combine the subtitles in the interval of 2 seconds + if self.add_subtitles and has_subtitles: + for subtitle in vtt_file: + sub=subtitle.text.replace('\n',' ') + if (subtitle.start_in_seconds <= (frame_count / int(cap.get(cv2.CAP_PROP_FPS))) <= subtitle.end_in_seconds) and sub not in subtitle_text_in_interval: + if not history_subtitles.get(sub,False): + subtitle_text_in_interval+=sub+" " + history_subtitles[sub]=True + break + if frame_count % sampling_interval == 0: + frame = self.transform(frame[:,:,::-1]) + frame = self.vis_processor(frame) + images.append(frame) + img_placeholder += '' + if self.add_subtitles and has_subtitles and subtitle_text_in_interval != "" and number_of_sub_words{subtitle_text_in_interval}' + number_of_sub_words+=len(subtitle_text_in_interval.split(' ')) + subtitle_text_in_interval = "" + frame_count += 1 + if len(images) >= self.length: + break + cap.release() + + instruction = "Write a description for the video." + video_found = True + if len(images) == 0: + images = torch.zeros(length, 3, 224, 224) + for i in range(length): + img_placeholder += '' + print("Video not found") + video_found = False + if len(images) < self.length: + last_item = images[-1] + while len(images) < self.length: + images.append(last_item) + img_placeholder += '' + images = torch.stack(images) if video_found else images + instruction = img_placeholder + '\n' + instruction + return images, instruction, caption, self.length,video_id + + + + +class VideoChatGPTEvalDataset(torch.utils.data.Dataset): + def __init__(self, vis_processor, videos_path, ann_path,subtitles_path,annotations_keys,videos_features_path,add_subtitles=True,llm_name="llama2"): + if llm_name=="llama2": + self.length = 45 + self.max_sub_len = 400 + else: + self.length = 90 + self.max_sub_len = 800 + self.add_subtitles = add_subtitles + self.vis_processor=vis_processor + self.videos_path=videos_path + self.question_key=annotations_keys[0] + self.answer_key=annotations_keys[1] + self.video_name_key=annotations_keys[2] + self.videos_extension={} + for video in os.listdir(self.videos_path): + self.videos_extension[video.split('.')[0]]=video.split('.')[1] + self.annotation=json.load(open(ann_path,'r')) + self.videos_has_subtitles = {} + if self.add_subtitles: + self.subtitle_folder = subtitles_path + for sub in os.listdir(self.subtitle_folder): + video_id = sub.split('.')[0] + self.videos_has_subtitles[video_id] = True + self.transform = transforms.Compose([ + transforms.ToPILImage(), + ]) + self.videos_features_path=videos_features_path + def __len__(self): + return len(self.annotation) + def __getitem__(self, index): + ann = self.annotation[index] + video_id = ann[self.video_name_key] + answer=ann[self.answer_key] + instruction=ann[self.question_key] + images=[] + img_placeholder = "" + video_path = os.path.join(self.videos_path,f'{video_id}.{self.videos_extension[video_id]}') + cap = cv2.VideoCapture(video_path) + clip = VideoFileClip(video_path) + total_num_frames = int(clip.duration * clip.fps) + clip.close() + frame_count = 0 + sampling_interval = int(total_num_frames / self.length) + if sampling_interval == 0: + sampling_interval = 1 + subtitle_path=None + if self.add_subtitles : + subtitle_path = generate_subtitles(video_path,self.videos_has_subtitles) + if subtitle_path is not None: + # Load the VTT subtitle file + vtt_file = webvtt.read(subtitle_path) + subtitle_text_in_interval = "" + history_subtitles = {} + number_of_sub_words=0 + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + # Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle + # we choose 1 frame for every 2 seconds,so we need to combine the subtitles in the interval of 2 seconds + if self.add_subtitles and subtitle_path is not None: + for subtitle in vtt_file: + sub=subtitle.text.replace('\n',' ') + if (subtitle.start_in_seconds <= (frame_count / int(cap.get(cv2.CAP_PROP_FPS))) <= subtitle.end_in_seconds) and sub not in subtitle_text_in_interval: + if not history_subtitles.get(sub,False): + subtitle_text_in_interval+=sub+" " + history_subtitles[sub]=True + break + if frame_count % sampling_interval == 0: + frame = self.transform(frame[:,:,::-1]) + frame = self.vis_processor(frame) + images.append(frame) + img_placeholder += '' + if self.add_subtitles and subtitle_path is not None and number_of_sub_words{subtitle_text_in_interval}' + number_of_sub_words+=len(subtitle_text_in_interval.split(' ')) + subtitle_text_in_interval = "" + frame_count += 1 + if len(images) >= self.length: + break + cap.release() + if len(images) == 0: + print("Video not found") + print('Video path',video_path) + return None,None,None,None,None + if 0 {subtitle_text_in_interval}' + number_of_sub_words+=len(subtitle_text_in_interval.split(' ')) + subtitle_text_in_interval = "" + frame_count += 1 + if len(images) >= self.length: + break + cap.release() + if len(images) == 0: + print("Video not found") + print('Video path',video_path) + return None,None,None,None,None + if 0 {subtitle_text_in_interval}' + number_of_sub_words+=len(subtitle_text_in_interval.split(' ')) + subtitle_text_in_interval = "" + frame_count += 1 + if len(images) >= self.length: + break + cap.release() + if len(images) == 0: + print("Video not found") + print('Video path',video_path) + return None,None,None,None,None + if 0 {subtitle_text_in_interval}' + number_of_sub_words+=len(subtitle_text_in_interval.split(' ')) + subtitle_text_in_interval = "" + if len(images) >= self.length: + break + if len(images) ==0: + print("Video not found",video_frames_path) + + if 0 {subtitle_text_in_interval}' + number_of_sub_words+=len(subtitle_text_in_interval.split(' ')) + subtitle_text_in_interval = "" + if len(images) >= self.length: + break + if len(images) ==0: + print("Video not found",video_frames_path) + + if 0
')[0] # remove the stop sign
+ output_texts = output_texts.replace("", "") + output_texts = output_texts.split(r'[/INST]')[-1].strip() + answers.append(output_texts) + if return_video_temporal_features: + return answers, video_temporal_features + else: + return answers + + @torch.no_grad() + def generate_text_only( + self, + images, + seg_tokens, + use_nucleus_sampling=False, + num_beams=1, + max_new_tokens=20, + min_length=1, + top_p=0.9, + repetition_penalty=1.5, + length_penalty=1, + temperature=1, + do_sample=False, + stop_words_ids=[2], + lengths=None, + return_video_temporal_features=False, + img_embeds=None, + ): + ''' + function for generate test use + ''' + + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( + stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) + + # seg_tokens=[] + # for i, text in enumerate(texts): + # seg_tokens.append(self.llama_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(self.device).input_ids) + + batch_embs = [torch.cat([self.embed_tokens(seg_t)]) for seg_t in seg_tokens] + + # seg_embs = torch.cat(seg_embs, dim=1) + # print("seg_embs shape",seg_embs.shape) + # batch_embs=[seg_embs] + batch_size = len(batch_embs) + max_len = max([emb.shape[1] for emb in batch_embs]) + emb_dim = batch_embs[0].shape[2] + dtype = batch_embs[0].dtype + device = batch_embs[0].device + + embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) + attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) + for i, emb in enumerate(batch_embs): + emb_len = emb.shape[1] + embs[i, -emb_len:] = emb[0] + attn_mask[i, -emb_len:] = 1 + + + print("inputs_embeds shape",embs.shape) + print("attention_mask shape",attn_mask.shape) + with self.maybe_autocast(): + outputs = self.llama_model.generate( + inputs_embeds=embs, + attention_mask=attn_mask, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + do_sample=do_sample, + temperature=temperature, + repetition_penalty=repetition_penalty, + # stopping_criteria=stopping_criteria, + ) + + answers = [] + for output_token in outputs: + if output_token[0] == 0: + output_token = output_token[1:] + output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) + output_texts = output_texts.split('')[0] # remove the stop sign
+ output_texts = output_texts.replace("", "") + output_texts = output_texts.split(r'[/INST]')[-1].strip() + answers.append(output_texts) + return answers + + + + @torch.no_grad() + def multi_select(self, images, texts, answers, num_cand=None): + all_losses = [] + for answer in answers: + choice_samples = { + 'image': images, + 'instruction_input': texts, + 'answer': answer + } + loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1) + all_losses.append(loss) + torch.cuda.empty_cache() + all_losses = torch.cat(all_losses, dim=-1) + if num_cand is not None: + for i in range(all_losses.shape[0]): + all_losses[i, num_cand[i]:] = 9999 + output_class_ranks = torch.argsort(all_losses, dim=-1) + return output_class_ranks.tolist() + + def predict_answers( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=128, + answer_list=None, + prompt="", + length_penalty=0, + **kwargs + ): + ''' + function for open-ended VQA + ''' + images = samples["image"].cuda() + texts = samples["instruction_input"] + + output_text = self.generate( + images=images, + texts=texts, + num_beams=num_beams, + max_new_tokens=max_len, + min_length=min_len, + length_penalty=length_penalty + ) + + if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]: + output_text = self._lemmatize(output_text) + + return output_text + + def predict_class( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=5, + answer_list=None, + prompt="", + length_penalty=0, + **kwargs + ): + ''' + function for multi-choice VQA + ''' + + image = samples["image"].cuda() + instruction = samples['instruction_input'] + answers = samples["choices"] + num_cand = samples["num_choices"] + + ranks = self.multi_select(image, instruction, answers, num_cand) + + pred_ans = [] + for i, rank in enumerate(ranks): + pred = answers[rank[0]][i] + pred_ans.append(pred) + return pred_ans + + def embed_tokens(self, token_ids): + try: + embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) + except AttributeError: + embeds = self.llama_model.model.embed_tokens(token_ids) + + return embeds + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) + + prompt_path = cfg.get("prompt_path", "") + prompt_template = cfg.get("prompt_template", "") + max_txt_len = cfg.get("max_txt_len", 300) + end_sym = cfg.get("end_sym", '\n') + + lora_r = cfg.get("lora_r",64) + lora_alpha = cfg.get("lora_alpha",16) + chat_template = cfg.get("chat_template",False) + system_prompt = cfg.get("system_prompt", False) + token_pooling = cfg.get("token_pooling",True) + + use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False) + max_context_len = cfg.get("max_context_len", 3800) + remove_template = cfg.get("remove_template", False) + + + model = cls( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + prompt_path=prompt_path, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + low_resource=low_resource, + end_sym=end_sym, + lora_r = lora_r, + lora_alpha = lora_alpha, + chat_template = chat_template, + system_prompt = system_prompt, + token_pooling = token_pooling, + use_grad_checkpoint_llm=use_grad_checkpoint_llm, + max_context_len=max_context_len, + remove_template = remove_template + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + + return model + + +def assign_imgs(batched_instruct_list, batched_img_embeds): + '''this function is used when the data is interleaved. + the interlevaed data is separated, and this function assign + corresponding image embeddings to each segment''' + if len(batched_img_embeds.shape) == 3: + batched_img_embeds = batched_img_embeds[:, None] + + batched_assigned = [] + + for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds): + img_idx = 0 + assigned_img = [] + n_assigned = [] + for instruct in instruct_list: + n_img = instruct.count('') + if n_img > 0: # this instruction include images. + assigned_img.append(img_embeds[None, img_idx:img_idx+n_img]) + img_idx += n_img + n_assigned.append(n_img) + else: # this instruction doesn't include images + assigned_img.append(None) + n_assigned.append(None) + batched_assigned.append(assigned_img) + + return batched_assigned \ No newline at end of file diff --git a/minigpt4/models/mini_gpt4v.py b/minigpt4/models/mini_gpt4v.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d5e4d0e08b99e7083f3b3b52b7f6f88f26616d --- /dev/null +++ b/minigpt4/models/mini_gpt4v.py @@ -0,0 +1,709 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from minigpt4.common.registry import registry +from minigpt4.models.blip2 import Blip2Base, disabled_train +from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM +from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub + +from transformers import LlamaTokenizer, CodeLlamaTokenizer, BitsAndBytesConfig + +from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_kbit_training +) +import time +import numpy as np + +from minigpt4.models import policies + + +@registry.register_model("mini_gpt4v") +class MiniGPT4v(Blip2Base): + """ + BLIP2 GPT-LLAMA model. + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_vicuna": "configs/models/minigpt4.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + low_resource=False, # use 8 bit and put vit in cpu + end_sym='\n', + lora_r = 8, + lora_target_modules = ["q_proj","v_proj"], + lora_alpha=16, + # lora_r = 16, + # lora_target_modules = ["q_proj","v_proj","v_proj"], + lora_dropout= 0.05, + ckpt_path = "", + system_prompt= False, + chat_template=False, + token_pooling=True, + use_grad_checkpoint_llm=False, + max_context_len=3800, + remove_template = False, + + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + self.low_resource = low_resource + self.token_pooling = token_pooling + self.remove_template = remove_template + + print("token pooling", self.token_pooling) + + + self.use_grad_checkpoint_llm = use_grad_checkpoint_llm + self.max_context_len = max_context_len + self.chat_template = chat_template + + # print('Loading VIT') + # self.visual_encoder, self.ln_vision = self.init_vision_encoder( + # vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + # ) + + + print("vit precision", vit_precision) + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, 224, drop_path_rate, use_grad_checkpoint, vit_precision + ) + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + for name, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + self.ln_vision.train = disabled_train + logging.info("freeze vision encoder") + print("freeze the vision encoder") + + + print('Loading VIT Done') + + # print("visual encoder shape", self.visual_encoder.pos_embed.shape) + # assert False + + print('Loading LLAMA') + + + self.B_SYS, self.E_SYS = "<>\n", "\n<>\n\n" + + if 'CodeLlama' in llama_model: + self.llama_tokenizer = CodeLlamaTokenizer.from_pretrained(llama_model, use_fast=False) # + self.llama_tokenizer.pad_token = "$$" + else: + self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) # + self.llama_tokenizer.pad_token = "$$" + + self.system_prompt = system_prompt + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) + + + + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + quantization_config=bnb_config, + device_map={"": 0} + ) + + # self.llama_model.gradient_checkpointing_enable() + self.llama_model = prepare_model_for_kbit_training(self.llama_model) + + # self.llama_model.print_trainable_parameters() + + + print('Loading LLAMA Done') + + self.merge_n = 3 + + self.llama_proj = nn.Linear( + 1408 * self.merge_n**2, self.llama_model.config.hidden_size + ) + + self.max_txt_len = max_txt_len + self.end_sym = end_sym + + if prompt_path: + with open(prompt_path, 'r') as f: + raw_prompts = f.read().splitlines() + filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] + self.prompt_list = [prompt_template.format(p) for p in filted_prompts] + print('Load {} training prompts'.format(len(self.prompt_list))) + print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) + else: + self.prompt_list = [] + + def encode_img(self, image): + device = image.device + if len(image.shape) > 4: + image = image.reshape(-1, *image.shape[-3:]) + + bs, ch, w, h = image.shape + assert w % 224 == 0 + bw = w // 224 + assert h % 224 == 0 + bh = h // 224 + image_patches = image.view(bs, ch, bw, 224, bh, 224).permute(0, 2, 4, 1, 3, 5) # bs, bw, bh, ch, 224, 224 + image_patches = image_patches.reshape(bs * bw * bh, ch, 224, 224) + + with self.maybe_autocast(): + image_patch_embeds = self.ln_vision(self.visual_encoder(image_patches)).to(device) + + image_patch_embeds = image_patch_embeds[:,1:,:].reshape(bs, bw, bh, 16, 16, image_patch_embeds.shape[-1]) + image_patch_embeds = image_patch_embeds.permute(0, 1, 3, 2, 4, 5) # bs, bw, 16, bh, 16, hs + image_embeds = image_patch_embeds.reshape(bs, bw * 16 * bh * 16, image_patch_embeds.shape[-1]) + + bs, pn, hs = image_embeds.shape + + image_embeds = image_embeds.view(bs, int(pn/self.merge_n**2), int(hs*self.merge_n**2)) + + inputs_llama = self.llama_proj(image_embeds) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + return inputs_llama, atts_llama + + def get_context_emb(self, prompt, img_list): + img_device = img_list[0].device + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + + seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] + + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None): + if prompts is None or len(prompts) == 0: + # prompts is not provided, just return the original image embedding + return img_embeds, atts_img + elif img_embeds is None: + # prompt is provided but there is no image embedding. return the prompt embedding in right padding + self.llama_tokenizer.padding_side = "right" + prompt_tokens = self.llama_tokenizer( + prompts, + return_tensors="pt", + padding="longest", + add_special_tokens=False + ).to(self.device) + prompt_embeds = self.embed_tokens(prompt_tokens.input_ids) + atts_prompt = prompt_tokens.attention_mask + return prompt_embeds, atts_prompt + + else: + # return the multi-modal embedding in right padding + emb_lists = [] + + for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)): + pn = each_img_embed.shape[-2] + if lengths is not None: + each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1]) + each_img_embed = each_img_embed[:lengths[idx] * pn] + + p_segs = each_prompt.split('') + interleave_emb = [] + for idx, seg in enumerate(p_segs[:-1]): + p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_embed = self.embed_tokens(p_tokens.input_ids) + interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1)) + + wrapped_emb = torch.cat(interleave_emb, dim=1) + p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_embed = self.embed_tokens(p_tokens.input_ids) + wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1) + emb_lists.append(wrapped_emb) + + emb_lens = [emb.shape[1] for emb in emb_lists] + pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) + + max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len + wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone() + wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device) + + for i, emb in enumerate(emb_lists): + length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len + wrapped_embs[i, :length] = emb[:, :length] + wrapped_atts[i, :length] = 1 + + return wrapped_embs, wrapped_atts + + def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): + """ + Concatenate the batched input embedding and batched output embedding together. + Both the input and the output embedding should be right padded. + """ + + input_lens = [] + cat_embs = [] + cat_atts = [] + + for i in range(input_embs.size(0)): + input_len = input_atts[i].sum() + input_lens.append(input_len) + + cat_embs.append( + torch.cat([ + input_embs[i][:input_len], + output_embs[i], + input_embs[i][input_len:] + ]) + ) + cat_atts.append( + torch.cat([ + input_atts[i][:input_len], + output_atts[i], + input_atts[i][input_len:] + ]) + ) + # print('===================================') + # print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones]) + # print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2]) + # print('check out emb: ', output_embs[i][:2]) + # print('check out pad emb: ', output_embs[i][-2:]) + # print('+++++++++++++++++++++++++++++++++++') + # + # print('check attn before: ', input_atts[i][:this_input_ones]) + # print('check attn after: ', input_atts[i][this_input_ones:]) + # print('check attn gt before: ', output_atts[i][:3]) + # print('check attn gt after: ', output_atts[i][-3:]) + + cat_embs = torch.stack(cat_embs) + cat_atts = torch.stack(cat_atts) + return cat_embs, cat_atts, input_lens + + def get_conv_emb(self, conv_q, conv_a, conv_img): + """concatenate conversation and make sure the model is only trained to regress the answer""" + + regress_embs_list = [] + targets_list = [] + + batch_size = len(conv_q) + for batch_idx in range(batch_size): + questions, answers = conv_q[batch_idx], conv_a[batch_idx] + assigned_imgs = conv_img[batch_idx] + questions = [self.prompt_wrap( + img_embeds=img, + atts_img=None, + prompts=[q], + lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)] + q_embs = [emb for emb, _ in questions] + + answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers] + cur_emb = [] + cur_target = [] + for i in range(len(questions)): + cur_emb.append(q_embs[i]) + cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100) + + cur_emb.append(self.embed_tokens(answers[i].input_ids)) + cur_target.append(answers[i].input_ids) + + cur_emb = torch.cat(cur_emb, dim=1) + cur_target = torch.cat(cur_target, dim=1) + + regress_embs_list.append(cur_emb) + targets_list.append(cur_target) + + max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len) + + regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device) + regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device) + targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100 + + for batch_idx in range(batch_size): + cur_len = regress_embs_list[batch_idx].shape[1] + regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len] + regress_attn[batch_idx, :cur_len] = 1 + targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len] + + return regress_embeds, regress_attn, targets + + def preparing_embedding(self, samples): + def remove_special_tokens(data): + + # if "instruction_input" in data: + data = [instruct.replace(" [caption]","") for instruct in data] + data = [instruct.replace(" [vqa]","") for instruct in data] + data = [instruct.replace(" [grounding]","") for instruct in data] + data = [instruct.replace(" [identify]","") for instruct in data] + data = [instruct.replace(" [refer]","") for instruct in data] + return data + + ### prepare input tokens + if 'image' in samples: + img_embeds, img_atts = self.encode_img(samples["image"]) + else: + img_embeds = img_atts = None + + if 'conv_q' in samples: + # handeling conversation datasets + conv_q, conv_a = samples['conv_q'], samples['conv_a'] + + connect_sym = samples['connect_sym'][0] + conv_q = [q.split(connect_sym)for q in conv_q] + conv_a = [a.split(connect_sym) for a in conv_a] + conv_img = assign_imgs(conv_q, img_embeds) + + if self.chat_template: + conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q] + + regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img) + cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0] + + else: + instruction = samples["instruction_input"] if "instruction_input" in samples else None + + # print("instruction before", instruction) + if self.remove_template: + instruction = remove_special_tokens(instruction) + # print("instruction after", instruction) + + if self.chat_template: + instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction] + + if 'length' in samples: + # the input is a image train (like videos) + bsz, pn, hs = img_embeds.shape + img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length']) + else: + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction) + + ### prepare target tokens + self.llama_tokenizer.padding_side = "right" + text = [t + self.end_sym for t in samples["answer"]] + + regress_tokens = self.llama_tokenizer( + text, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False + ).to(self.device) + + regress_token_ids = regress_tokens.input_ids + regress_atts = regress_tokens.attention_mask + part_targets = regress_token_ids.masked_fill( + regress_token_ids == self.llama_tokenizer.pad_token_id, -100 + ) + + regress_embeds = self.embed_tokens(regress_token_ids) + + return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets + + def forward(self, samples, reduction="mean"): + # prepare the embedding to condition and the embedding to regress + cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \ + self.preparing_embedding(samples) + + # concat the embedding to condition and the embedding to regress + inputs_embeds, attention_mask, input_lens = \ + self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) + + # get bos token embedding + bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id + bos_embeds = self.embed_tokens(bos) + bos_atts = attention_mask[:, :1] + + # add bos token at the begining + inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_atts, attention_mask], dim=1) + + # ensemble the final targets + targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], + dtype=torch.long).to(self.device).fill_(-100) + for i, target in enumerate(part_targets): + targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos + + with self.maybe_autocast(): + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + reduction=reduction + ) + loss = outputs.loss + + return {"loss": loss} + + @torch.no_grad() + def generate( + self, + images, + texts, + use_nucleus_sampling=False, + num_beams=1, + max_new_tokens=20, + min_length=1, + top_p=0.9, + repetition_penalty=1, + length_penalty=1, + temperature=1, + do_sample=False, + stop_words_ids=[2], + lengths=None, + ): + ''' + function for generate test use + ''' + + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( + stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) + + img_embeds, atts_img = self.encode_img(images.to(self.device)) + if lengths is not None: + image_lists = [] + img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1]) + for idx, img_embed in enumerate(img_embeds): + image_lists.append([img_embed[i][None] for i in range(lengths[idx])]) + else: + image_lists = [[image_emb[None]] for image_emb in img_embeds] + assert len(texts) == len(image_lists) + batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] + + batch_size = len(batch_embs) + max_len = max([emb.shape[1] for emb in batch_embs]) + emb_dim = batch_embs[0].shape[2] + dtype = batch_embs[0].dtype + device = batch_embs[0].device + + embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) + attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) + for i, emb in enumerate(batch_embs): + emb_len = emb.shape[1] + embs[i, -emb_len:] = emb[0] + attn_mask[i, -emb_len:] = 1 + + with self.maybe_autocast(): + outputs = self.llama_model.generate( + inputs_embeds=embs, + attention_mask=attn_mask, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + do_sample=do_sample, + # stopping_criteria=stopping_criteria, + ) + + answers = [] + for output_token in outputs: + if output_token[0] == 0: + output_token = output_token[1:] + output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) + output_texts = output_texts.split('')[0] # remove the stop sign + output_texts = output_texts.replace("", "") + output_texts = output_texts.split(r'[/INST]')[-1].strip() + answers.append(output_texts) + + return answers + + @torch.no_grad() + def multi_select(self, images, texts, answers, num_cand=None): + all_losses = [] + for answer in answers: + choice_samples = { + 'image': images, + 'instruction_input': texts, + 'answer': answer + } + loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1) + all_losses.append(loss) + torch.cuda.empty_cache() + all_losses = torch.cat(all_losses, dim=-1) + if num_cand is not None: + for i in range(all_losses.shape[0]): + all_losses[i, num_cand[i]:] = 9999 + output_class_ranks = torch.argsort(all_losses, dim=-1) + return output_class_ranks.tolist() + + def predict_answers( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=128, + answer_list=None, + prompt="", + length_penalty=0, + **kwargs + ): + ''' + function for open-ended VQA + ''' + images = samples["image"].cuda() + texts = samples["instruction_input"] + + output_text = self.generate( + images=images, + texts=texts, + num_beams=num_beams, + max_new_tokens=max_len, + min_length=min_len, + length_penalty=length_penalty + ) + + if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]: + output_text = self._lemmatize(output_text) + + return output_text + + def predict_class( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=5, + answer_list=None, + prompt="", + length_penalty=0, + **kwargs + ): + ''' + function for multi-choice VQA + ''' + + image = samples["image"].cuda() + instruction = samples['instruction_input'] + answers = samples["choices"] + num_cand = samples["num_choices"] + + ranks = self.multi_select(image, instruction, answers, num_cand) + + pred_ans = [] + for i, rank in enumerate(ranks): + pred = answers[rank[0]][i] + pred_ans.append(pred) + return pred_ans + + def embed_tokens(self, token_ids): + try: + embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) + except AttributeError: + embeds = self.llama_model.model.embed_tokens(token_ids) + + return embeds + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) + + prompt_path = cfg.get("prompt_path", "") + prompt_template = cfg.get("prompt_template", "") + max_txt_len = cfg.get("max_txt_len", 300) + end_sym = cfg.get("end_sym", '\n') + + lora_r = cfg.get("lora_r",64) + lora_alpha = cfg.get("lora_alpha",16) + chat_template = cfg.get("chat_template",False) + system_prompt = cfg.get("system_prompt", False) + token_pooling = cfg.get("token_pooling",True) + + use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False) + max_context_len = cfg.get("max_context_len", 3800) + remove_template = cfg.get("remove_template", False) + + + model = cls( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + prompt_path=prompt_path, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + low_resource=low_resource, + end_sym=end_sym, + lora_r = lora_r, + lora_alpha = lora_alpha, + chat_template = chat_template, + system_prompt = system_prompt, + token_pooling = token_pooling, + use_grad_checkpoint_llm=use_grad_checkpoint_llm, + max_context_len=max_context_len, + remove_template = remove_template + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + + return model + + +def assign_imgs(batched_instruct_list, batched_img_embeds): + '''this function is used when the data is interleaved. + the interlevaed data is separated, and this function assign + corresponding image embeddings to each segment''' + if len(batched_img_embeds.shape) == 3: + batched_img_embeds = batched_img_embeds[:, None] + + batched_assigned = [] + + for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds): + img_idx = 0 + assigned_img = [] + n_assigned = [] + for instruct in instruct_list: + n_img = instruct.count('') + if n_img > 0: # this instruction include images. + assigned_img.append(img_embeds[None, img_idx:img_idx+n_img]) + img_idx += n_img + n_assigned.append(n_img) + else: # this instruction doesn't include images + assigned_img.append(None) + n_assigned.append(None) + batched_assigned.append(assigned_img) + + return batched_assigned \ No newline at end of file diff --git a/minigpt4/models/mistral.py b/minigpt4/models/mistral.py new file mode 100644 index 0000000000000000000000000000000000000000..43095ff1bcf084f9f4946b066510dc0100cb235f --- /dev/null +++ b/minigpt4/models/mistral.py @@ -0,0 +1,25 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +device = "cuda" # the device to load the model onto + +model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") +tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") + +messages = [ + {"role": "user", "content": "What is your favourite condiment?"}, + {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}, + {"role": "user", "content": "Do you have mayonnaise recipes?"} +] +p="Well, I'm quite partial to a good squeeze of fresh lemon juice." +encoded_input = tokenizer(p, return_tensors='pt') +embeds = model.model.embed_tokens(encoded_input.input_ids) +print(embeds.shape) + + +encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt") +model_inputs = encodeds.to(device) +model.to(device) + +generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True) +decoded = tokenizer.batch_decode(generated_ids) +print(decoded[0]) diff --git a/minigpt4/models/modeling_llama_v2.py b/minigpt4/models/modeling_llama_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2802050ff5fe65b7bbef22a4b648c4109c90bc --- /dev/null +++ b/minigpt4/models/modeling_llama_v2.py @@ -0,0 +1,111 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC +from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig + + +class LlamaForCausalLM(LlamaForCausalLMOrig): + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + if reduction == "none": + loss = loss.view(logits.size(0), -1).mean(1) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/models/modeling_mistral.py b/minigpt4/models/modeling_mistral.py new file mode 100644 index 0000000000000000000000000000000000000000..3a98c7de70bd0b13192fd4114fb3cd162953fcb0 --- /dev/null +++ b/minigpt4/models/modeling_mistral.py @@ -0,0 +1,1388 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mistral model.""" +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.models.mistral.configuration_mistral import MistralConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MistralConfig" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral +class MistralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MistralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral +# TODO @Arthur no longer copied from LLama after static cache +class MistralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# TODO @Arthur no longer copied from LLama after static cache +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MistralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MistralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MistralFlashAttention2(MistralAttention): + """ + Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral +# TODO @Arthur no longer copied from LLama after static cache +class MistralSdpaAttention(MistralAttention): + """ + Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MistralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MISTRAL_ATTENTION_CLASSES = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, + "sdpa": MistralSdpaAttention, +} + + +class MistralDecoderLayer(nn.Module): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +MISTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MistralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralPreTrainedModel(PreTrainedModel): + config_class = MistralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MistralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MISTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralModel(MistralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] + + Args: + config: MistralConfig + """ + + def __init__(self, config: MistralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class MistralForCausalLM(MistralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + if reduction == "none": + loss = loss.view(logits.size(0), -1).mean(1) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The Mistral Model transformer with a sequence classification head on top (linear layer). + + [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL +class MistralForSequenceClassification(MistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MistralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file diff --git a/minigpt4/models/policies/__init__.py b/minigpt4/models/policies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d03d7c49eaf465dec6f3c37a6e0684762b5efd9 --- /dev/null +++ b/minigpt4/models/policies/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from .mixed_precision import * +from .wrapping import * +from .activation_checkpointing_functions import apply_fsdp_checkpointing +from .anyprecision_optimizer import AnyPrecisionAdamW +from .fsdp_utils import fsdp_auto_wrap_policy \ No newline at end of file diff --git a/minigpt4/models/policies/activation_checkpointing_functions.py b/minigpt4/models/policies/activation_checkpointing_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1e31f427d1bedc6e7b3eb905e6614f2441be87 --- /dev/null +++ b/minigpt4/models/policies/activation_checkpointing_functions.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import torch +import os +import torch.distributed as dist +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointImpl, + apply_activation_checkpointing, +) + +from transformers.models.t5.modeling_t5 import T5Block +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from functools import partial + +non_reentrant_wrapper = partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, +) + +check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) + + +def apply_fsdp_checkpointing(model): + """apply activation checkpointing to model + returns None as model is updated directly + """ + print(f"--> applying fdsp activation checkpointing...") + + apply_activation_checkpointing( + model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn + ) diff --git a/minigpt4/models/policies/anyprecision_optimizer.py b/minigpt4/models/policies/anyprecision_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..22b0ca00173bd8b40c8982c615a3a04a697d6484 --- /dev/null +++ b/minigpt4/models/policies/anyprecision_optimizer.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# AnyPrecisionAdamW: a flexible precision AdamW optimizer +# with optional Kahan summation for high precision weight updates. +# Allows direct control over momentum, variance and auxiliary compensation +# buffer dtypes. +# Optional Kahan summation is used to offset precision reduction for +# the weight updates. This allows full training in BFloat16 (equal or +# better than FP32 results in many cases) due to high precision weight upates. + +import torch +from torch.optim.optimizer import Optimizer + + +class AnyPrecisionAdamW(Optimizer): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + use_kahan_summation=False, + momentum_dtype=torch.bfloat16, + variance_dtype=torch.bfloat16, + compensation_buffer_dtype=torch.bfloat16, + ): + """ + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + + # Any Precision specific + use_kahan_summation = creates auxiliary buffer to ensure high precision + model param updates (default: False) + momentum_dtype = dtype for momentum (default: BFloat32) + variance_dtype = dtype for uncentered variance (default: BFloat16) + compensation_buffer_dtype = dtype for Kahan summation + buffer (default: BFloat16) + + # Usage + This optimizer implements optimizer states, and Kahan summation + for high precision updates, all in user controlled dtypes. + Defaults are variance in BF16, Momentum in FP32. + This can be run in FSDP mixed precision, amp, or full precision, + depending on what training pipeline you wish to work with. + + Setting to use_kahan_summation = False, and changing momentum and + variance dtypes to FP32, reverts this to a standard AdamW optimizer. + + """ + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + use_kahan_summation=use_kahan_summation, + momentum_dtype=momentum_dtype, + variance_dtype=variance_dtype, + compensation_buffer_dtype=compensation_buffer_dtype, + ) + + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + if closure is not None: + with torch.enable_grad(): + # to fix linter, we do not keep the returned loss for use atm. + closure() + + for group in self.param_groups: + + beta1, beta2 = group["betas"] + lr = group["lr"] + weight_decay = group["weight_decay"] + eps = group["eps"] + use_kahan_summation = group["use_kahan_summation"] + + momentum_dtype = group["momentum_dtype"] + variance_dtype = group["variance_dtype"] + compensation_buffer_dtype = group["compensation_buffer_dtype"] + + for p in group["params"]: + if p.grad is None: + continue + + if p.grad.is_sparse: + raise RuntimeError( + "AnyPrecisionAdamW does not support sparse gradients" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + + state["step"] = torch.tensor(0.0) + + # momentum - EMA of gradient values + state["exp_avg"] = torch.zeros_like( + p, + dtype=momentum_dtype, + ) + + # variance uncentered - EMA of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, + dtype=variance_dtype, + ) + + # optional Kahan summation - accumulated error tracker + if use_kahan_summation: + state["compensation"] = torch.zeros_like( + p, + dtype=compensation_buffer_dtype, + ) + + # main processing ------------------------- + + # update the steps for each param group update + state["step"] += 1 + step = state["step"] + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + + grad = p.grad + + # weight decay, AdamW style + if weight_decay: + p.data.mul_(1 - lr * weight_decay) + + # update momentum + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + # update uncentered variance + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # adjust using bias1 + bias_correction1 = 1 - beta1**step + + step_size = lr / bias_correction1 + + # adjust using bias2 + denom_correction = (1 - beta2**step) ** 0.5 # avoids math import + + centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( + eps, alpha=1 + ) + + # lr update to compensation + if use_kahan_summation: + compensation = state["compensation"] + + compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) + + # update weights with compensation (Kahan summation) + # save error back to compensation for next iteration + temp_buffer = p.detach().clone() + p.data.add_(compensation) + compensation.add_(temp_buffer.sub_(p.data)) + + else: + # usual AdamW updates + p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) \ No newline at end of file diff --git a/minigpt4/models/policies/fsdp_utils.py b/minigpt4/models/policies/fsdp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ed13d2a3f7614ee12e03ff585d0ac91d17a824 --- /dev/null +++ b/minigpt4/models/policies/fsdp_utils.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +def fsdp_auto_wrap_policy(model, transformer_layer_name): + import functools + import os + + from accelerate import FullyShardedDataParallelPlugin + from transformers.models.t5.modeling_t5 import T5Block + from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy + + from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder + + def lambda_policy_fn(module): + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=( + PrefixEncoder, + PromptEncoder, + PromptEmbedding, + transformer_layer_name, + # FullyShardedDataParallelPlugin.get_module_class_from_name( + # model, transformer_layer_name + # ), + ), + ) + + auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) + return auto_wrap_policy \ No newline at end of file diff --git a/minigpt4/models/policies/mixed_precision.py b/minigpt4/models/policies/mixed_precision.py new file mode 100644 index 0000000000000000000000000000000000000000..410ee392edf846da59318bdc80fdd9ab3951cf0f --- /dev/null +++ b/minigpt4/models/policies/mixed_precision.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import torch + +from torch.distributed.fsdp import ( + # FullyShardedDataParallel as FSDP, + # CPUOffload, + MixedPrecision, + # BackwardPrefetch, + # ShardingStrategy, +) + +# requires grad scaler in main loop +fpSixteen = MixedPrecision( + param_dtype=torch.float16, + # Gradient communication precision. + reduce_dtype=torch.float16, + # Buffer precision. + buffer_dtype=torch.float16, +) + +bfSixteen = MixedPrecision( + param_dtype=torch.bfloat16, + # Gradient communication precision. + reduce_dtype=torch.bfloat16, + # Buffer precision. + buffer_dtype=torch.bfloat16, + cast_forward_inputs=True, +) + +bfSixteen_mixed = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, +) + +fp32_policy = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, +) diff --git a/minigpt4/models/policies/wrapping.py b/minigpt4/models/policies/wrapping.py new file mode 100644 index 0000000000000000000000000000000000000000..d9fadc3347add4974ab57b858288c489e23463d3 --- /dev/null +++ b/minigpt4/models/policies/wrapping.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import torch.distributed as dist +import torch.nn as nn +import torch + +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel as FSDP, + CPUOffload, + BackwardPrefetch, + MixedPrecision, +) +from torch.distributed.fsdp.wrap import ( + transformer_auto_wrap_policy, + size_based_auto_wrap_policy, + enable_wrap, + wrap, +) + +import functools +from typing import Type + + +def get_size_policy(min_params=1e8): + num_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=min_params + ) + return num_wrap_policy + + +def get_llama_wrapper(): + """we register our main layer class and use the fsdp transformer wrapping policy + ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers + """ + # ==== use new transformer wrapper + + llama_auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + LlamaDecoderLayer, + }, + ) + + return llama_auto_wrap_policy diff --git a/minigpt4/processors/__init__.py b/minigpt4/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e560eaa15f3266dbc1ffbca70bdc791901737a60 --- /dev/null +++ b/minigpt4/processors/__init__.py @@ -0,0 +1,33 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from minigpt4.processors.base_processor import BaseProcessor +from minigpt4.processors.blip_processors import ( + Blip2ImageTrainProcessor, + Blip2ImageEvalProcessor, + BlipCaptionProcessor, +) + +from minigpt4.common.registry import registry + +__all__ = [ + "BaseProcessor", + "Blip2ImageTrainProcessor", + "Blip2ImageEvalProcessor", + "BlipCaptionProcessor", +] + + +def load_processor(name, cfg=None): + """ + Example + + >>> processor = load_processor("alpro_video_train", cfg=None) + """ + processor = registry.get_processor_class(name).from_config(cfg) + + return processor diff --git a/minigpt4/processors/base_processor.py b/minigpt4/processors/base_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..39b33cdf8fcd97cfd3e4a5fbece6593357af9d41 --- /dev/null +++ b/minigpt4/processors/base_processor.py @@ -0,0 +1,26 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from omegaconf import OmegaConf + + +class BaseProcessor: + def __init__(self): + self.transform = lambda x: x + return + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + return cls() + + def build(self, **kwargs): + cfg = OmegaConf.create(kwargs) + + return self.from_config(cfg) diff --git a/minigpt4/processors/blip_processors.py b/minigpt4/processors/blip_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..c633ed3408d05414072375cc951f7d72f840dd28 --- /dev/null +++ b/minigpt4/processors/blip_processors.py @@ -0,0 +1,164 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import re + +from minigpt4.common.registry import registry +from minigpt4.processors.base_processor import BaseProcessor +from minigpt4.processors.randaugment import RandomAugment +from omegaconf import OmegaConf +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + + +class BlipImageBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + + segment_mean = (0.485, 0.456, 0.406) + segment_std = (0.229, 0.224, 0.225) + + self.normalize = transforms.Normalize(segment_mean, segment_std) + + +@registry.register_processor("blip_caption") +class BlipCaptionProcessor(BaseProcessor): + def __init__(self, prompt="", max_words=50): + self.prompt = prompt + self.max_words = max_words + + def __call__(self, caption): + caption = self.prompt + self.pre_caption(caption) + + return caption + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + prompt = cfg.get("prompt", "") + max_words = cfg.get("max_words", 50) + + return cls(prompt=prompt, max_words=max_words) + + def pre_caption(self, caption): + caption = re.sub( + r"([.!\"()*#:;~])", + " ", + caption.lower(), + ) + caption = re.sub( + r"\s{2,}", + " ", + caption, + ) + caption = caption.rstrip("\n") + caption = caption.strip(" ") + + # truncate caption + caption_words = caption.split(" ") + if len(caption_words) > self.max_words: + caption = " ".join(caption_words[: self.max_words]) + + return caption + + +@registry.register_processor("blip2_image_train") +class Blip2ImageTrainProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): + super().__init__(mean=mean, std=std) + + # self.transform = transforms.Compose( + # [ + # transforms.RandomResizedCrop( + # image_size, + # scale=(min_scale, max_scale), + # interpolation=InterpolationMode.BICUBIC, + # ), + # transforms.ToTensor(), + # self.normalize, + # ] + # ) + self.transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + # ### segment anything + # ''' + # x = (x - self.pixel_mean) / self.pixel_std + + # # Pad + # h, w = x.shape[-2:] + # padh = self.image_encoder.img_size - h + # padw = self.image_encoder.img_size - w + # x = F.pad(x, (0, padw, 0, padh)) + # ''' + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + min_scale = cfg.get("min_scale", 0.5) + max_scale = cfg.get("max_scale", 1.0) + + return cls( + image_size=image_size, + mean=mean, + std=std, + min_scale=min_scale, + max_scale=max_scale, + ) + + +@registry.register_processor("blip2_image_eval") +class Blip2ImageEvalProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls(image_size=image_size, mean=mean, std=std) \ No newline at end of file diff --git a/minigpt4/processors/randaugment.py b/minigpt4/processors/randaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..7034a49ad5fc63b97910790017432617ff4c6d7b --- /dev/null +++ b/minigpt4/processors/randaugment.py @@ -0,0 +1,398 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import cv2 +import numpy as np + +import torch + + +## aug functions +def identity_func(img): + return img + + +def autocontrast_func(img, cutoff=0): + """ + same output as PIL.ImageOps.autocontrast + """ + n_bins = 256 + + def tune_channel(ch): + n = ch.size + cut = cutoff * n // 100 + if cut == 0: + high, low = ch.max(), ch.min() + else: + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + low = np.argwhere(np.cumsum(hist) > cut) + low = 0 if low.shape[0] == 0 else low[0] + high = np.argwhere(np.cumsum(hist[::-1]) > cut) + high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] + if high <= low: + table = np.arange(n_bins) + else: + scale = (n_bins - 1) / (high - low) + offset = -low * scale + table = np.arange(n_bins) * scale + offset + table[table < 0] = 0 + table[table > n_bins - 1] = n_bins - 1 + table = table.clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def equalize_func(img): + """ + same output as PIL.ImageOps.equalize + PIL's implementation is different from cv2.equalize + """ + n_bins = 256 + + def tune_channel(ch): + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + non_zero_hist = hist[hist != 0].reshape(-1) + step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) + if step == 0: + return ch + n = np.empty_like(hist) + n[0] = step // 2 + n[1:] = hist[:-1] + table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def rotate_func(img, degree, fill=(0, 0, 0)): + """ + like PIL, rotate by degree, not radians + """ + H, W = img.shape[0], img.shape[1] + center = W / 2, H / 2 + M = cv2.getRotationMatrix2D(center, degree, 1) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill) + return out + + +def solarize_func(img, thresh=128): + """ + same output as PIL.ImageOps.posterize + """ + table = np.array([el if el < thresh else 255 - el for el in range(256)]) + table = table.clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def color_func(img, factor): + """ + same output as PIL.ImageEnhance.Color + """ + ## implementation according to PIL definition, quite slow + # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] + # out = blend(degenerate, img, factor) + # M = ( + # np.eye(3) * factor + # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) + # )[np.newaxis, np.newaxis, :] + M = np.float32( + [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] + ) * factor + np.float32([[0.114], [0.587], [0.299]]) + out = np.matmul(img, M).clip(0, 255).astype(np.uint8) + return out + + +def contrast_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) + table = ( + np.array([(el - mean) * factor + mean for el in range(256)]) + .clip(0, 255) + .astype(np.uint8) + ) + out = table[img] + return out + + +def brightness_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def sharpness_func(img, factor): + """ + The differences the this result and PIL are all on the 4 boundaries, the center + areas are same + """ + kernel = np.ones((3, 3), dtype=np.float32) + kernel[1][1] = 5 + kernel /= 13 + degenerate = cv2.filter2D(img, -1, kernel) + if factor == 0.0: + out = degenerate + elif factor == 1.0: + out = img + else: + out = img.astype(np.float32) + degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] + out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) + out = out.astype(np.uint8) + return out + + +def shear_x_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, factor, 0], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def translate_x_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, -offset], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def translate_y_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [0, 1, -offset]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def posterize_func(img, bits): + """ + same output as PIL.ImageOps.posterize + """ + out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) + return out + + +def shear_y_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [factor, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def cutout_func(img, pad_size, replace=(0, 0, 0)): + replace = np.array(replace, dtype=np.uint8) + H, W = img.shape[0], img.shape[1] + rh, rw = np.random.random(2) + pad_size = pad_size // 2 + ch, cw = int(rh * H), int(rw * W) + x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) + y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) + out = img.copy() + out[x1:x2, y1:y2, :] = replace + return out + + +### level to args +def enhance_level_to_args(MAX_LEVEL): + def level_to_args(level): + return ((level / MAX_LEVEL) * 1.8 + 0.1,) + + return level_to_args + + +def shear_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 0.3 + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * float(translate_const) + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = int((level / MAX_LEVEL) * cutout_const) + return (level, replace_value) + + return level_to_args + + +def solarize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 256) + return (level,) + + return level_to_args + + +def none_level_to_args(level): + return () + + +def posterize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 4) + return (level,) + + return level_to_args + + +def rotate_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 30 + if np.random.random() < 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +func_dict = { + "Identity": identity_func, + "AutoContrast": autocontrast_func, + "Equalize": equalize_func, + "Rotate": rotate_func, + "Solarize": solarize_func, + "Color": color_func, + "Contrast": contrast_func, + "Brightness": brightness_func, + "Sharpness": sharpness_func, + "ShearX": shear_x_func, + "TranslateX": translate_x_func, + "TranslateY": translate_y_func, + "Posterize": posterize_func, + "ShearY": shear_y_func, +} + +translate_const = 10 +MAX_LEVEL = 10 +replace_value = (128, 128, 128) +arg_dict = { + "Identity": none_level_to_args, + "AutoContrast": none_level_to_args, + "Equalize": none_level_to_args, + "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), + "Solarize": solarize_level_to_args(MAX_LEVEL), + "Color": enhance_level_to_args(MAX_LEVEL), + "Contrast": enhance_level_to_args(MAX_LEVEL), + "Brightness": enhance_level_to_args(MAX_LEVEL), + "Sharpness": enhance_level_to_args(MAX_LEVEL), + "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), + "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "Posterize": posterize_level_to_args(MAX_LEVEL), + "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), +} + + +class RandomAugment(object): + def __init__(self, N=2, M=10, isPIL=False, augs=[]): + self.N = N + self.M = M + self.isPIL = isPIL + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N) + return [(op, 0.5, self.M) for op in sampled_ops] + + def __call__(self, img): + if self.isPIL: + img = np.array(img) + ops = self.get_random_ops() + for name, prob, level in ops: + if np.random.random() > prob: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return img + + +class VideoRandomAugment(object): + def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): + self.N = N + self.M = M + self.p = p + self.tensor_in_tensor_out = tensor_in_tensor_out + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N, replace=False) + return [(op, self.M) for op in sampled_ops] + + def __call__(self, frames): + assert ( + frames.shape[-1] == 3 + ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." + + if self.tensor_in_tensor_out: + frames = frames.numpy().astype(np.uint8) + + num_frames = frames.shape[0] + + ops = num_frames * [self.get_random_ops()] + apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] + + frames = torch.stack( + list(map(self._aug, frames, ops, apply_or_not)), dim=0 + ).float() + + return frames + + def _aug(self, img, ops, apply_or_not): + for i, (name, level) in enumerate(ops): + if not apply_or_not[i]: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return torch.from_numpy(img) + + +if __name__ == "__main__": + a = RandomAugment() + img = np.random.randn(32, 32, 3) + a(img) diff --git a/minigpt4/runners/__init__.py b/minigpt4/runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64e7a4d643a8b5a1714687f42d43347a94b72373 --- /dev/null +++ b/minigpt4/runners/__init__.py @@ -0,0 +1,10 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from minigpt4.runners.runner_base import RunnerBase + +__all__ = ["RunnerBase"] diff --git a/minigpt4/runners/runner_base.py b/minigpt4/runners/runner_base.py new file mode 100644 index 0000000000000000000000000000000000000000..0be9020aa351f94d67ff5069f3896b584d3149e5 --- /dev/null +++ b/minigpt4/runners/runner_base.py @@ -0,0 +1,724 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import json +import logging +import os +import time +from pathlib import Path + +import torch +import torch.distributed as dist +import webdataset as wds +import wandb +from minigpt4.common.dist_utils import ( + download_cached_file, + get_rank, + get_world_size, + is_main_process, + main_process, +) +from minigpt4.common.registry import registry +from minigpt4.common.utils import is_url +from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset +from minigpt4.datasets.datasets.dataloader_utils import ( + IterLoader, + MultiIterLoader, + PrefetchLoader, +) +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, DistributedSampler +from minigpt4.processors.blip_processors import Blip2ImageTrainProcessor,BlipCaptionProcessor +from minigpt4.datasets.datasets.video_datasets import Video_validation_Dataset +from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser +from minigpt4.conversation.conversation import CONV_VISION +from tqdm import tqdm +from omegaconf import OmegaConf + +@registry.register_runner("runner_base") +class RunnerBase: + """ + A runner class to train and evaluate a model given a task and datasets. + + The runner uses pytorch distributed data parallel by default. Future release + will support other distributed frameworks. + """ + + def __init__(self, cfg, task, model, datasets, job_id): + self.config = cfg + self.job_id = job_id + + self.task = task + self.datasets = datasets + + self._model = model + + self._wrapped_model = None + self._device = None + self._optimizer = None + self._scaler = None + self._dataloaders = None + self._lr_sched = None + + self.start_epoch = 0 + + # self.setup_seeds() + self.setup_output_dir() + + @property + def device(self): + if self._device is None: + self._device = torch.device(self.config.run_cfg.device) + + return self._device + + @property + def use_distributed(self): + return self.config.run_cfg.distributed + + @property + def model(self): + """ + A property to get the DDP-wrapped model on the device. + """ + # move model to device + # print("self device",self.device) + # print("self model device",self._model.device) + + # print(self._model.device, self.device) + + if self._model.device != self.device: + self._model = self._model.to(self.device) + + # distributed training wrapper + if self.use_distributed: + if self._wrapped_model is None: + self._wrapped_model = DDP( + self._model, device_ids=[self.config.run_cfg.gpu],find_unused_parameters=False + ) + # + else: + self._wrapped_model = self._model + + return self._wrapped_model + + @property + def optimizer(self): + # TODO make optimizer class and configurations + if self._optimizer is None: + num_parameters = 0 + p_wd, p_non_wd = [], [] + for n, p in self.model.named_parameters(): + if not p.requires_grad: + continue # frozen weights + print(n) + if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n: + p_non_wd.append(p) + else: + p_wd.append(p) + num_parameters += p.data.nelement() + logging.info("number of trainable parameters: %d" % num_parameters) + optim_params = [ + { + "params": p_wd, + "weight_decay": float(self.config.run_cfg.weight_decay), + }, + {"params": p_non_wd, "weight_decay": 0}, + ] + beta2 = self.config.run_cfg.get("beta2", 0.999) + self._optimizer = torch.optim.AdamW( + optim_params, + lr=float(self.config.run_cfg.init_lr), + weight_decay=float(self.config.run_cfg.weight_decay), + betas=(0.9, beta2), + ) + + return self._optimizer + + @property + def scaler(self): + amp = self.config.run_cfg.get("amp", False) + # print("amp", amp) + # assert False + + + if amp: + if self._scaler is None: + self._scaler = torch.cuda.amp.GradScaler() + + return self._scaler + + @property + def lr_scheduler(self): + """ + A property to get and create learning rate scheduler by split just in need. + """ + if self._lr_sched is None: + lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched) + + # max_epoch = self.config.run_cfg.max_epoch + max_epoch = self.max_epoch + # min_lr = self.config.run_cfg.min_lr + min_lr = self.min_lr + # init_lr = self.config.run_cfg.init_lr + init_lr = self.init_lr + + # optional parameters + decay_rate = self.config.run_cfg.get("lr_decay_rate", None) + warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1) + warmup_steps = self.config.run_cfg.get("warmup_steps", 0) + iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None) + + if iters_per_epoch is None: + try: + iters_per_epoch = len(self.dataloaders['train']) + except (AttributeError, TypeError): + iters_per_epoch = 10000 + + self._lr_sched = lr_sched_cls( + optimizer=self.optimizer, + max_epoch=max_epoch, + iters_per_epoch=iters_per_epoch, + min_lr=min_lr, + init_lr=init_lr, + decay_rate=decay_rate, + warmup_start_lr=warmup_start_lr, + warmup_steps=warmup_steps, + ) + + return self._lr_sched + + @property + def dataloaders(self) -> dict: + """ + A property to get and create dataloaders by split just in need. + + If no train_dataset_ratio is provided, concatenate map-style datasets and + chain wds.DataPipe datasets separately. Training set becomes a tuple + (ConcatDataset, ChainDataset), both are optional but at least one of them is + required. The resultant ConcatDataset and ChainDataset will be sampled evenly. + + If train_dataset_ratio is provided, create a MultiIterLoader to sample + each dataset by ratios during training. + + Currently do not support multiple datasets for validation and test. + + Returns: + dict: {split_name: (tuples of) dataloader} + """ + if self._dataloaders is None: + + # concatenate map-style datasets and chain wds.DataPipe datasets separately + # training set becomes a tuple (ConcatDataset, ChainDataset), both are + # optional but at least one of them is required. The resultant ConcatDataset + # and ChainDataset will be sampled evenly. + logging.info( + "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." + ) + + batch_sizes = {dataset_name: getattr(self.config.datasets_cfg, dataset_name).batch_size + for dataset_name in self.datasets.keys()} + datasets, batch_sizes = reorg_datasets_by_split(self.datasets, batch_sizes) + self.datasets = datasets + # self.datasets = concat_datasets(datasets) + + # print dataset statistics after concatenation/chaining + for split_name in self.datasets: + if isinstance(self.datasets[split_name], tuple) or isinstance( + self.datasets[split_name], list + ): + # mixed wds.DataPipeline and torch.utils.data.Dataset + num_records = sum( + [ + len(d) + if not type(d) in [wds.DataPipeline, ChainDataset] + else 0 + for d in self.datasets[split_name] + ] + ) + + else: + if hasattr(self.datasets[split_name], "__len__"): + # a single map-style dataset + num_records = len(self.datasets[split_name]) + else: + # a single wds.DataPipeline + num_records = -1 + logging.info( + "Only a single wds.DataPipeline dataset, no __len__ attribute." + ) + + if num_records >= 0: + logging.info( + "Loaded {} records for {} split from the dataset.".format( + num_records, split_name + ) + ) + + # create dataloaders + split_names = sorted(self.datasets.keys()) + + datasets = [self.datasets[split] for split in split_names] + batch_sizes = [batch_sizes[split] for split in split_names] + is_trains = [split in self.train_splits for split in split_names] + + # batch_sizes = [ + # self.config.run_cfg.batch_size_train + # if split == "train" + # else self.config.run_cfg.batch_size_eval + # for index, split in enumerate(split_names) + # ] + + # print(split_names) + print("batch sizes", batch_sizes) + + collate_fns = [] + for dataset in datasets: + if isinstance(dataset, tuple) or isinstance(dataset, list): + collate_fns.append([getattr(d, "collater", None) for d in dataset]) + else: + collate_fns.append(getattr(dataset, "collater", None)) + + dataloaders = self.create_loaders( + datasets=datasets, + num_workers=self.config.run_cfg.num_workers, + batch_sizes=batch_sizes, + is_trains=is_trains, + collate_fns=collate_fns, + ) + + self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} + + return self._dataloaders + + @property + def cuda_enabled(self): + return self.device.type == "cuda" + + @property + def max_epoch(self): + return int(self.config.run_cfg.max_epoch) + + @property + def log_freq(self): + log_freq = self.config.run_cfg.get("log_freq", 50) + return int(log_freq) + + @property + def init_lr(self): + return float(self.config.run_cfg.init_lr) + + @property + def min_lr(self): + return float(self.config.run_cfg.min_lr) + + @property + def accum_grad_iters(self): + return int(self.config.run_cfg.get("accum_grad_iters", 1)) + + @property + def valid_splits(self): + valid_splits = self.config.run_cfg.get("valid_splits", []) + + if len(valid_splits) == 0: + logging.info("No validation splits found.") + + return valid_splits + + @property + def test_splits(self): + test_splits = self.config.run_cfg.get("test_splits", []) + + return test_splits + + @property + def train_splits(self): + train_splits = self.config.run_cfg.get("train_splits", []) + + if len(train_splits) == 0: + logging.info("Empty train splits.") + + return train_splits + + @property + def evaluate_only(self): + """ + Set to True to skip training. + """ + return self.config.run_cfg.evaluate + + @property + def use_dist_eval_sampler(self): + return self.config.run_cfg.get("use_dist_eval_sampler", True) + + @property + def resume_ckpt_path(self): + return self.config.run_cfg.get("resume_ckpt_path", None) + + @property + def train_loader(self): + train_dataloader = self.dataloaders["train"] + + return train_dataloader + + def setup_output_dir(self): + lib_root = Path(registry.get_path("library_root")) + + output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id + # output_dir = lib_root / self.config.run_cfg.output_dir + result_dir = output_dir / "result" + + output_dir.mkdir(parents=True, exist_ok=True) + result_dir.mkdir(parents=True, exist_ok=True) + + registry.register_path("result_dir", str(result_dir)) + registry.register_path("output_dir", str(output_dir)) + + self.result_dir = result_dir + self.output_dir = output_dir + + def train(self): + start_time = time.time() + best_agg_metric = 0 + best_epoch = 0 + + self.log_config() + + # resume from checkpoint if specified + if not self.evaluate_only and self.resume_ckpt_path is not None: + self._load_checkpoint(self.resume_ckpt_path) + + for cur_epoch in range(self.start_epoch, self.max_epoch): + # training phase + if not self.evaluate_only: + logging.info("Start training") + train_stats = self.train_epoch(cur_epoch) + self.log_stats(split_name="train", stats=train_stats) + + # evaluation phase + # if len(self.valid_splits) > 0 and self.config.run_cfg.video_instruction_eval: + # self._save_checkpoint(cur_epoch, is_best=False) + # for split_name in self.valid_splits: + # logging.info("Evaluating on {}.".format(split_name)) + # ## Add validation + # val_log=self.custom_eval_epoch(cur_epoch) + # # val_log = self.eval_epoch( + # # split_name=split_name,cur_epoch=cur_epoch + # # ) + # print("val log",val_log) + # if val_log is not None: + # if is_main_process(): + # assert ( + # "agg_metrics" in val_log + # ), "No agg_metrics found in validation log." + + # agg_metrics = val_log["agg_metrics"] + # if agg_metrics > best_agg_metric and split_name == "val": + # best_epoch, best_agg_metric = cur_epoch, agg_metrics + + # self._save_checkpoint(cur_epoch, is_best=True) + + # val_log.update({"best_epoch": best_epoch}) + # self.log_stats(val_log, split_name) + # wandb.log({"epoch": cur_epoch, "GPT4_Accuracy": val_log['agg_metrics']}) + # print("Validation finished") + + else: + # if no validation split is provided, we just save the checkpoint at the end of each epoch. + if not self.evaluate_only: + self._save_checkpoint(cur_epoch, is_best=False) + + if self.evaluate_only: + break + + if self.config.run_cfg.distributed: + dist.barrier() + + # testing phase + test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch + self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Training time {}".format(total_time_str)) + + def evaluate(self, cur_epoch="best", skip_reload=False): + test_logs = dict() + + if len(self.test_splits) > 0: + for split_name in self.test_splits: + test_logs[split_name] = self.eval_epoch( + split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload + ) + + return test_logs + + def train_epoch(self, epoch): + # train + self.model.train() + + return self.task.train_epoch( + epoch=epoch, + model=self.model, + data_loader=self.train_loader, + optimizer=self.optimizer, + scaler=self.scaler, + lr_scheduler=self.lr_scheduler, + cuda_enabled=self.cuda_enabled, + log_freq=self.log_freq, + accum_grad_iters=self.accum_grad_iters, + ) + + @torch.no_grad() + def eval_epoch(self, split_name, cur_epoch, skip_reload=False): + """ + Evaluate the model on a given split. + + Args: + split_name (str): name of the split to evaluate on. + cur_epoch (int): current epoch. + skip_reload_best (bool): whether to skip reloading the best checkpoint. + During training, we will reload the best checkpoint for validation. + During testing, we will use provided weights and skip reloading the best checkpoint . + """ + data_loader = self.dataloaders.get(split_name, None) + assert data_loader, "data_loader for split {} is None.".format(split_name) + + # TODO In validation, you need to compute loss as well as metrics + # TODO consider moving to model.before_evaluation() + model = self.unwrap_dist_model(self.model) + if not skip_reload and cur_epoch == "best": + model = self._reload_best_model(model) + model.eval() + + self.task.before_evaluation( + model=model, + dataset=self.datasets[split_name], + ) + results = self.task.evaluation(model, data_loader) + + if results is not None: + return self.task.after_evaluation( + val_result=results, + split_name=split_name, + epoch=cur_epoch, + ) + def get_validation_loader(self): + # TODO make the path configurable + dataset_congif="minigpt4/configs/datasets/video_chatgpt/default.yaml" + # read the dataset config using omegaconf + config = OmegaConf.load(dataset_congif).datasets + config = config[list(config.keys())[0]] + vis_processor=Blip2ImageTrainProcessor() + validation_data = Video_validation_Dataset(vis_processor, + videos_path=config.valid['videos_path'], + ann_path=config.valid['ann_path'], + subtitles_path=config.valid['subtitles_path'], + annotations_keys=config.valid['annotations_keys'], + add_subtitles=config.valid['add_subtitles'],) + validation_dataloader = DataLoader(validation_data, batch_size=1, shuffle=False) + return validation_dataloader + @torch.no_grad() + def custom_eval_epoch(self, cur_epoch): + validation_dataloader=self.get_validation_loader() + model = self.unwrap_dist_model(self.model) + model.eval() + conv_temp = CONV_VISION.copy() + conv_temp.system = "" + results = [] + for images, texts, gt_answers, lengths,videos_ids in tqdm(validation_dataloader): + texts = prepare_texts(texts, conv_temp, template='', lengths=lengths) # warp the texts with conversation template + models_answers = model.generate(images, texts, max_new_tokens=512, do_sample=False, lengths=lengths,num_beams=1) + for video_id,model_answer, gt_answer,text in zip(videos_ids,models_answers, gt_answers,texts): + result = dict() + result['video_name'] = video_id + result['Q'] = text.split('\n')[-1].replace('[/INST]','') + result['A'] = gt_answer + result['pred'] = model_answer + results.append(result) + val_log= self.task.after_evaluation( + val_result=results, + epoch=cur_epoch, + ) + return val_log + def unwrap_dist_model(self, model): + if self.use_distributed: + return model.module + else: + return model + + def create_loaders( + self, + datasets, + num_workers, + batch_sizes, + is_trains, + collate_fns, + dataset_ratios=None, + ): + """ + Create dataloaders for training and validation. + """ + + def _create_loader(dataset, num_workers, bsz, is_train, collate_fn): + # create a single dataloader for each split + if isinstance(dataset, ChainDataset) or isinstance( + dataset, wds.DataPipeline + ): + # wds.WebdDataset instance are chained together + # webdataset.DataPipeline has its own sampler and collate_fn + loader = iter( + DataLoader( + dataset, + batch_size=bsz, + num_workers=num_workers, + pin_memory=True, + ) + ) + else: + # map-style dataset are concatenated together + # setup distributed sampler + + if self.use_distributed: + sampler = DistributedSampler( + dataset, + shuffle=is_train, + num_replicas=get_world_size(), + rank=get_rank(), + ) + if not self.use_dist_eval_sampler: + # e.g. retrieval evaluation + sampler = sampler if is_train else None + else: + sampler = None + + loader = DataLoader( + dataset, + batch_size=bsz, + num_workers=num_workers, + pin_memory=True, + sampler=sampler, + shuffle=sampler is None and is_train, + collate_fn=collate_fn, + drop_last=True if is_train else False, + ) + loader = PrefetchLoader(loader) + + if is_train: + loader = IterLoader(loader, use_distributed=self.use_distributed) + + return loader + + loaders = [] + + for dataset, bsz, is_train, collate_fn in zip( + datasets, batch_sizes, is_trains, collate_fns + ): + if isinstance(dataset, list) or isinstance(dataset, tuple): + if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None: + dataset_ratios = [d.sample_ratio for d in dataset] + loader = MultiIterLoader( + loaders=[ + _create_loader(d, num_workers, bsz[i], is_train, collate_fn[i]) + for i, d in enumerate(dataset) + ], + ratios=dataset_ratios, + ) + else: + loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn) + + loaders.append(loader) + + return loaders + + @main_process + def _save_checkpoint(self, cur_epoch, is_best=False): + """ + Save the checkpoint at the current epoch. + """ + model_no_ddp = self.unwrap_dist_model(self.model) + param_grad_dic = { + k: v.requires_grad for (k, v) in model_no_ddp.named_parameters() + } + state_dict = model_no_ddp.state_dict() + for k in list(state_dict.keys()): + if k in param_grad_dic.keys() and not param_grad_dic[k]: + # delete parameters that do not require gradient + del state_dict[k] + save_obj = { + "model": state_dict, + "optimizer": self.optimizer.state_dict(), + "config": self.config.to_dict(), + "scaler": self.scaler.state_dict() if self.scaler else None, + "epoch": cur_epoch, + } + save_to = os.path.join( + self.output_dir, + "checkpoint_{}.pth".format("best" if is_best else cur_epoch), + ) + logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to)) + torch.save(save_obj, save_to) + + def _reload_best_model(self, model): + """ + Load the best checkpoint for evaluation. + """ + checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth") + + logging.info("Loading checkpoint from {}.".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + try: + model.load_state_dict(checkpoint["model"]) + except RuntimeError as e: + logging.warning( + """ + Key mismatch when loading checkpoint. This is expected if only part of the model is saved. + Trying to load the model with strict=False. + """ + ) + model.load_state_dict(checkpoint["model"], strict=False) + return model + + def _load_checkpoint(self, url_or_filename): + """ + Resume from a checkpoint. + """ + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location=self.device) + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location=self.device) + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + message = self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False) + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if self.scaler and "scaler" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler"]) + + self.start_epoch = checkpoint["epoch"] + 1 + print("resume the checkpoint") + logging.info("Resume checkpoint from {}".format(url_or_filename)) + + @main_process + def log_stats(self, stats, split_name): + if isinstance(stats, dict): + log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}} + with open(os.path.join(self.output_dir, "log.txt"), "a") as f: + f.write(json.dumps(log_stats) + "\n") + elif isinstance(stats, list): + pass + + @main_process + def log_config(self): + with open(os.path.join(self.output_dir, "log.txt"), "a") as f: + f.write(json.dumps(self.config.to_dict(), indent=4) + "\n") diff --git a/minigpt4/tasks/__init__.py b/minigpt4/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f975ab2a59e7ddebc6c1232e29d9de854551d66 --- /dev/null +++ b/minigpt4/tasks/__init__.py @@ -0,0 +1,33 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from minigpt4.common.registry import registry +from minigpt4.tasks.base_task import BaseTask +from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask + +from minigpt4.tasks.vqa import VQATask, GQATask +from minigpt4.tasks.vqa_reading_comprehension import VQARCTask, GQARCTask + + +def setup_task(cfg): + assert "task" in cfg.run_cfg, "Task name must be provided." + + task_name = cfg.run_cfg.task + task = registry.get_task_class(task_name).setup_task(cfg=cfg) + assert task is not None, "Task {} not properly registered.".format(task_name) + + return task + + +__all__ = [ + "BaseTask", + "ImageTextPretrainTask", + "VQATask", + "GQATask", + "VQARCTask", + "GQARCTask", +] diff --git a/minigpt4/tasks/base_task.py b/minigpt4/tasks/base_task.py new file mode 100644 index 0000000000000000000000000000000000000000..95d0c0dccc67608515a9334c5d21a8a932568d97 --- /dev/null +++ b/minigpt4/tasks/base_task.py @@ -0,0 +1,368 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os + +import torch +import torch.distributed as dist +from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized +from minigpt4.common.logger import MetricLogger, SmoothedValue +from minigpt4.common.registry import registry +from minigpt4.datasets.data_utils import prepare_sample + +import wandb +import openai +import ast +openai.api_key_path = "/home/ataallka/chatgpt_api.txt" + +class BaseTask: + def __init__(self, **kwargs): + super().__init__() + + self.inst_id_key = "instance_id" + self.cfg = "" + + + + @classmethod + def setup_task(cls, **kwargs): + + return cls() + + + def build_model(self, cfg): + self.cfg = cfg + model_config = cfg.model_cfg + + model_cls = registry.get_model_class(model_config.arch) + return model_cls.from_config(model_config) + + def build_datasets(self, cfg): + """ + Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'. + Download dataset and annotations automatically if not exist. + + Args: + cfg (common.config.Config): _description_ + + Returns: + dict: Dictionary of torch.utils.data.Dataset objects by split. + """ + + datasets = dict() + + datasets_config = cfg.datasets_cfg + + assert len(datasets_config) > 0, "At least one dataset has to be specified." + + for name in datasets_config: + dataset_config = datasets_config[name] + + builder = registry.get_builder_class(name)(dataset_config) + dataset = builder.build_datasets() + + dataset['train'].name = name + if 'sample_ratio' in dataset_config: + dataset['train'].sample_ratio = dataset_config.sample_ratio + + datasets[name] = dataset + + return datasets + + def train_step(self, model, samples): + loss = model(samples)["loss"] + return loss + + def valid_step(self, model, samples): + answers = model(samples)['answers'] + return answers + + def before_evaluation(self, model, dataset, **kwargs): + model.before_evaluation(dataset=dataset, task_type=type(self)) + def chatgpt_eval(self,question, answer,pred): + try: + # Compute the correctness score + completion = openai.ChatCompletion.create( + # model="gpt-3.5-turbo", + model='gpt-4', + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer." + }, + { + "role": "user", + "content": + "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " + "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + return response_dict + except Exception as e: + print(f"Error : {e}") + return None + def after_evaluation(self, val_result,epoch,**kwargs): + scores=[] + yes_count=0 + no_count=0 + for res in val_result: + gpt_response=self.chatgpt_eval(res['Q'],res['A'],res['pred']) + if gpt_response is None: + continue + try: + scores.append(float(gpt_response['score'])) + if 'yes' in gpt_response['pred'].lower(): + yes_count+=1 + elif 'no' in gpt_response['pred'].lower(): + no_count+=1 + except: + continue + avg_score=sum(scores)/len(scores) + accuracy=(yes_count/(yes_count+no_count))*100 + print(f"Epoch {epoch} chatgpt score: {avg_score} accuracy: {accuracy}") + val_accuracy={"agg_metrics":accuracy,"best_epoch":epoch} + # val_accuracy={"agg_metrics":50.2,"best_epoch":epoch} + return val_accuracy + + def inference_step(self): + raise NotImplementedError + + def evaluation(self, model, data_loader, cuda_enabled=True): + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation" + # TODO make it configurable + print_freq = 10 + results = [] + for samples in metric_logger.log_every(data_loader, print_freq, header): + samples = prepare_sample(samples, cuda_enabled=cuda_enabled) + eval_output = self.valid_step(model=model, samples=samples) + for i,pred in enumerate(eval_output): + res={} + res['video_name'] = samples['image_id'][i] + res['Q'] = samples['instruction_input'][i].split('\n')[-1] + res['A'] = samples['answer'][i] + res['pred'] = pred + results.append(res) + if is_dist_avail_and_initialized(): + dist.barrier() + + return results + + def train_epoch( + self, + epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + cuda_enabled=False, + log_freq=50, + accum_grad_iters=1, + ): + return self._train_inner_loop( + epoch=epoch, + iters_per_epoch=lr_scheduler.iters_per_epoch, + model=model, + data_loader=data_loader, + optimizer=optimizer, + scaler=scaler, + lr_scheduler=lr_scheduler, + log_freq=log_freq, + cuda_enabled=cuda_enabled, + accum_grad_iters=accum_grad_iters, + ) + + def train_iters( + self, + epoch, + start_iters, + iters_per_inner_epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + cuda_enabled=False, + log_freq=50, + accum_grad_iters=1, + ): + return self._train_inner_loop( + epoch=epoch, + start_iters=start_iters, + iters_per_epoch=iters_per_inner_epoch, + model=model, + data_loader=data_loader, + optimizer=optimizer, + scaler=scaler, + lr_scheduler=lr_scheduler, + log_freq=log_freq, + cuda_enabled=cuda_enabled, + accum_grad_iters=accum_grad_iters, + ) + + def _train_inner_loop( + self, + epoch, + iters_per_epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + start_iters=None, + log_freq=50, + cuda_enabled=False, + accum_grad_iters=1, + ): + """ + An inner training loop compatible with both epoch-based and iter-based training. + + When using epoch-based, training stops after one epoch; when using iter-based, + training stops after #iters_per_epoch iterations. + """ + use_amp = scaler is not None + + if not hasattr(data_loader, "__next__"): + # convert to iterator if not already + data_loader = iter(data_loader) + + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) + + # if iter-based runner, schedule lr based on inner epoch. + logging.info( + "Start training epoch {}, {} iters per inner epoch.".format( + epoch, iters_per_epoch + ) + ) + header = "Train: data epoch: [{}]".format(epoch) + if start_iters is None: + # epoch-based runner + inner_epoch = epoch + else: + # In iter-based runner, we schedule the learning rate based on iterations. + inner_epoch = start_iters // iters_per_epoch + header = header + "; inner epoch [{}]".format(inner_epoch) + + for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): + # if using iter-based runner, we stop after iters_per_epoch iterations. + if i >= iters_per_epoch: + break + + samples = next(data_loader) + + samples = prepare_sample(samples, cuda_enabled=cuda_enabled) + samples.update( + { + "epoch": inner_epoch, + "num_iters_per_epoch": iters_per_epoch, + "iters": i, + } + ) + + lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i) + + with torch.cuda.amp.autocast(enabled=use_amp): + loss = self.train_step(model=model, samples=samples) + + # after_train_step() + if use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + # update gradients every accum_grad_iters iterations + if (i + 1) % accum_grad_iters == 0: + if hasattr(model, 'visual_encoder'): + visual_encoder_params = model.visual_encoder.parameters() + else: + visual_encoder_params = model.module.visual_encoder.parameters() + + if use_amp: + scaler.unscale_(optimizer) + # torch.nn.utils.clip_grad_norm_(visual_encoder_params, + # max_norm=0.3) # apply gradient clipping on vit + scaler.step(optimizer) + scaler.update() + else: + # torch.nn.utils.clip_grad_norm_(visual_encoder_params, + # max_norm=0.3) # apply gradient clipping on vit + optimizer.step() + optimizer.zero_grad() + if self.cfg.run_cfg.rank==0: + wandb.log({"epoch": inner_epoch, "loss": loss}) + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + # after train_epoch() + # gather the stats from all processes + metric_logger.synchronize_between_processes() + logging.info("Averaged stats: " + str(metric_logger.global_avg())) + return { + k: "{:.3f}".format(meter.global_avg) + for k, meter in metric_logger.meters.items() + } + + @staticmethod + def save_result(result, result_dir, filename, remove_duplicate=""): + import json + + result_file = os.path.join( + result_dir, "%s_rank%d.json" % (filename, get_rank()) + ) + final_result_file = os.path.join(result_dir, "%s.json" % filename) + + json.dump(result, open(result_file, "w")) + + if is_dist_avail_and_initialized(): + dist.barrier() + + if is_main_process(): + logging.warning("rank %d starts merging results." % get_rank()) + # combine results from all processes + result = [] + + for rank in range(get_world_size()): + result_file = os.path.join( + result_dir, "%s_rank%d.json" % (filename, rank) + ) + res = json.load(open(result_file, "r")) + result += res + + if remove_duplicate: + result_new = [] + id_list = [] + for res in result: + if res[remove_duplicate] not in id_list: + id_list.append(res[remove_duplicate]) + result_new.append(res) + result = result_new + + json.dump(result, open(final_result_file, "w")) + print("result file saved to %s" % final_result_file) + + return final_result_file diff --git a/minigpt4/tasks/image_text_pretrain.py b/minigpt4/tasks/image_text_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..c5cfaf2e9583a1636fe2a9b0249c203b08bf07dd --- /dev/null +++ b/minigpt4/tasks/image_text_pretrain.py @@ -0,0 +1,18 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from minigpt4.common.registry import registry +from minigpt4.tasks.base_task import BaseTask + + +@registry.register_task("image_text_pretrain") +class ImageTextPretrainTask(BaseTask): + def __init__(self): + super().__init__() + + # def evaluation(self, model, data_loader, cuda_enabled=True): + # pass diff --git a/minigpt4/tasks/vqa.py b/minigpt4/tasks/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc2ed9db0ca7b8673882f987c1f5d8949d0d9fe --- /dev/null +++ b/minigpt4/tasks/vqa.py @@ -0,0 +1,343 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import json +import os + +import minigpt4.common.dist_utils as dist_utils +from minigpt4.common.registry import registry +from minigpt4.common.vqa_tools.vqa import VQA +from minigpt4.common.vqa_tools.vqa_eval import VQAEval +from minigpt4.tasks.base_task import BaseTask + + +@registry.register_task("vqa") +class VQATask(BaseTask): + def __init__( + self, + num_beams, + max_len, + min_len, + evaluate, + num_ans_candidates, + inference_method="rank", + prompt="", + ): + super().__init__() + + self.num_beams = num_beams + self.max_len = max_len + self.min_len = min_len + + self.evaluate = evaluate + self.inference_method = inference_method + self.num_ans_candidates = num_ans_candidates + self.prompt = prompt + + self.answer_list = None + + self.ques_files = dict() + self.anno_files = dict() + + @classmethod + def setup_task(cls, cfg): + run_cfg = cfg.run_cfg + + num_beams = run_cfg.get("num_beams", 3) + max_len = run_cfg.get("max_len", 10) + min_len = run_cfg.get("min_len", 1) + + evaluate = run_cfg.get("evaluate", False) + + inference_method = run_cfg.get("inference_method", "rank") + num_ans_candidates = run_cfg.get("num_ans_candidates", 128) + prompt = run_cfg.get("prompt", "") + + return cls( + num_beams=num_beams, + max_len=max_len, + min_len=min_len, + evaluate=evaluate, + num_ans_candidates=num_ans_candidates, + inference_method=inference_method, + prompt=prompt, + ) + + def build_datasets(self, cfg): + datasets = super().build_datasets(cfg) + + # get question file, annotation file and anwser list in COCO format + for dataset in datasets.values(): + for split in dataset: + if ( + hasattr(dataset[split], "coco_fmt_qust_file") + and dataset[split].coco_fmt_qust_file is not None + ): + self.ques_files[split] = dataset[split].coco_fmt_qust_file + self.anno_files[split] = dataset[split].coco_fmt_anno_file + + try: + self.answer_list = dataset[split].answer_list + except AttributeError: + # if answer_list is not provided, then set it to None + pass + + if len(self.ques_files) > 0: + assert len(self.ques_files) == len( + self.anno_files + ), "Only support one split for evaluation." + + return datasets + + def valid_step(self, model, samples): + answers = model.predict_answers( + samples=samples, + answer_list=self.answer_list, + inference_method=self.inference_method, + num_beams=self.num_beams, + max_len=self.max_len, + min_len=self.min_len, + num_ans_candidates=self.num_ans_candidates, + prompt=self.prompt, + ) + pred_qa_pairs = [] + + question_id = samples["question_id"] + for answer, ques_id in zip(answers, question_id): + ques_id = int(ques_id.item()) + pred_qa_pairs.append({"question_id": ques_id, "answer": answer}) + + return pred_qa_pairs + + def after_evaluation(self, val_result, split_name, result_dir): + + result_file = self.save_result( + val_result, + result_dir=result_dir, #registry.get_path("result_dir"), + filename=split_name, + remove_duplicate="question_id", + ) + +# metrics = self._report_metrics(result_file=result_file, split=split_name) + +# return metrics + + @dist_utils.main_process + def _report_metrics(self, result_file, split): + """ + Use official VQA evaluation script to report metrics. + """ + metrics = {} + + if split in self.ques_files and split in self.anno_files: + vqa = VQA(self.anno_files[split], self.ques_files[split]) + vqa_result = vqa.loadRes( + resFile=result_file, quesFile=self.ques_files[split] + ) + + # create vqaEval object by taking vqa and vqaRes + # n is precision of accuracy (number of places after decimal), default is 2 + vqa_scorer = VQAEval(vqa, vqa_result, n=2) + logging.info("Start VQA evaluation.") + vqa_scorer.evaluate() + + # print accuracies + overall_acc = vqa_scorer.accuracy["overall"] + metrics["agg_metrics"] = overall_acc + + + logging.info("Overall Accuracy is: %.02f\n" % overall_acc) + logging.info("Per Answer Type Accuracy is the following:") + + for ans_type in vqa_scorer.accuracy["perAnswerType"]: + logging.info( + "%s : %.02f" + % (ans_type, vqa_scorer.accuracy["perAnswerType"][ans_type]) + ) + metrics[ans_type] = vqa_scorer.accuracy["perAnswerType"][ans_type] + + with open( + os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" + ) as f: + f.write(json.dumps(metrics) + "\n") + + return metrics + +@registry.register_task("gqa") +class GQATask(VQATask): + def valid_step(self, model, samples): + answers = model.predict_answers( + samples=samples, + answer_list=self.answer_list, + inference_method=self.inference_method, + num_beams=self.num_beams, + max_len=self.max_len, + min_len=self.min_len, + num_ans_candidates=self.num_ans_candidates, + prompt=self.prompt, + ) + pred_qa_pairs = [] + + question_id = samples["question_id"] + gt_answers = samples["answer"] + + for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): + ques_id = int(ques_id.item()) + pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer}) + + return pred_qa_pairs + + @dist_utils.main_process + def _report_metrics(self, result_file, split): + """ + TODO: add other evaluation metrics for GQA + """ + + results = json.load(open(result_file, "r")) + acc = [] + vqa_tool = VQAEval() + + for res in results: + if res["gt_ans"] is None: + # prepare test results for leaderboard evaluation + self._save_result_leaderboard(results) + return + + gt_ans = res["gt_ans"] + pred = res["pred_ans"] + + if self.inference_method == "generate": + pred = vqa_tool.processPunctuation(pred) + pred = vqa_tool.processDigitArticle(pred) + + vqa_acc = 1 if pred == gt_ans else 0 + + acc.append(vqa_acc) + + accuracy = sum(acc) / len(acc) * 100 + metrics = {"agg_metrics": accuracy, "acc": accuracy} + + with open( + os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" + ) as f: + f.write(json.dumps(metrics) + "\n") + + logging.info(metrics) + + return metrics + + +@registry.register_task("scienceqa") +class ScienceQATask(GQATask): + def valid_step(self, model, samples): + answers = model.predict_class( + samples=samples, + answer_list=self.answer_list, + inference_method=self.inference_method, + num_beams=self.num_beams, + max_len=self.max_len, + min_len=self.min_len, + num_ans_candidates=self.num_ans_candidates, + prompt=self.prompt, + ) + pred_qa_pairs = [] + + question_id = samples["question_id"] + gt_answers = samples["answer"] + + for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): + ques_id = int(ques_id.item()) + pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer}) + + return pred_qa_pairs + + +@registry.register_task("aok_vqa") +class AOKVQATask(VQATask): + def valid_step(self, model, samples): + answers = model.predict_answers( + samples=samples, + answer_list=self.answer_list, + inference_method=self.inference_method, + num_beams=self.num_beams, + max_len=self.max_len, + min_len=self.min_len, + num_ans_candidates=self.num_ans_candidates, + ) + + pred_qa_pairs = [] + + question_id = samples["question_id"] + gt_answers = samples["direct_answers"] + + for pred_answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): + pred_qa_pairs.append( + {"question_id": ques_id, "pred_ans": pred_answer, "gt_ans": gt_answer} + ) + + return pred_qa_pairs + + @dist_utils.main_process + def _report_metrics(self, result_file, split): + """ + Implementing accuracy computation for AOKVQA, see + https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py#L45 for details. + """ + # TODO add evaluation for multi-choice + + results = json.load(open(result_file, "r")) + acc = [] + + for res in results: + if res["gt_ans"] is None: + # prepare test results for leaderboard evaluation + self._save_result_leaderboard(results) + return + + pred = res["pred_ans"] + gt_ans = res["gt_ans"] + + num_match = sum([pred == gt for gt in gt_ans]) + vqa_acc = min(1.0, num_match / 3.0) + + acc.append(vqa_acc) + + accuracy = sum(acc) / len(acc) * 100 + metrics = {"agg_metrics": accuracy, "acc": accuracy} + + with open( + os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" + ) as f: + f.write(json.dumps(metrics) + "\n") + + logging.info(metrics) + + return metrics + + @dist_utils.main_process + def _save_result_leaderboard(self, results): + """ + Saving the results in the format required for leaderboard evaluation. + + [TODO] add support for multi-choice. + """ + result_leaderboard = dict() + for res in results: + result_leaderboard[res["question_id"]] = { + "direct_answer": res["pred_ans"], + "multiple_choice": "", + } + + result_file = registry.get_path("result_dir") + "_leaderboard.json" + + with open(result_file, "w") as f: + json.dump(result_leaderboard, f) + + + logging.info(f"Saved results for leaderboard evaluation at {result_file}") + diff --git a/minigpt4/tasks/vqa_reading_comprehension.py b/minigpt4/tasks/vqa_reading_comprehension.py new file mode 100644 index 0000000000000000000000000000000000000000..c67b3b759b4e3081acdc888c58783754bfa5f8f3 --- /dev/null +++ b/minigpt4/tasks/vqa_reading_comprehension.py @@ -0,0 +1,248 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import json +import os +import torch +import torch.distributed as dist +from itertools import chain + +import minigpt4.common.dist_utils as dist_utils +from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process +from minigpt4.common.registry import registry +from minigpt4.common.vqa_tools.vqa_eval import VQAEval as VQATool +from minigpt4.tasks.vqa import VQATask + + +@registry.register_task("vqa_reading_comprehension") +class VQARCTask(VQATask): + def __init__( + self, + num_beams, + max_len, + min_len, + evaluate, + num_ans_candidates, + inference_method="rank", + **kwargs, + ): + super().__init__(num_beams, max_len, min_len, evaluate, num_ans_candidates, inference_method) + + self.config = kwargs.get('config') + + @classmethod + def setup_task(cls, cfg): + run_cfg = cfg.run_cfg + + num_beams = run_cfg.get("num_beams", 3) + max_len = run_cfg.get("max_len", 10) + min_len = run_cfg.get("min_len", 1) + + evaluate = run_cfg.get("evaluate", False) + + inference_method = run_cfg.get("inference_method", "rank") + num_ans_candidates = run_cfg.get("num_ans_candidates", 128) + + return cls( + num_beams=num_beams, + max_len=max_len, + min_len=min_len, + evaluate=evaluate, + num_ans_candidates=num_ans_candidates, + inference_method=inference_method, + config=run_cfg, + ) + + def valid_step(self, model, samples): + answers, captions, gradcams = model.predict_answers( + samples=samples, + inference_method=self.inference_method, + num_beams=self.num_beams, + max_len=self.max_len, + min_len=self.min_len, + internal_bsz_fid=self.config['internal_bsz_fid'], + num_captions=self.config['num_captions'], + num_captions_fid=self.config['num_captions_fid'], + cap_max_length=self.config['cap_max_length'], + cap_min_length=self.config['cap_min_length'], + top_k=self.config['top_k'], + top_p=self.config['top_p'], + repetition_penalty=self.config['repetition_penalty'], + num_patches=self.config['num_patches'], + block_num=self.config['block_num'], + ) + + pred_qa_pairs = [] + sample_captions = [] + sample_gradcams = [] + + question_id = samples["question_id"] + for answer, caption, gradcam, ques_id in zip(answers, captions, gradcams, question_id): + ques_id = int(ques_id.item()) + pred_qa_pairs.append({"question_id": ques_id, "answer": answer}) + sample_captions.append({"question_id": ques_id, "caption": caption}) + sample_gradcams.append({"question_id": ques_id, "gradcam": gradcam}) + + return [sample_gradcams, sample_captions, pred_qa_pairs] + + def after_evaluation(self, val_result, split_name, **kwargs): + result_ = list(chain(*val_result[0::3])) + result_file = self.save_gradcam( + result_, + result_dir=registry.get_path("result_dir"), + filename=f"{split_name}_gradcam_result", + remove_duplicate="question_id", + ) + + result_ = list(chain(*val_result[1::3])) + result_file = self.save_result( + result_, + result_dir=registry.get_path("result_dir"), + filename=f"{split_name}_caption_result", + remove_duplicate="question_id", + ) + + result_ = list(chain(*val_result[2::3])) + result_file = self.save_result( + result_, + result_dir=registry.get_path("result_dir"), + filename=f"{split_name}_vqa_result", + remove_duplicate="question_id", + ) + + metrics = self._report_metrics(result_file=result_file, split=split_name) + + return metrics + + def save_gradcam(self, result, result_dir, filename, remove_duplicate=""): + result_file = os.path.join(result_dir, '%s_rank%d.pth' % (filename, get_rank())) + final_result_file = os.path.join(result_dir, '%s.pth' % filename) + torch.save({'result': result}, result_file) + + dist.barrier() + + if is_main_process(): + logging.warning("rank %d starts merging results." % get_rank()) + # combine results from all processes + result = [] + + for rank in range(get_world_size()): + result_file = os.path.join(result_dir, '%s_rank%d.pth' % (filename, rank)) + res_ckpt = torch.load(result_file, map_location='cpu') + res = res_ckpt['result'] + + result += res + + if remove_duplicate: + result_new = [] + id_list = [] + for res in result: + if res[remove_duplicate] not in id_list: + id_list.append(res[remove_duplicate]) + result_new.append(res) + result = result_new + + torch.save({'result': result}, final_result_file) + print("result file saved to %s" % final_result_file) + + return final_result_file + + +@registry.register_task("gqa_reading_comprehension") +class GQARCTask(VQARCTask): + def valid_step(self, model, samples): + answers, captions, gradcams = model.predict_answers( + samples=samples, + inference_method=self.inference_method, + num_beams=self.num_beams, + max_len=self.max_len, + min_len=self.min_len, + internal_bsz_fid=self.config['internal_bsz_fid'], + num_captions=self.config['num_captions'], + num_captions_fid=self.config['num_captions_fid'], + cap_max_length=self.config['cap_max_length'], + cap_min_length=self.config['cap_min_length'], + top_k=self.config['top_k'], + top_p=self.config['top_p'], + repetition_penalty=self.config['repetition_penalty'], + num_patches=self.config['num_patches'], + block_num=self.config['block_num'], + ) + + pred_qa_pairs = [] + sample_captions = [] + sample_gradcams = [] + + question_id = samples["question_id"] + gt_answers = samples["answer"] + + for pred_answer, caption, gradcam, ques_id, gt_answer in zip(answers, captions, gradcams, question_id, gt_answers): + ques_id = int(ques_id.item()) + pred_qa_pairs.append({"question_id": ques_id, "pred_ans": pred_answer, "gt_ans": gt_answer}) + sample_captions.append({"question_id": ques_id, "caption": caption}) + sample_gradcams.append({"question_id": ques_id, "gradcam": gradcam}) + + return [sample_gradcams, sample_captions, pred_qa_pairs] + + @dist_utils.main_process + def _report_metrics(self, result_file, split): + """ + TODO: add other evaluation metrics for GQA + """ + + results = json.load(open(result_file, "r")) + acc = [] + vqa_tool = VQATool() + + for res in results: + if res["gt_ans"] is None: + # prepare test results for leaderboard evaluation + self._save_result_leaderboard(results) + return + + gt_ans = res["gt_ans"] + pred = res["pred_ans"] + + if self.inference_method == "generate": + pred = vqa_tool.processPunctuation(pred) + pred = vqa_tool.processDigitArticle(pred) + + vqa_acc = 1 if pred == gt_ans else 0 + + acc.append(vqa_acc) + + accuracy = sum(acc) / len(acc) * 100 + metrics = {"agg_metrics": accuracy, "acc": accuracy} + + with open( + os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" + ) as f: + f.write(json.dumps(metrics) + "\n") + + logging.info(metrics) + + return metrics + + @dist_utils.main_process + def _save_result_leaderboard(self, results): + """ + Saving the results in the format required for leaderboard evaluation. + """ + result_leaderboard = [] + for res in results: + result_leaderboard.append({ + "questionId": str(res['question_id']), + "prediction": str(res["pred_ans"]), + }) + + result_file = registry.get_path("result_dir") + "_leaderboard.json" + + with open(result_file, "w") as f: + json.dump(result_leaderboard, f) + + logging.info(f"Saved results for leaderboard evaluation at {result_file}") \ No newline at end of file diff --git a/minigpt4_video_demo.py b/minigpt4_video_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..967668e43a0b7f74caead7df8e244e076628a2c8 --- /dev/null +++ b/minigpt4_video_demo.py @@ -0,0 +1,348 @@ +import torch +import webvtt +import os +import cv2 +from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, eval_bleu,eval_cider,chat_gpt_eval +from minigpt4.conversation.conversation import CONV_VISION +from torchvision import transforms +import json +from tqdm import tqdm +import soundfile as sf +import argparse +import moviepy.editor as mp +import gradio as gr +from pytubefix import YouTube +import shutil +from PIL import Image +from moviepy.editor import VideoFileClip +from theme import minigptlv_style, custom_css,text_css + +from huggingface_hub import login +hf_token = os.environ.get('HF_TKN') +login(token=hf_token) + +def create_video_grid(images, rows, cols,save_path): + image_width, image_height = images[0].size + grid_width = cols * image_width + grid_height = rows * image_height + + new_image = Image.new("RGB", (grid_width, grid_height)) + + for i in range(rows): + for j in range(cols): + index = i * cols + j + if index < len(images): + image = images[index] + x_offset = j * image_width + y_offset = i * image_height + new_image.paste(image, (x_offset, y_offset)) + # new_image.save(save_path) + return new_image + +def prepare_input(vis_processor,video_path,subtitle_path,instruction): + cap = cv2.VideoCapture(video_path) + if subtitle_path is not None: + # Load the VTT subtitle file + vtt_file = webvtt.read(subtitle_path) + print("subtitle loaded successfully") + clip = VideoFileClip(video_path) + total_num_frames = int(clip.duration * clip.fps) + # print("Video duration = ",clip.duration) + clip.close() + else : + # calculate the total number of frames in the video using opencv + total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if "mistral" in args.ckpt : + max_images_length=90 + max_sub_len = 800 + else: + max_images_length = 45 + max_sub_len = 400 + images = [] + frame_count = 0 + sampling_interval = int(total_num_frames / max_images_length) + if sampling_interval == 0: + sampling_interval = 1 + img_placeholder = "" + subtitle_text_in_interval = "" + history_subtitles = {} + # raw_frames=[] + number_of_words=0 + transform=transforms.Compose([ + transforms.ToPILImage(), + ]) + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + # Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle + # we choose 1 frame for every 2 seconds,so we need to combine the subtitles in the interval of 2 seconds + if subtitle_path is not None: + for subtitle in vtt_file: + sub=subtitle.text.replace('\n',' ') + if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds) and sub not in subtitle_text_in_interval: + if not history_subtitles.get(sub,False): + subtitle_text_in_interval+=sub+" " + history_subtitles[sub]=True + break + if frame_count % sampling_interval == 0: + # raw_frames.append(Image.fromarray(cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB))) + frame = transform(frame[:,:,::-1]) # convert to RGB + frame = vis_processor(frame) + images.append(frame) + img_placeholder += '' + if subtitle_path is not None and subtitle_text_in_interval != "" and number_of_words< max_sub_len: + img_placeholder+=f'{subtitle_text_in_interval}' + number_of_words+=len(subtitle_text_in_interval.split(' ')) + subtitle_text_in_interval = "" + frame_count += 1 + + if len(images) >= max_images_length: + break + cap.release() + cv2.destroyAllWindows() + if len(images) == 0: + # skip the video if no frame is extracted + return None,None + # video_grid_image=create_video_grid(raw_frames,8,len(raw_frames)//8,"concatenated.jpg") + images = torch.stack(images) + instruction = img_placeholder + '\n' + instruction + return images,instruction +def extract_audio(video_path, audio_path): + video_clip = mp.VideoFileClip(video_path) + audio_clip = video_clip.audio + audio_clip.write_audiofile(audio_path, codec="libmp3lame", bitrate="320k") + +def generate_subtitles(video_path): + video_id=video_path.split('/')[-1].split('.')[0] + audio_path = f"workspace/inference_subtitles/mp3/{video_id}"+'.mp3' + os.makedirs("workspace/inference_subtitles/mp3",exist_ok=True) + if existed_subtitles.get(video_id,False): + return f"workspace/inference_subtitles/{video_id}"+'.vtt' + try: + extract_audio(video_path,audio_path) + print("successfully extracted") + os.system(f"whisper {audio_path} --language English --model large --output_format vtt --output_dir workspace/inference_subtitles") + # remove the audio file + os.system(f"rm {audio_path}") + print("subtitle successfully generated") + return f"workspace/inference_subtitles/{video_id}"+'.vtt' + except Exception as e: + print("error",e) + print("error",video_path) + return None + + +def run (video_path,instruction,model,vis_processor,gen_subtitles=True): + if gen_subtitles: + subtitle_path=generate_subtitles(video_path) + else : + subtitle_path=None + prepared_images,prepared_instruction=prepare_input(vis_processor,video_path,subtitle_path,instruction) + if prepared_images is None: + return "Video cann't be open ,check the video path again" + length=len(prepared_images) + prepared_images=prepared_images.unsqueeze(0) + conv = CONV_VISION.copy() + conv.system = "" + # if you want to make conversation comment the 2 lines above and make the conv is global variable + conv.append_message(conv.roles[0], prepared_instruction) + conv.append_message(conv.roles[1], None) + prompt = [conv.get_prompt()] + answers = model.generate(prepared_images, prompt, max_new_tokens=args.max_new_tokens, do_sample=True, lengths=[length],num_beams=2) + # remove the subtitle file and the video file + if subtitle_path: + os.system(f"rm {subtitle_path}") + #if video_path.split('.')[-1] == 'mp4' or video_path.split('.')[-1] == 'mkv' or video_path.split('.')[-1] == 'avi': + # os.system(f"rm {video_path}") + return answers[0] + +def run_single_image (image_path,instruction,model,vis_processor): + image=Image.open(image_path) + image = vis_processor(image) + prepared_images=torch.stack([image]) + prepared_instruction=''+instruction + length=len(prepared_images) + prepared_images=prepared_images.unsqueeze(0) + conv = CONV_VISION.copy() + conv.system = "" + # if you want to make conversation comment the 2 lines above and make the conv is global variable + conv.append_message(conv.roles[0], prepared_instruction) + conv.append_message(conv.roles[1], None) + prompt = [conv.get_prompt()] + answers = model.generate(prepared_images, prompt, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=[length],num_beams=1) + return answers[0] + +def download_video(youtube_url, download_finish): + video_id=youtube_url.split('v=')[-1].split('&')[0] + # Create a YouTube object + youtube = YouTube(youtube_url) + # Get the best available video stream + video_stream = youtube.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first() + # if has_subtitles: + # Download the video to the workspace folder + print('Downloading video') + video_stream.download(output_path="workspace",filename=f"{video_id}.mp4") + print('Video downloaded successfully') + processed_video_path= f"workspace/{video_id}.mp4" + download_finish = gr.State(value=True) + return processed_video_path, download_finish + +def get_video_url(url,has_subtitles): + # get video id from url + video_id=url.split('v=')[-1].split('&')[0] + # Create a YouTube object + youtube = YouTube(url) + # Get the best available video stream + video_stream = youtube.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first() + # if has_subtitles: + # Download the video to the workspace folder + print('Downloading video') + video_stream.download(output_path="workspace",filename=f"{video_id}.mp4") + print('Video downloaded successfully') + return f"workspace/{video_id}.mp4" + # else: + # return video_stream.url + + +def get_arguments(): + parser = argparse.ArgumentParser(description="Inference parameters") + parser.add_argument("--cfg-path", help="path to configuration file.",default="test_configs/llama2_test_config.yaml") + parser.add_argument("--ckpt", type=str,default='checkpoints/video_llama_checkpoint_last.pth', help="path to checkpoint") + parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens") + parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model") + parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + return parser.parse_args() +args=get_arguments() +model, vis_processor = init_model(args) +conv = CONV_VISION.copy() +conv.system = "" +inference_subtitles_folder="workspace/inference_subtitles" +os.makedirs(inference_subtitles_folder,exist_ok=True) +existed_subtitles={} +for sub in os.listdir(inference_subtitles_folder): + existed_subtitles[sub.split('.')[0]]=True + +def gradio_demo_local(video_path,has_sub,instruction): + pred=run(video_path,instruction,model,vis_processor,gen_subtitles=has_sub) + return pred + +def gradio_demo_youtube(youtube_url,has_sub,instruction): + video_path=get_video_url(youtube_url,has_sub) + pred=run(video_path,instruction,model,vis_processor,gen_subtitles=has_sub) + return pred + +def use_example(url,has_sub_1,q): + # set the youtube link and the question with the example values + youtube_link.value=url + has_subtitles.value=has_sub_1 + question.value=q + + +title = """

MiniGPT4-video 🎞️🍿

""" +description = """
This is the demo of MiniGPT4-video Model.
""" +project_page = """

""" +code_link="""

""" +paper_link="""

""" +#video_path="" +with gr.Blocks(title="MiniGPT4-video 🎞️🍿",css=text_css ) as demo : + # with gr.Row(): + # with gr.Column(scale=2): + gr.Markdown(title) + gr.Markdown(description) + # gr.Image("repo_imgs/Designer_2_new.jpeg",scale=1,show_download_button=False,show_label=False) + # with gr.Row(): + # gr.Markdown(project_page) + # gr.Markdown(code_link) + # gr.Markdown(paper_link) + + with gr.Tab("Local videos"): + # local_interface=gr.Interface( + # fn=gradio_demo_local, + # inputs=[gr.Video(sources=["upload"]),gr.Checkbox(label='Use subtitles'),gr.Textbox(label="Write any Question")], + # outputs=["text", + # ], + + # # title="

Local videos

", + # description="Upload your videos with length from one to two minutes", + # examples=[ + # ["example_videos/sample_demo_1.mp4", True, "Why is this video funny"], + # ["example_videos/sample_demo_2.mp4", False, "Generate a creative advertisement for this product."], + # ["example_videos/sample_demo_3.mp4", False, "Write a poem inspired by this video."], + # ], + # css=custom_css, # Apply custom CSS + # allow_flagging='auto' + # ) + with gr.Row(): + with gr.Column(): + video_player_local = gr.Video(sources=["upload"]) + question_local = gr.Textbox(label="Your Question", placeholder="Default: What's this video talking about?") + has_subtitles_local = gr.Checkbox(label="Use subtitles", value=True) + process_button_local = gr.Button("Answer the Question (QA)") + + with gr.Column(): + answer_local=gr.Text("Answer will be here",label="MiniGPT4-video Answer") + + process_button_local.click(fn=gradio_demo_local, inputs=[video_player_local, has_subtitles_local, question_local], outputs=[answer_local]) + + with gr.Tab("Youtube videos"): + # youtube_interface=gr.Interface( + # fn=gradio_demo_youtube, + # inputs=[gr.Textbox(label="Enter the youtube link"),gr.Checkbox(label='Use subtitles'),gr.Textbox(label="Write any Question")], + # outputs=["text", + # ], + # # title="

YouTube videos

", + # description="Videos length should be from one to two minutes", + # examples=[ + # ["https://www.youtube.com/watch?v=8kyg5u6o21k", True, "What happens in this video?"], + # ["https://www.youtube.com/watch?v=zWfX5jeF6k4", True, "what is the main idea in this video?"], + # ["https://www.youtube.com/watch?v=W5PRZuaQ3VM", True, "Inspired by this video content suggest a creative advertisement about the same content."], + # ["https://www.youtube.com/watch?v=W8jcenQDXYg", True, "Describe what happens in this video."], + # ["https://www.youtube.com/watch?v=u3ybWiEUaUU", True, "what is creative in this video ?"], + # ["https://www.youtube.com/watch?v=nEwfSZfz7pw", True, "What Monica did in this video ?"], + # ], + # css=custom_css, # Apply custom CSS + # allow_flagging='auto', + # ) + with gr.Row(): + with gr.Column(): + youtube_link = gr.Textbox(label="Enter the youtube link", placeholder="Paste YouTube URL here") + video_player = gr.Video(autoplay=False) + download_finish = gr.State(value=False) + youtube_link.change( + fn=download_video, + inputs=[youtube_link, download_finish], + outputs=[video_player, download_finish] + ) + question = gr.Textbox(label="Your Question", placeholder="Default: What's this video talking about?") + has_subtitles = gr.Checkbox(label="Use subtitles", value=True) + process_button = gr.Button("Answer the Question (QA)") + + with gr.Column(): + answer=gr.Text("Answer will be here",label="MiniGPT4-video Answer") + + process_button.click(fn=gradio_demo_youtube, inputs=[youtube_link, has_subtitles, question], outputs=[answer]) + ## Add examples to make the demo more interactive and user-friendly + # with gr.Row(): + # url_1=gr.Text("https://www.youtube.com/watch?v=8kyg5u6o21k") + # has_sub_1=True + # q_1=gr.Text("What happens in this video?") + # # add button to change the youtube link and the question with the example values + # use_example_1_btn=gr.Button("Use this example") + # use_example_1_btn.click(use_example,inputs=[url_1,has_sub_1,q_1]) + + + + +if __name__ == "__main__": + demo.queue(max_size=10).launch(share=False,show_error=True, show_api=False) + + + diff --git a/minigpt4_video_inference.py b/minigpt4_video_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d6da7d2fc72bb16ff038b9158e69e0e74e193486 --- /dev/null +++ b/minigpt4_video_inference.py @@ -0,0 +1,94 @@ +import json +from tqdm import tqdm +from pytubefix import YouTube + +import xml.etree.ElementTree as ET +import os + +with open ('VideoInstruct100K.json','r') as f : + data=json.load(f) + +# Usage +existed_video_id={} +for video_name in os.listdir('videos'): + video_id = video_name.split('.')[0] + existed_video_id[video_id]=True + + + +def download_video_with_subtitles(video_id): + # Create a YouTube object. + yt = YouTube(f'https://www.youtube.com/watch?v={video_id}') + + video_filename = f"{video_id}.mp4" + video_downloaded=False + try : + # Get the video stream with the highest resolution and download the video. + stream = yt.streams.get_highest_resolution() + stream.download(output_path='videos', filename=video_filename) + video_downloaded=True + except Exception as e: + print(f"Error downloading video {video_id}: {str(e)}") + video_downloaded=False + if not video_downloaded: + return False,False + + # Get the video's available captions (subtitles). + captions = yt.captions.all() + + # Download the captions if available in xml format. + caption_downloaded = False + for caption in captions: + caption_code = caption.code + # select only english captions + if 'en' in caption_code: + caption.download(title=f"{video_id}", output_path='subtitles_xml',srt=False) + caption_downloaded = True + return video_downloaded,caption_downloaded +def convert_xml_vtt(xml_path, vtt_path): + # Parse the XML subtitle file + tree = ET.parse(xml_path) + root = tree.getroot() + + # Initialize a list to store VTT subtitle entries + vtt_subtitle = [] + + # Function to convert time in milliseconds to WebVTT format + def ms_to_vtt_time(milliseconds): + seconds, milliseconds = divmod(milliseconds, 1000) + minutes, seconds = divmod(seconds, 60) + return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}" + + # Iterate through subtitle elements + toggle = True + for p in root.findall(".//p"): + if toggle: + start_time = int(p.get("t")) + subtitle_text = " ".join(s.text.strip() for s in p.findall(".//s")) + # duration = int(p.get("d")) if p.get("d") is not None else 0 + if not toggle: + end_time = int(p.get("t")) + # Format and append the VTT entry to the list + vtt_subtitle.append(f"{ms_to_vtt_time(start_time)} --> {ms_to_vtt_time(end_time)}\n{subtitle_text}\n") + toggle = not toggle + # Join the VTT entries into a single string + vtt_content = "WEBVTT\n\n" + "\n".join(vtt_subtitle) + + # Save the VTT content to a file + with open(vtt_path, "w", encoding="utf-8") as vtt_file: + vtt_file.write(vtt_content) +import os +os.makedirs('videos', exist_ok=True) +os.makedirs('subtitles_vtt', exist_ok=True) +os.makedirs('subtitles_xml', exist_ok=True) +for video_path in tqdm(data,desc='Downloading videos') : + video_id=video_path.split('/')[-1].split('.')[0] + if existed_video_id.get(video_id,False): + continue + video_downloaded,caption_downloaded=download_video_with_subtitles(video_id) + if caption_downloaded: + # convert xml to vtt + xml_file_path=f'subtitles_xml/{video_id} (a.en).xml' + convert_xml_vtt(xml_file_path,f'subtitles_vtt/{video_id}.vtt') + + diff --git a/repo_imgs/MiniGPT4-video_fig.jpg b/repo_imgs/MiniGPT4-video_fig.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b8077df893678b602c487c04b74cf3c2a674de1f Binary files /dev/null and b/repo_imgs/MiniGPT4-video_fig.jpg differ diff --git a/repo_imgs/sample_1.gif b/repo_imgs/sample_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..068d4a91f4ed88d3623788b6d495decc94b20914 --- /dev/null +++ b/repo_imgs/sample_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee2db1948b2abb62f3935dd8f077747055bb77cdb277a584e40ef1ce653e08c6 +size 4546850 diff --git a/repo_imgs/sample_2.gif b/repo_imgs/sample_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..8807ed142f23ffcdc6659cc51063f7da1c480b92 --- /dev/null +++ b/repo_imgs/sample_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e661c29d150bc1b9d8cc1c3499a93a5594327dcea9ded21ae322b14b9a30504c +size 3829382 diff --git a/repo_imgs/sample_3.gif b/repo_imgs/sample_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..3452272f1c76b83d8baa371a5b1c2769c79d03a3 --- /dev/null +++ b/repo_imgs/sample_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fab764d8ed82d13d7431e73f59644fd947fc418d6b2e020abf2e8a181b465281 +size 5628166 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..dd6760aaf91ffbf6374fb47f1a9745731d615a3c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,232 @@ +accelerate==0.25.0 +aiofiles==23.2.1 +aiohttp==3.9.1 +aiosignal==1.3.1 +altair==5.2.0 +annotated-types==0.6.0 +antlr4-python3-runtime==4.9.3 +anyio==4.2.0 +appdirs==1.4.4 +asgiref==3.7.2 +async-timeout==4.0.3 +attrs==23.2.0 +backoff==2.2.1 +bcrypt==4.1.2 +beautifulsoup4==4.12.2 +bitarray==2.9.2 +bitsandbytes==0.42.0 +bleach==6.1.0 +blinker==1.7.0 +braceexpand==0.1.7 +build==1.0.3 +cachetools==5.3.2 +chardet==5.2.0 +chroma-hnswlib==0.7.3 +chromadb==0.4.22 +click==8.1.7 +cmake==3.25.0 +colbert-ai==0.2.18 +coloredlogs==15.0.1 +contourpy==1.2.0 +cycler==0.12.1 +datasets==2.17.0 +decorator==4.4.2 +decord==0.6.0 +deprecated==1.2.14 +dill==0.3.8 +docker-pycreds==0.4.0 +docopt==0.6.2 +einops==0.7.0 +exceptiongroup==1.2.0 +faiss-gpu==1.7.2 +fastapi==0.108.0 +ffmpeg==1.4 +ffmpeg-python==0.2.0 +ffmpy==0.3.1 +filelock==3.13.1 +flash-attn==2.5.4 +flask==3.0.2 +flatbuffers==23.5.26 +fonttools==4.47.0 +frozenlist==1.4.1 +fsspec==2023.10.0 +ftfy==6.1.3 +future==0.18.3 +gdown==4.7.1 +git-python==1.0.3 +gitdb==4.0.11 +gitpython==3.1.40 +google-auth==2.26.1 +googleapis-common-protos==1.62.0 +gradio +gradio-client +h11==0.14.0 +h5py==3.10.0 +httpcore==1.0.2 +httptools==0.6.1 +httpx==0.26.0 +huggingface-hub==0.21.1 +humanfriendly==10.0 +imageio==2.33.1 +imageio-ffmpeg==0.4.9 +importlib-metadata==6.11.0 +importlib-resources==6.1.1 +inquirerpy==0.3.4 +iopath==0.1.10 +itsdangerous==2.1.2 +jinja2==3.1.2 +joblib==1.3.2 +jsonschema==4.20.0 +jsonschema-specifications==2023.12.1 +kaggle==1.6.0 +kiwisolver==1.4.5 +kubernetes==29.0.0 +lazy-loader==0.3 +lit==15.0.7 +llvmlite==0.41.1 +markdown-it-py==3.0.0 +matplotlib==3.8.2 +mdurl==0.1.2 +mmh3==4.1.0 +monotonic==1.6 +more-itertools==10.1.0 +moviepy==1.0.3 +mpmath==1.3.0 +multidict==6.0.4 +multiprocess==0.70.16 +mutagen==1.47.0 +networkx==3.2.1 +ninja==1.11.1.1 +nltk==3.8.1 +numba==0.58.1 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu11==8.5.0.96 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.18.1 +nvidia-nvjitlink-cu12==12.3.101 +nvidia-nvtx-cu12==12.1.105 +omegaconf==2.3.0 +onnxruntime==1.16.3 +openai==0.28.0 +openai-whisper==20231117 +opencv-python==4.7.0.72 +opentelemetry-api==1.22.0 +opentelemetry-exporter-otlp-proto-common==1.22.0 +opentelemetry-exporter-otlp-proto-grpc==1.22.0 +opentelemetry-instrumentation==0.43b0 +opentelemetry-instrumentation-asgi==0.43b0 +opentelemetry-instrumentation-fastapi==0.43b0 +opentelemetry-proto==1.22.0 +opentelemetry-sdk==1.22.0 +opentelemetry-semantic-conventions==0.43b0 +opentelemetry-util-http==0.43b0 +orjson==3.9.10 +overrides==7.4.0 +pandas==2.0.0 +pathtools==0.1.2 +peft==0.2.0 +pfzy==0.3.4 +pillow==10.2.0 +plotly==5.18.0 +portalocker==2.8.2 +posthog==3.3.0 +proglog==0.1.10 +progressbar2==4.3.2 +prompt-toolkit==3.0.43 +protobuf==4.25.1 +psutil==5.9.7 +pulsar-client==3.4.0 +pyarrow==15.0.0 +pyarrow-hotfix==0.6 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pycocoevalcap==1.2 +pycocotools==2.0.6 +pycryptodomex==3.19.1 +pydantic==2.5.3 +pydantic-core==2.14.6 +pydub==0.25.1 +pygments==2.17.2 +pyparsing==3.1.1 +pypika==0.48.9 +pyproject-hooks==1.0.0 +pysrt==1.1.2 +python-dateutil==2.8.2 +python-dotenv==1.0.0 +python-multipart==0.0.6 +python-slugify==8.0.1 +python-utils==3.8.1 +pytubefix +pytz==2023.3.post1 +pyyaml==6.0.1 +referencing==0.32.0 +regex==2023.12.25 +rich==13.7.0 +rouge==1.0.1 +rpds-py==0.16.2 +rsa==4.9 +safetensors==0.4.1 +scikit-image==0.22.0 +scikit-learn==1.3.2 +scipy==1.11.4 +seaborn==0.13.1 +semantic-version==2.10.0 +sentence-transformers==2.2.2 +sentencepiece==0.1.97 +sentry-sdk==1.39.1 +setproctitle==1.3.3 +setuptools==69.0.3 +shellingham==1.5.4 +six==1.16.0 +smmap==5.0.1 +sniffio==1.3.0 +soundfile==0.12.1 +soupsieve==2.5 +starlette==0.32.0.post1 +sympy==1.12 +tenacity==8.2.3 +text-unidecode==1.3 +threadpoolctl==3.2.0 +tifffile==2023.12.9 +tiktoken==0.5.2 +timm==0.6.13 +tokenizers==0.15.2 +tomli==2.0.1 +tomlkit==0.12.0 +toolz==0.12.0 +torch==2.0.1 +torchaudio==2.0.2 +torchvision==0.15.2 +transformers==4.37.2 +triton==2.0.0 +typer==0.9.0 +typing-extensions==4.9.0 +tzdata==2023.4 +ujson==5.9.0 +uvicorn==0.25.0 +uvloop==0.19.0 +visual-genome==1.1.1 +wandb==0.14.2 +watchfiles==0.21.0 +wcwidth==0.2.13 +webdataset==0.2.48 +webencodings==0.5.1 +websocket-client==1.7.0 +websockets +webvtt-py==0.4.6 +wrapt==1.16.0 +xxhash==3.4.1 +yarl==1.9.4 +youtube-dl==2021.12.17 +yt-dlp +zipp diff --git a/test_benchmark/quantitative_evaluation/benchmark_dataset_generation/generate_consistency_qa.py b/test_benchmark/quantitative_evaluation/benchmark_dataset_generation/generate_consistency_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..4ddb0666ed8ca34d768624bc64f02285308c0f11 --- /dev/null +++ b/test_benchmark/quantitative_evaluation/benchmark_dataset_generation/generate_consistency_qa.py @@ -0,0 +1,140 @@ +import openai +import os +import argparse +import warnings +import json +import ast +from multiprocessing.pool import Pool + +# Disable warnings. +warnings.filterwarnings('ignore') + + +def parse_args(): + parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") + parser.add_argument("--gt_caption_folder", required=True, help="The path to captions") + parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") + parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") + parser.add_argument("--api_key", required=True, help="OpenAI API key.") + parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") + args = parser.parse_args() + return args + + +def annotate(gt_file, caption_files, output_dir): + """ + Generate questions and answers for each caption file using GPT-3. + """ + for file in caption_files: + key = file[:-5] # Strip file extension. + caption = gt_file[key] + try: + # Generate GPT-3 response. + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": + "Your primary task is to formulate two distinct but conceptually similar questions, such that when asked about the same video-information, they correspond to the same answer. " + "------" + "##TASK:" + "When given details about a video, your task is to generate two questions asked in different ways. The crucial aspect is to frame these questions so that they are conceptually alike but phrased differently, leading to the exact same answer. " + "The questions should be cleverly designed to extract the same information directly from the video details given, so that the provided information or parts of it can serve as the answer. It's important that both questions yield the SAME answer. " + "- Generate TWO questions and ONE answer. The purpose is to extract identical information from both questions. Therefore, formulate your questions in a way that the given details can serve directly as the answer. " + "------" + "##SAMPLE QUESTIONS:" + "- {'Q1': 'What is the colour of the cycle the boy rides?', 'Q2': 'Can you describe the cycle the boy is riding?', 'A': 'The boy is riding a red bicycle with a basket.'}" + "- {'Q1': 'What is the baby girl doing in the video?', 'Q2': 'Can you see the baby girl engaged in an activity in the video?', 'A': 'The baby girl is reading a book in the video.'}" + }, + { + "role": "user", + "content": + f"The user input is: {caption}. " + f"Please generate the response in the form of a Python dictionary string with keys 'Q1', 'Q2', and 'A', where value of 'Q1' is first question, 'Q2' for second question and 'A' is the answer to both questions. Each corresponding value should be the question or answer text respectively. " + "For example, your response should look like this: {'Q1': 'Your first question here...', 'Q2': 'Your second question here...', 'A': 'Your answer to both questions here...'}. " + "Remember, it's critical to ensure that both questions are designed to extract the same details from the video, leading to the same answer." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + + # Save the question-answer pairs to a json file. + with open(f"{output_dir}/{key}.json", "w") as f: + json.dump(response_dict, f) + except Exception as e: + print(f"Error processing file '{key}': {e}") + + +def main(): + """ + Main function to control the flow of the program. + """ + # Parse arguments. + args = parse_args() + + # Read ground truth captions. + gt_captions = {} + gt_files = os.listdir(args.gt_caption_folder) + for file in gt_files: + # Read human-assisted annotations from individual text files. + with open(os.path.join(args.gt_caption_folder, file), mode='r', encoding='utf-8-sig') as f: + caption = f.read().replace('\n', '').replace('‘', "'").replace('’', "'") + video_id = file[:-4] + gt_captions[video_id] = caption + + caption_files = [f"{video_id}.json" for video_id in gt_captions.keys()] + + output_dir = args.output_dir + # Generate output directory if not exists. + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Set the OpenAI API key. + openai.api_key = args.api_key + num_tasks = args.num_tasks + + # While loop to ensure that all captions are processed. + while True: + try: + # Files that have already been completed. + completed_files = os.listdir(output_dir) + print(f"completed_files: {len(completed_files)}") + + # Files that have not been processed yet. + incomplete_files = [f for f in caption_files if f not in completed_files] + print(f"incomplete_files: {len(incomplete_files)}") + + if len(incomplete_files) == 0: + break + if len(incomplete_files) <= num_tasks: + num_tasks = 1 + + # Split tasks into parts. + part_len = len(incomplete_files) // num_tasks + all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] + task_args = [(gt_captions, part, args.output_dir) for part in all_parts] + + # Use a pool of workers to process the files in parallel. + with Pool() as pool: + pool.starmap(annotate, task_args) + + except Exception as e: + print(f"Error: {e}") + + # Combine qa pairs into single file when individual qa generation completes + all_data = {} + for filename in os.listdir(output_dir): + if filename.endswith(".json"): + with open(os.path.join(output_dir, filename)) as f: + key = filename[:-5] + all_data[key] = json.load(f) + + with open(args.output_json, 'w') as f: + json.dump(all_data, f, indent=4) + + +if __name__ == "__main__": + main() diff --git a/test_benchmark/quantitative_evaluation/benchmark_dataset_generation/generate_correctness_detailed_context_qa.py b/test_benchmark/quantitative_evaluation/benchmark_dataset_generation/generate_correctness_detailed_context_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..12c01a892f76436677681012ff94d176044dbf5c --- /dev/null +++ b/test_benchmark/quantitative_evaluation/benchmark_dataset_generation/generate_correctness_detailed_context_qa.py @@ -0,0 +1,134 @@ +import openai +import os +import argparse +import warnings +import json +import ast +from multiprocessing.pool import Pool + +warnings.filterwarnings('ignore') + + +def parse_args(): + parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") + parser.add_argument("--gt_caption_folder", required=True, help="The path to captions") + parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") + parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") + parser.add_argument("--api_key", required=True, help="OpenAI API key.") + parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") + args = parser.parse_args() + return args + + +def annotate(gt_file, caption_files, output_dir): + """ + Generate generic descriptive type questions and answers for each caption file using GPT-3. + """ + for file in caption_files: + key = file[:-5] # Strip file extension. + caption = gt_file[key] + try: + # Generate GPT-3 response. + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": + "You will play two roles: a human asking questions related to describing a video and an intelligent chatbot designed for video description and dense captioning. " + "Your task is to generate a detailed and descriptive paragraph based on the provided fragmented information about a video. " + "------" + "##TASK:" + "Users will provide a descriptions of a video, and you will generate ONE conversation-like question and answer related to describing the video in detail. " + "The question should ask to describe the video content in detail. " + "The answer should be a paraphrased and well-structured paragraph based on the provided description, as detailed as possible. " + }, + { + "role": "user", + "content": + f"The user input is: {caption}. " + f"Please generate the response in the form of a Python dictionary string with keys 'Q' for question and 'A' for answer. Each corresponding value should be the question and answer text respectively. " + "For example, your response should look like this: {'Q': 'Your question here...', 'A': 'Your answer here...'}. " + f"Emphasize that the answer should focus on describing the video content as detailed as possible." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + + # Save the question-answer pairs to a json file. + with open(f"{output_dir}/{key}.json", "w") as f: + json.dump(response_dict, f) + except Exception as e: + print(f"Error processing file '{key}': {e}") + + +def main(): + """ + Main function to control the flow of the program. + """ + # Parse arguments. + args = parse_args() + + # Read ground truth captions. + gt_captions = {} + gt_files = os.listdir(args.gt_caption_folder) + for file in gt_files: + with open(os.path.join(args.gt_caption_folder, file), mode='r', encoding='utf-8-sig') as f: + caption = f.read().replace('\n', '').replace('‘', "'").replace('’', "'") + video_id = file[:-4] + gt_captions[video_id] = caption + + caption_files = [f"{video_id}.json" for video_id in gt_captions.keys()] + output_dir = args.output_dir + # Generate output directory if not exists. + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Set the OpenAI API key. + openai.api_key = args.api_key + num_tasks = args.num_tasks + + # While loop to ensure that all captions are processed. + while True: + try: + # Files that have not been processed yet. + completed_files = os.listdir(output_dir) + print(f"completed_files: {len(completed_files)}") + + # Files that have not been processed yet. + incomplete_files = [f for f in caption_files if f not in completed_files] + print(f"incomplete_files: {len(incomplete_files)}") + + if len(incomplete_files) == 0: + break + if len(incomplete_files) <= num_tasks: + num_tasks = 1 + + # Split tasks into parts. + part_len = len(incomplete_files) // num_tasks + all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] + task_args = [(gt_captions, part, args.output_dir) for part in all_parts] + + # Use a pool of workers to process the files in parallel. + with Pool() as pool: + pool.starmap(annotate, task_args) + + except Exception as e: + print(f"Error: {e}") + + # Combine qa pairs into single file when individual qa generation completes + all_data = {} + for filename in os.listdir(output_dir): + if filename.endswith(".json"): + with open(os.path.join(output_dir, filename)) as f: + key = filename[:-5] + all_data[key] = json.load(f) + + with open(args.output_json, 'w') as f: + json.dump(all_data, f, indent=4) + + +if __name__ == "__main__": + main() diff --git a/test_benchmark/quantitative_evaluation/benchmark_dataset_generation/generate_temporal_qa.py b/test_benchmark/quantitative_evaluation/benchmark_dataset_generation/generate_temporal_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0f2cafc64fbf4be3cf1fba7dd851830eaa40c5 --- /dev/null +++ b/test_benchmark/quantitative_evaluation/benchmark_dataset_generation/generate_temporal_qa.py @@ -0,0 +1,139 @@ +import openai +import os +import argparse +import warnings +import json +import ast +from multiprocessing.pool import Pool + +warnings.filterwarnings('ignore') + + +def parse_args(): + parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") + parser.add_argument("--gt_caption_folder", required=True, help="The path to captions") + parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") + parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") + parser.add_argument("--api_key", required=True, help="OpenAI API key.") + parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") + args = parser.parse_args() + return args + + +def annotate(gt_file, caption_files, output_dir): + """ + Generate questions and answers for each caption file using GPT-3. + """ + for file in caption_files: + key = file[:-5] # Strip file extension. + caption = gt_file[key] + try: + # Generate GPT-3 response. + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": + "You play two roles: a human asking questions related to a video and an intelligent chatbot designed to help people find information from a given video. " + "Your task is to generate a question-answer pair specifically related to temporal understanding from the video content. " + "Your task is to first play the role of a human who asks a question about the temporal sequence or timing of events in the video and then play the role of an AI assistant that provides information based on the video content." + "------" + "##TASK: " + "Users will provide some information about a video, and you will generate a conversation-like question and answers pair specifically focusing on the temporal sequence of events in the video. " + "The question should be designed to extract temporal sequence information directly from the given information, so that the provided information or parts of it can serve as the answer. " + "Generate ONE descriptive and conversational style question and detailed answer based on the given information, specifically related to the temporal understanding in the video." + "------" + "##INSTRUCTIONS:" + "- The question must be like a human conversation and directly related to the temporal sequence of events in the video. " + "- The question should be designed to extract temporal sequence information DIRECTLY from the given information, so that it or parts of it can serve as the answer. " + "- The answer must be detailed and descriptive, and should directly reference the information provided with respect to the temporal sequence of events in the video." + }, + { + "role": "user", + "content": + f"The user input is: {caption}. " + "Please generate the response in the form of a Python dictionary string with keys 'Q' for question and 'A' for answer. Each corresponding value should be the question and answer text respectively. " + "For example, your response should look like this: {'Q': 'Your question here...', 'A': 'Your answer here...'}. " + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + + # Save the question-answer pairs to a json file. + with open(f"{output_dir}/{key}.json", "w") as f: + json.dump(response_dict, f) + except Exception as e: + print(f"Error processing file '{key}': {e}") + + +def main(): + """ + Main function to control the flow of the program. + """ + # Parse arguments. + args = parse_args() + + # Read ground truth captions. + gt_captions = {} + gt_files = os.listdir(args.gt_caption_folder) + for file in gt_files: + with open(os.path.join(args.gt_caption_folder, file), mode='r', encoding='utf-8-sig') as f: + caption = f.read().replace('\n', '').replace('‘', "'").replace('’', "'") + video_id = file[:-4] + gt_captions[video_id] = caption + + caption_files = [f"{video_id}.json" for video_id in gt_captions.keys()] + output_dir = args.output_dir + # Generate output directory if not exists. + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Set the OpenAI API key. + openai.api_key = args.api_key + num_tasks = args.num_tasks + + # While loop to ensure that all captions are processed. + while True: + try: + # Files that have not been processed yet. + completed_files = os.listdir(output_dir) + print(f"completed_files: {len(completed_files)}") + + # Files that have not been processed yet. + incomplete_files = [f for f in caption_files if f not in completed_files] + print(f"incomplete_files: {len(incomplete_files)}") + + if len(incomplete_files) == 0: + break + if len(incomplete_files) <= num_tasks: + num_tasks = 1 + + # Split tasks into parts. + part_len = len(incomplete_files) // num_tasks + all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] + task_args = [(gt_captions, part, args.output_dir) for part in all_parts] + + # Use a pool of workers to process the files in parallel. + with Pool() as pool: + pool.starmap(annotate, task_args) + + except Exception as e: + print(f"Error: {e}") + + # Combine qa pairs into single file when individual qa generation completes + all_data = {} + for filename in os.listdir(output_dir): + if filename.endswith(".json"): + with open(os.path.join(output_dir, filename)) as f: + key = filename[:-5] + all_data[key] = json.load(f) + + with open(args.output_json, 'w') as f: + json.dump(all_data, f, indent=4) + + +if __name__ == "__main__": + main() diff --git a/test_benchmark/quantitative_evaluation/evaluate_activitynet_qa.py b/test_benchmark/quantitative_evaluation/evaluate_activitynet_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..581eb7b6069e636670265e8d6f9478fcb0eecdc2 --- /dev/null +++ b/test_benchmark/quantitative_evaluation/evaluate_activitynet_qa.py @@ -0,0 +1,207 @@ +import openai +import os +import argparse +import json +import ast +from multiprocessing.pool import Pool + + +def parse_args(): + parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") + parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") + parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") + parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") + parser.add_argument("--api_key", required=True, help="OpenAI API key.") + parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") + args = parser.parse_args() + return args + + +def annotate(prediction_set, caption_files, output_dir): + """ + Evaluates question and answer pairs using GPT-3 + Returns a score for correctness. + """ + for file in caption_files: + key = file[:-5] # Strip file extension + qa_set = prediction_set[key] + question = qa_set['q'] + answer = qa_set['a'] + pred = qa_set['pred'] + try: + # Compute the correctness score + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer." + }, + { + "role": "user", + "content": + "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " + "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + result_qa_pair = [response_dict, qa_set] + + # Save the question-answer pairs to a json file. + with open(f"{output_dir}/{key}.json", "w") as f: + json.dump(result_qa_pair, f) + + except Exception as e: + print(f"Error processing file '{key}': {e}") + + +def main(): + """ + Main function to control the flow of the program. + """ + # Parse arguments. + args = parse_args() + + file = open(args.pred_path) + pred_contents = json.load(file) + + # Dictionary to store the count of occurrences for each video_id + video_id_counts = {} + new_pred_contents = [] + + # Iterate through each sample in pred_contents + for sample in pred_contents: + video_id = sample['video_name'] + if video_id in video_id_counts: + video_id_counts[video_id] += 1 + else: + video_id_counts[video_id] = 0 + + # Create a new sample with the modified key + new_sample = sample + new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" + new_pred_contents.append(new_sample) + + # Generating list of id's and corresponding files + id_list = [x['video_name'] for x in new_pred_contents] + caption_files = [f"{id}.json" for id in id_list] + + output_dir = args.output_dir + # Generate output directory if not exists. + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Preparing dictionary of question-answer sets + prediction_set = {} + for sample in new_pred_contents: + id = sample['video_name'] + question = sample['Q'] + answer = sample['A'] + pred = sample['pred'] + qa_set = {"q": question, "a": answer, "pred": pred} + prediction_set[id] = qa_set + + # Set the OpenAI API key. + openai.api_key = args.api_key + num_tasks = args.num_tasks + + # While loop to ensure that all captions are processed. + while True: + try: + # Files that have not been processed yet. + completed_files = os.listdir(output_dir) + print(f"completed_files: {len(completed_files)}") + + # Files that have not been processed yet. + incomplete_files = [f for f in caption_files if f not in completed_files] + print(f"incomplete_files: {len(incomplete_files)}") + + # Break the loop when there are no incomplete files + if len(incomplete_files) == 0: + break + if len(incomplete_files) <= num_tasks: + num_tasks = 1 + + # Split tasks into parts. + part_len = len(incomplete_files) // num_tasks + all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] + task_args = [(prediction_set, part, args.output_dir) for part in all_parts] + + # Use a pool of workers to process the files in parallel. + with Pool() as pool: + pool.starmap(annotate, task_args) + + except Exception as e: + print(f"Error: {e}") + + # Combine all the processed files into one + combined_contents = {} + json_path = args.output_json + + # Iterate through json files + for file_name in os.listdir(output_dir): + if file_name.endswith(".json"): + file_path = os.path.join(output_dir, file_name) + with open(file_path, "r") as json_file: + content = json.load(json_file) + combined_contents[file_name[:-5]] = content + + # Write combined content to a json file + with open(json_path, "w") as json_file: + json.dump(combined_contents, json_file) + print("All evaluation completed!") + + # Calculate average score and accuracy + score_sum = 0 + count = 0 + yes_count = 0 + no_count = 0 + for key, result in combined_contents.items(): + # Computing score + count += 1 + try : + score_match = result[0]['score'] + score = int(score_match) + score_sum += score + except: + print("Score not found for", key) + continue + + # Computing accuracy + try: + pred = result[0]['pred'] + if "yes" in pred.lower(): + yes_count += 1 + elif "no" in pred.lower(): + no_count += 1 + except: + print("Prediction not found for", key) + continue + + average_score = score_sum / count + accuracy = yes_count / (yes_count + no_count) + print("Yes count:", yes_count) + print("No count:", no_count) + print("Accuracy:", accuracy) + print("Average score:", average_score) + + +if __name__ == "__main__": + main() + diff --git a/test_benchmark/quantitative_evaluation/evaluate_benchmark.sh b/test_benchmark/quantitative_evaluation/evaluate_benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..b2b9194aa835aa069bbb9f38e7a0293ce2f759a8 --- /dev/null +++ b/test_benchmark/quantitative_evaluation/evaluate_benchmark.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +# Define common arguments for all scripts + +PRED="pred_path" +OUTPUT_DIR="output_dir" +API_KEY="api_key" +NUM_TASKS=128 + +# Run the "correctness" evaluation script +python evaluate_benchmark_1_correctness.py \ + --pred_path "${PRED_GENERIC}" \ + --output_dir "${OUTPUT_DIR}/correctness_eval" \ + --output_json "${OUTPUT_DIR}/correctness_results.json" \ + --api_key $API_KEY \ + --num_tasks $NUM_TASKS + +# Run the "detailed orientation" evaluation script +python evaluate_benchmark_2_detailed_orientation.py \ + --pred_path "${PRED_GENERIC}" \ + --output_dir "${OUTPUT_DIR}/detailed_eval" \ + --output_json "${OUTPUT_DIR}/detailed_orientation_results.json" \ + --api_key $API_KEY \ + --num_tasks $NUM_TASKS + +# Run the "contextual understanding" evaluation script +python evaluate_benchmark_3_context.py \ + --pred_path "${PRED_GENERIC}" \ + --output_dir "${OUTPUT_DIR}/context_eval" \ + --output_json "${OUTPUT_DIR}/contextual_understanding_results.json" \ + --api_key $API_KEY \ + --num_tasks $NUM_TASKS + +# Run the "temporal understanding" evaluation script +python evaluate_benchmark_4_temporal.py \ + --pred_path "${PRED_TEMPORAL}" \ + --output_dir "${OUTPUT_DIR}/temporal_eval" \ + --output_json "${OUTPUT_DIR}/temporal_understanding_results.json" \ + --api_key $API_KEY \ + --num_tasks $NUM_TASKS + +# Run the "consistency" evaluation script +python evaluate_benchmark_5_consistency.py \ + --pred_path "${PRED_CONSISTENCY}" \ + --output_dir "${OUTPUT_DIR}/consistency_eval" \ + --output_json "${OUTPUT_DIR}/consistency_results.json" \ + --api_key $API_KEY \ + --num_tasks $NUM_TASKS + + +echo "All evaluations completed!" diff --git a/test_benchmark/quantitative_evaluation/evaluate_benchmark_1_correctness.py b/test_benchmark/quantitative_evaluation/evaluate_benchmark_1_correctness.py new file mode 100644 index 0000000000000000000000000000000000000000..6ebae9013b6102ec8b9c71495d0b19e2a3ac5ce7 --- /dev/null +++ b/test_benchmark/quantitative_evaluation/evaluate_benchmark_1_correctness.py @@ -0,0 +1,186 @@ +import openai +import os +import argparse +import json +import ast +from multiprocessing.pool import Pool + + +def parse_args(): + parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") + parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") + parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") + parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") + parser.add_argument("--api_key", required=True, help="OpenAI API key.") + parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") + args = parser.parse_args() + return args + + +def annotate(prediction_set, caption_files, output_dir): + """ + Evaluates question and answer pairs using GPT-3 + Returns a score for correctness. + """ + for file in caption_files: + key = file[:-5] # Strip file extension + qa_set = prediction_set[key] + question = qa_set['q'] + answer = qa_set['a'] + pred = qa_set['pred'] + try: + # Compute the correctness score + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for evaluating the factual accuracy of generative outputs for video-based question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they are factually consistent. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the factual consistency between the predicted answer and the correct answer. The predicted answer should not contain any misinterpretations or misinformation.\n" + "- The predicted answer must be factually accurate and align with the video content.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the factual accuracy of the prediction compared to the answer." + }, + { + "role": "user", + "content": + "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a factual accuracy score where the factual accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of factual consistency. " + "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the factual accuracy score in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {''score': 4.8}." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + result_qa_pair = [response_dict, qa_set] + + # Save the question-answer pairs to a json file. + with open(f"{output_dir}/{key}.json", "w") as f: + json.dump(result_qa_pair, f) + + except Exception as e: + print(f"Error processing file '{key}': {e}") + + +def main(): + """ + Main function to control the flow of the program. + """ + # Parse arguments. + args = parse_args() + + file = open(args.pred_path) + pred_contents = json.load(file) + + # Dictionary to store the count of occurrences for each video_id + video_id_counts = {} + new_pred_contents = [] + + # Iterate through each sample in pred_contents + for sample in pred_contents: + video_id = sample['video_name'] + if video_id in video_id_counts: + video_id_counts[video_id] += 1 + else: + video_id_counts[video_id] = 0 + + # Create a new sample with the modified key + new_sample = sample + new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" + new_pred_contents.append(new_sample) + + # Generating list of id's and corresponding files + id_list = [x['video_name'] for x in new_pred_contents] + caption_files = [f"{id}.json" for id in id_list] + + output_dir = args.output_dir + # Generate output directory if not exists. + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Preparing dictionary of question-answer sets + prediction_set = {} + for sample in new_pred_contents: + id = sample['video_name'] + question = sample['Q'] + answer = sample['A'] + pred = sample['pred'] + qa_set = {"q": question, "a": answer, "pred": pred} + prediction_set[id] = qa_set + + # Set the OpenAI API key. + openai.api_key = args.api_key + num_tasks = args.num_tasks + + # While loop to ensure that all captions are processed. + while True: + try: + # Files that have not been processed yet. + completed_files = os.listdir(output_dir) + print(f"completed_files: {len(completed_files)}") + + # Files that have not been processed yet. + incomplete_files = [f for f in caption_files if f not in completed_files] + print(f"incomplete_files: {len(incomplete_files)}") + + # Break the loop when there are no incomplete files + if len(incomplete_files) == 0: + break + if len(incomplete_files) <= num_tasks: + num_tasks = 1 + + # Split tasks into parts. + part_len = len(incomplete_files) // num_tasks + all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] + task_args = [(prediction_set, part, args.output_dir) for part in all_parts] + + # Use a pool of workers to process the files in parallel. + with Pool() as pool: + pool.starmap(annotate, task_args) + + except Exception as e: + print(f"Error: {e}") + + # Combine all the processed files into one + combined_contents = {} + json_path = args.output_json + + # Iterate through json files + for file_name in os.listdir(output_dir): + if file_name.endswith(".json"): + file_path = os.path.join(output_dir, file_name) + with open(file_path, "r") as json_file: + content = json.load(json_file) + combined_contents[file_name[:-5]] = content + + # Write combined content to a json file + with open(json_path, "w") as json_file: + json.dump(combined_contents, json_file) + print("All evaluation completed!") + + # Calculate average score + score_sum = 0 + count = 0 + for key, result in combined_contents.items(): + count += 1 + score_match = result[0]['score'] + score = int(score_match) + score_sum += score + average_score = score_sum / count + + print("Average score for correctness:", average_score) + + +if __name__ == "__main__": + main() + diff --git a/test_benchmark/quantitative_evaluation/evaluate_benchmark_2_detailed_orientation.py b/test_benchmark/quantitative_evaluation/evaluate_benchmark_2_detailed_orientation.py new file mode 100644 index 0000000000000000000000000000000000000000..634bda06ece01ad2914012d8cebe857b5e79ced2 --- /dev/null +++ b/test_benchmark/quantitative_evaluation/evaluate_benchmark_2_detailed_orientation.py @@ -0,0 +1,186 @@ +import openai +import os +import argparse +import json +import ast +from multiprocessing.pool import Pool + + +def parse_args(): + parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") + parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") + parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") + parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") + parser.add_argument("--api_key", required=True, help="OpenAI API key.") + parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") + args = parser.parse_args() + return args + + +def annotate(prediction_set, caption_files, output_dir): + """ + Evaluates question and answer pairs using GPT-3 and + returns a score for detailed orientation. + """ + for file in caption_files: + key = file[:-5] # Strip file extension + qa_set = prediction_set[key] + question = qa_set['q'] + answer = qa_set['a'] + pred = qa_set['pred'] + try: + # Compute the detailed-orientation score + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for evaluating the detail orientation of generative outputs for video-based question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine its level of detail, considering both completeness and specificity. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Check if the predicted answer covers all major points from the video. The response should not leave out any key aspects.\n" + "- Evaluate whether the predicted answer includes specific details rather than just generic points. It should provide comprehensive information that is tied to specific elements of the video.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Provide a single evaluation score that reflects the level of detail orientation of the prediction, considering both completeness and specificity." + }, + { + "role": "user", + "content": + "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a detail orientation score where the detail orientation score is an integer value between 0 and 5, with 5 indicating the highest level of detail orientation. " + "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the detail orientation score in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {''score': 4.8}." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + result_qa_pair = [response_dict, qa_set] + + # Save the question-answer pairs to a json file. + with open(f"{output_dir}/{key}.json", "w") as f: + json.dump(result_qa_pair, f) + + except Exception as e: + print(f"Error processing file '{key}': {e}") + + +def main(): + """ + Main function to control the flow of the program. + """ + # Parse arguments. + args = parse_args() + + file = open(args.pred_path) + pred_contents = json.load(file) + + # Dictionary to store the count of occurrences for each video_id + video_id_counts = {} + new_pred_contents = [] + + # Iterate through each sample in pred_contents + for sample in pred_contents: + video_id = sample['video_name'] + if video_id in video_id_counts: + video_id_counts[video_id] += 1 + else: + video_id_counts[video_id] = 0 + + # Create a new sample with the modified key + new_sample = sample + new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" + new_pred_contents.append(new_sample) + + # Generating list of id's and corresponding files + id_list = [x['video_name'] for x in new_pred_contents] + caption_files = [f"{id}.json" for id in id_list] + + output_dir = args.output_dir + # Generate output directory if not exists. + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Preparing dictionary of question-answer sets + prediction_set = {} + for sample in new_pred_contents: + id = sample['video_name'] + question = sample['Q'] + answer = sample['A'] + pred = sample['pred'] + qa_set = {"q": question, "a": answer, "pred": pred} + prediction_set[id] = qa_set + + # Set the OpenAI API key. + openai.api_key = args.api_key + num_tasks = args.num_tasks + + # While loop to ensure that all captions are processed. + while True: + try: + # Files that have not been processed yet. + completed_files = os.listdir(output_dir) + print(f"completed_files: {len(completed_files)}") + + # Files that have not been processed yet. + incomplete_files = [f for f in caption_files if f not in completed_files] + print(f"incomplete_files: {len(incomplete_files)}") + + # Break the loop when there are no incomplete files + if len(incomplete_files) == 0: + break + if len(incomplete_files) <= num_tasks: + num_tasks = 1 + + # Split tasks into parts. + part_len = len(incomplete_files) // num_tasks + all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] + task_args = [(prediction_set, part, args.output_dir) for part in all_parts] + + # Use a pool of workers to process the files in parallel. + with Pool() as pool: + pool.starmap(annotate, task_args) + + except Exception as e: + print(f"Error: {e}") + + # Combine all the processed files into one + combined_contents = {} + json_path = args.output_json + + # Iterate through json files + for file_name in os.listdir(output_dir): + if file_name.endswith(".json"): + file_path = os.path.join(output_dir, file_name) + with open(file_path, "r") as json_file: + content = json.load(json_file) + combined_contents[file_name[:-5]] = content + + # Write combined content to a json file + with open(json_path, "w") as json_file: + json.dump(combined_contents, json_file) + print("All evaluation completed!") + + # Calculate average score + score_sum = 0 + count = 0 + for key, result in combined_contents.items(): + count += 1 + score_match = result[0]['score'] + score = int(score_match) + score_sum += score + average_score = score_sum / count + + print("Average score for detailed orientation:", average_score) + + +if __name__ == "__main__": + main() + diff --git a/test_benchmark/quantitative_evaluation/evaluate_benchmark_3_context.py b/test_benchmark/quantitative_evaluation/evaluate_benchmark_3_context.py new file mode 100644 index 0000000000000000000000000000000000000000..0058f75b51c41af838194603b8c24628671ca286 --- /dev/null +++ b/test_benchmark/quantitative_evaluation/evaluate_benchmark_3_context.py @@ -0,0 +1,186 @@ +import openai +import os +import argparse +import json +import ast +from multiprocessing.pool import Pool + + +def parse_args(): + parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") + parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") + parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") + parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") + parser.add_argument("--api_key", required=True, help="OpenAI API key.") + parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") + args = parser.parse_args() + return args + + +def annotate(prediction_set, caption_files, output_dir): + """ + Evaluates question and answer pairs using GPT-3 and + returns a score for contextual understanding. + """ + for file in caption_files: + key = file[:-5] # Strip file extension + qa_set = prediction_set[key] + question = qa_set['q'] + answer = qa_set['a'] + pred = qa_set['pred'] + try: + # Compute the contextual understanding score + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for evaluating the contextual understanding of generative outputs for video-based question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if the generated response aligns with the overall context of the video content. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Evaluate whether the predicted answer aligns with the overall context of the video content. It should not provide information that is out of context or misaligned.\n" + "- The predicted answer must capture the main themes and sentiments of the video.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Provide your evaluation of the contextual understanding of the prediction compared to the answer." + }, + { + "role": "user", + "content": + "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a contextual understanding score where the contextual understanding score is an integer value between 0 and 5, with 5 indicating the highest level of contextual understanding. " + "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is contextual understanding score in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {''score': 4.8}." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + result_qa_pair = [response_dict, qa_set] + + # Save the question-answer pairs to a json file. + with open(f"{output_dir}/{key}.json", "w") as f: + json.dump(result_qa_pair, f) + + except Exception as e: + print(f"Error processing file '{key}': {e}") + + +def main(): + """ + Main function to control the flow of the program. + """ + # Parse arguments. + args = parse_args() + + file = open(args.pred_path) + pred_contents = json.load(file) + + # Dictionary to store the count of occurrences for each video_id + video_id_counts = {} + new_pred_contents = [] + + # Iterate through each sample in pred_contents + for sample in pred_contents: + video_id = sample['video_name'] + if video_id in video_id_counts: + video_id_counts[video_id] += 1 + else: + video_id_counts[video_id] = 0 + + # Create a new sample with the modified key + new_sample = sample + new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" + new_pred_contents.append(new_sample) + + # Generating list of id's and corresponding files + id_list = [x['video_name'] for x in new_pred_contents] + caption_files = [f"{id}.json" for id in id_list] + + output_dir = args.output_dir + # Generate output directory if not exists. + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Preparing dictionary of question-answer sets + prediction_set = {} + for sample in new_pred_contents: + id = sample['video_name'] + question = sample['Q'] + answer = sample['A'] + pred = sample['pred'] + qa_set = {"q": question, "a": answer, "pred": pred} + prediction_set[id] = qa_set + + # Set the OpenAI API key. + openai.api_key = args.api_key + num_tasks = args.num_tasks + + # While loop to ensure that all captions are processed. + while True: + try: + # Files that have not been processed yet. + completed_files = os.listdir(output_dir) + print(f"completed_files: {len(completed_files)}") + + # Files that have not been processed yet. + incomplete_files = [f for f in caption_files if f not in completed_files] + print(f"incomplete_files: {len(incomplete_files)}") + + # Break the loop when there are no incomplete files + if len(incomplete_files) == 0: + break + if len(incomplete_files) <= num_tasks: + num_tasks = 1 + + # Split tasks into parts. + part_len = len(incomplete_files) // num_tasks + all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] + task_args = [(prediction_set, part, args.output_dir) for part in all_parts] + + # Use a pool of workers to process the files in parallel. + with Pool() as pool: + pool.starmap(annotate, task_args) + + except Exception as e: + print(f"Error: {e}") + + # Combine all the processed files into one + combined_contents = {} + json_path = args.output_json + + # Iterate through json files + for file_name in os.listdir(output_dir): + if file_name.endswith(".json"): + file_path = os.path.join(output_dir, file_name) + with open(file_path, "r") as json_file: + content = json.load(json_file) + combined_contents[file_name[:-5]] = content + + # Write combined content to a json file + with open(json_path, "w") as json_file: + json.dump(combined_contents, json_file) + print("All evaluation completed!") + + # Calculate average score + score_sum = 0 + count = 0 + for key, result in combined_contents.items(): + count += 1 + score_match = result[0]['score'] + score = int(score_match) + score_sum += score + average_score = score_sum / count + + print("Average score for contextual understanding:", average_score) + + +if __name__ == "__main__": + main() + diff --git a/test_benchmark/quantitative_evaluation/evaluate_benchmark_4_temporal.py b/test_benchmark/quantitative_evaluation/evaluate_benchmark_4_temporal.py new file mode 100644 index 0000000000000000000000000000000000000000..33e8db079e3317da705be91e72d340d33281d65e --- /dev/null +++ b/test_benchmark/quantitative_evaluation/evaluate_benchmark_4_temporal.py @@ -0,0 +1,185 @@ +import openai +import os +import argparse +import json +import ast +from multiprocessing.pool import Pool + + +def parse_args(): + parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") + parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") + parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") + parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") + parser.add_argument("--api_key", required=True, help="OpenAI API key.") + parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") + args = parser.parse_args() + return args + + +def annotate(prediction_set, caption_files, output_dir): + """ + Evaluates question and answer pairs using GPT-3 and + returns a score for temporal understanding. + """ + for file in caption_files: + key = file[:-5] # Strip file extension + qa_set = prediction_set[key] + question = qa_set['q'] + answer = qa_set['a'] + pred = qa_set['pred'] + try: + # Compute the temporal understanding score + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for evaluating the temporal understanding of generative outputs for video-based question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they correctly reflect the temporal sequence of events in the video content. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the temporal consistency between the predicted answer and the correct answer. The predicted answer should correctly reflect the sequence of events or details as they are presented in the video content.\n" + "- Consider synonyms or paraphrases as valid matches, but only if the temporal order is maintained.\n" + "- Evaluate the temporal accuracy of the prediction compared to the answer." + }, + { + "role": "user", + "content": + "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a temporal accuracy score where the temporal accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of temporal consistency. " + "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the temporal accuracy score in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {''score': 4.8}." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + result_qa_pair = [response_dict, qa_set] + + # Save the question-answer pairs to a json file. + with open(f"{output_dir}/{key}.json", "w") as f: + json.dump(result_qa_pair, f) + + except Exception as e: + print(f"Error processing file '{key}': {e}") + + +def main(): + """ + Main function to control the flow of the program. + """ + # Parse arguments. + args = parse_args() + + file = open(args.pred_path) + pred_contents = json.load(file) + + # Dictionary to store the count of occurrences for each video_id + video_id_counts = {} + new_pred_contents = [] + + # Iterate through each sample in pred_contents + for sample in pred_contents: + video_id = sample['video_name'] + if video_id in video_id_counts: + video_id_counts[video_id] += 1 + else: + video_id_counts[video_id] = 0 + + # Create a new sample with the modified key + new_sample = sample + new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" + new_pred_contents.append(new_sample) + + # Generating list of id's and corresponding files + id_list = [x['video_name'] for x in new_pred_contents] + caption_files = [f"{id}.json" for id in id_list] + + output_dir = args.output_dir + # Generate output directory if not exists. + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Preparing dictionary of question-answer sets + prediction_set = {} + for sample in new_pred_contents: + id = sample['video_name'] + question = sample['Q'] + answer = sample['A'] + pred = sample['pred'] + qa_set = {"q": question, "a": answer, "pred": pred} + prediction_set[id] = qa_set + + # Set the OpenAI API key. + openai.api_key = args.api_key + num_tasks = args.num_tasks + + # While loop to ensure that all captions are processed. + while True: + try: + # Files that have not been processed yet. + completed_files = os.listdir(output_dir) + print(f"completed_files: {len(completed_files)}") + + # Files that have not been processed yet. + incomplete_files = [f for f in caption_files if f not in completed_files] + print(f"incomplete_files: {len(incomplete_files)}") + + # Break the loop when there are no incomplete files + if len(incomplete_files) == 0: + break + if len(incomplete_files) <= num_tasks: + num_tasks = 1 + + # Split tasks into parts. + part_len = len(incomplete_files) // num_tasks + all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] + task_args = [(prediction_set, part, args.output_dir) for part in all_parts] + + # Use a pool of workers to process the files in parallel. + with Pool() as pool: + pool.starmap(annotate, task_args) + + except Exception as e: + print(f"Error: {e}") + + # Combine all the processed files into one + combined_contents = {} + json_path = args.output_json + + # Iterate through json files + for file_name in os.listdir(output_dir): + if file_name.endswith(".json"): + file_path = os.path.join(output_dir, file_name) + with open(file_path, "r") as json_file: + content = json.load(json_file) + combined_contents[file_name[:-5]] = content + + # Write combined content to a json file + with open(json_path, "w") as json_file: + json.dump(combined_contents, json_file) + print("All evaluation completed!") + + # Calculate average score + score_sum = 0 + count = 0 + for key, result in combined_contents.items(): + count += 1 + score_match = result[0]['score'] + score = int(score_match) + score_sum += score + average_score = score_sum / count + + print("Average score temporal understanding:", average_score) + + +if __name__ == "__main__": + main() + diff --git a/test_benchmark/quantitative_evaluation/evaluate_benchmark_5_consistency.py b/test_benchmark/quantitative_evaluation/evaluate_benchmark_5_consistency.py new file mode 100644 index 0000000000000000000000000000000000000000..3352c4258203efb693c23253fedb8d5c324b1495 --- /dev/null +++ b/test_benchmark/quantitative_evaluation/evaluate_benchmark_5_consistency.py @@ -0,0 +1,193 @@ +import openai +import os +import argparse +import json +import ast +from multiprocessing.pool import Pool + + +def parse_args(): + parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") + parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") + parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") + parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") + parser.add_argument("--api_key", required=True, help="OpenAI API key.") + parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") + args = parser.parse_args() + return args + + +def annotate(prediction_set, caption_files, output_dir): + """ + Evaluates question and answer pairs using GPT-3 and + returns a score for consistency. + """ + for file in caption_files: + key = file[:-5] # Strip file extension + qa_set = prediction_set[key] + question1 = qa_set['q1'] + question2 = qa_set['q2'] + answer = qa_set['a'] + pred1 = qa_set['pred1'] + pred2 = qa_set['pred2'] + try: + # Compute the consistency score + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for evaluating the consistency of generative outputs for similar video-based question-answer pairs. " + "You will be given two very similar questions, a common answer common to both the questions and predicted answers for the two questions ." + "Your task is to compare the predicted answers for two very similar question, with a common correct answer and determine if they are consistent. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the consistency between the two predicted answers and the correct answer. Both predicted answers should correspond to the correct answer and to each other, and should not contain any contradictions or significant differences in the conveyed information.\n" + "- Both predicted answers must be consistent with each other and the correct answer, in terms of the information they provide about the video content.\n" + "- Consider synonyms or paraphrases as valid matches, but only if they maintain the consistency in the conveyed information.\n" + "- Evaluate the consistency of the two predicted answers compared to the correct answer." + }, + { + "role": "user", + "content": + "Please evaluate the following video-based question-answer pair:\n\n" + f"Question 1: {question1}\n" + f"Question 2: {question2}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer to Question 1: {pred1}\n" + f"Predicted Answer to Question 2: {pred2}\n\n" + "Provide your evaluation only as a consistency score where the consistency score is an integer value between 0 and 5, with 5 indicating the highest level of consistency. " + "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the consistency score in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {''score': 4.8}." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + result_qa_pair = [response_dict, qa_set] + + # Save the question-answer pairs to a json file. + with open(f"{output_dir}/{key}.json", "w") as f: + json.dump(result_qa_pair, f) + + except Exception as e: + print(f"Error processing file '{key}': {e}") + + +def main(): + """ + Main function to control the flow of the program. + """ + # Parse arguments. + args = parse_args() + + file = open(args.pred_path) + pred_contents = json.load(file) + + # Dictionary to store the count of occurrences for each video_id + video_id_counts = {} + new_pred_contents = [] + + # Iterate through each sample in pred_contents + for sample in pred_contents: + video_id = sample['video_name'] + if video_id in video_id_counts: + video_id_counts[video_id] += 1 + else: + video_id_counts[video_id] = 0 + + # Create a new sample with the modified key + new_sample = sample + new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" + new_pred_contents.append(new_sample) + + # Generating list of id's and corresponding files + id_list = [x['video_name'] for x in new_pred_contents] + caption_files = [f"{id}.json" for id in id_list] + + output_dir = args.output_dir + # Generate output directory if not exists. + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Preparing dictionary of question-answer sets + prediction_set = {} + for sample in new_pred_contents: + id = sample['video_name'] + question1 = sample['Q1'] + question2 = sample['Q1'] + answer = sample['A'] + pred1 = sample['pred1'] + pred2 = sample['pred2'] + qa_set = {"q1": question1, "q2": question2, "a": answer, "pred1": pred1, "pred2": pred2} + prediction_set[id] = qa_set + + # Set the OpenAI API key. + openai.api_key = args.api_key + num_tasks = args.num_tasks + + # While loop to ensure that all captions are processed. + while True: + try: + # Files that have not been processed yet. + completed_files = os.listdir(output_dir) + print(f"completed_files: {len(completed_files)}") + + # Files that have not been processed yet. + incomplete_files = [f for f in caption_files if f not in completed_files] + print(f"incomplete_files: {len(incomplete_files)}") + + # Break the loop when there are no incomplete files + if len(incomplete_files) == 0: + break + if len(incomplete_files) <= num_tasks: + num_tasks = 1 + + # Split tasks into parts. + part_len = len(incomplete_files) // num_tasks + all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] + task_args = [(prediction_set, part, args.output_dir) for part in all_parts] + + # Use a pool of workers to process the files in parallel. + with Pool() as pool: + pool.starmap(annotate, task_args) + + except Exception as e: + print(f"Error: {e}") + + # Combine all the processed files into one + combined_contents = {} + json_path = args.output_json + + # Iterate through json files + for file_name in os.listdir(output_dir): + if file_name.endswith(".json"): + file_path = os.path.join(output_dir, file_name) + with open(file_path, "r") as json_file: + content = json.load(json_file) + combined_contents[file_name[:-5]] = content + + # Write combined content to a json file + with open(json_path, "w") as json_file: + json.dump(combined_contents, json_file) + print("All evaluation completed!") + + # Calculate average score + score_sum = 0 + count = 0 + for key, result in combined_contents.items(): + count += 1 + score_match = result[0]['score'] + score = int(score_match) + score_sum += score + average_score = score_sum / count + + print("Average score for consistency:", average_score) + + +if __name__ == "__main__": + main() + diff --git a/test_benchmark/quantitative_evaluation/evaluate_zeroshot.sh b/test_benchmark/quantitative_evaluation/evaluate_zeroshot.sh new file mode 100644 index 0000000000000000000000000000000000000000..d3a1cd6e92e825d8789cbfd9f8c43d093cd1cb26 --- /dev/null +++ b/test_benchmark/quantitative_evaluation/evaluate_zeroshot.sh @@ -0,0 +1,25 @@ +#!/bin/bash +#SBATCH --partition=batch +#SBATCH --job-name=zeroshot_eval%j +#SBATCH --output=zeroshot_eval%j.out +#SBATCH --error=zeroshot_eval%j.err +#SBATCH --time=0-10:00:00 +#SBATCH --mem=64G +#SBATCH --nodes=1 + +## run the application: + +# PRED="pred_path" +# OUTPUT_DIR="output_dir" +# API_KEY="api_key" +# NUM_TASKS=128 + + +python evaluate_activitynet_qa.py \ + --pred_path ${PRED} \ + --output_dir "${OUTPUT_DIR}/fewshot_accuracy" \ + --output_json "${OUTPUT_DIR}/fewshot_accuracy_results.json"\ + --api_key $API_KEY \ + --num_tasks $NUM_TASKS + +echo pred_path: $PRED \ No newline at end of file diff --git a/test_configs/llama2_test_config.yaml b/test_configs/llama2_test_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..30d2292d7bfc14f89f968e1cde299e595d0ac2d4 --- /dev/null +++ b/test_configs/llama2_test_config.yaml @@ -0,0 +1,34 @@ +model: + arch: mini_gpt4_llama_v2 + model_type: pretrain_vicuna + freeze_vit: True + freeze_qformer: True + max_txt_len: 256 + low_resource: True + image_size: 224 + end_sym: "
" + llama_model: "meta-llama/Llama-2-7b-chat-hf" + ckpt: "checkpoints/video_llama_checkpoint_last.pth" + use_grad_checkpoint: True + chat_template: True + lora_r: 64 + lora_alpha: 16 + length: 50 + use_grad_checkpoint_llm: True + max_context_len: 3600 + + +datasets: + video_chatgpt: #99378 row - 13224 video + batch_size: 4 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 200 +run: + seed: 42 + amp: True diff --git a/test_configs/mistral_test_config.yaml b/test_configs/mistral_test_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..93d4a99dbb71396410b9f1877bb4a2a71467afc9 --- /dev/null +++ b/test_configs/mistral_test_config.yaml @@ -0,0 +1,37 @@ +model: + arch: mini_gpt4_llama_v2 + model_type: pretrain_vicuna + freeze_vit: True + freeze_qformer: True + max_txt_len: 512 + low_resource: True + image_size: 224 + end_sym: "" + llama_model: "mistralai/Mistral-7B-Instruct-v0.2" + ckpt: "checkpoints/video_mistral_all_checkpoint_last.pth" + use_grad_checkpoint: True + chat_template: True + lora_r: 64 + lora_alpha: 16 + length: 50 + use_grad_checkpoint_llm: True + max_context_len: 7200 + + +datasets: + video_chatgpt: #99378 row - 13224 video + batch_size: 1 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 200 + + +run: + task: image_text_pretrain + seed: 42 + amp: True \ No newline at end of file diff --git a/theme.py b/theme.py new file mode 100644 index 0000000000000000000000000000000000000000..cc2ea66e40ff29c32b020001e60df2530be9032a --- /dev/null +++ b/theme.py @@ -0,0 +1,104 @@ +import gradio as gr +# https://www.gradio.app/docs/themes + +minigptlv_style = gr.themes.Soft( + primary_hue=gr.themes.Color( + c50="#ff339c", + c100="#791aff", + c200="#ff339c", + c300="#ff339c", + c400="#ff339c", + c500="3384FF", + c600="#ff339c", + c700="#ff339c", + c800="#ff339c", + c900="#ff339c", + c950="#ff339c", + name="lighter_blue", + ), + secondary_hue=gr.themes.Color( + c50="#ff339c", + c100="#ff339c", + c200="#ff339c", + c300="#ff339c", + c400="#ff339c", + c500="#ff339c", + c600="#ff339c", + c700="#ff339c", + c800="#ff339c", + c900="#ff339c", + c950="#ff339c", + ), + neutral_hue=gr.themes.Color( + c50="#ff339c", + c100="#FFFFFF", + c200="#3384FF", + c300="#ff339c", + c400="#FFFFFF", + c500="#FFFFFF", + c600="#ff339c", + c700="#192423", + c800="#cccdde", + c900="#ff339c", + c950="#ff339c", + name="dark_scale", + ), + radius_size=gr.themes.sizes.radius_sm, +).set( + button_primary_text_color="#ff339c", + button_primary_background_fill="#ff339c", + button_primary_background_fill_dark="#FFFFFF", + button_primary_border_color_dark="#FFFFFF", + button_primary_text_color_dark="#000000", + button_secondary_background_fill="#ff339c", + button_secondary_background_fill_hover="#40c928", + button_secondary_background_fill_dark="#ff339c", + button_secondary_background_fill_hover_dark="#40c928", + button_secondary_text_color="white", + button_secondary_text_color_dark="#white", + block_title_background_fill_dark="#1a94ff", + block_label_background_fill_dark="#1a94ff", + input_background_fill="#999999", + background_fill_primary="#1e1d1f", + background_fill_primary_dark="#1e1d1f", +) + +# Define custom CSS +custom_css = """ + /* Custom CSS for Gradio interface */ + .input-box { + font-family: Arial, sans-serif; + background-color: #F0F0F0; + border: 1px solid #CCCCCC; + } + + .output-box { + font-family: Arial, sans-serif; + background-color: #FFFFFF; + border: 1px solid #CCCCCC; + } + + .checkbox { + color: #464646; + } + + .textbox { + width: 100%; + } + + .output-image { + border: 1px solid #CCCCCC; + } + """ + +text_css = """ +h1 { + text-align: center; + display:block; + font-size: 45px; +} +h5 { + text-align: center; + display:block; +} +""" \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..39053ae90ce8bae2c5d49553819da76bdab7786f --- /dev/null +++ b/train.py @@ -0,0 +1,128 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +import minigpt4.tasks as tasks +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank, init_distributed_mode +from minigpt4.common.logger import setup_logger +from minigpt4.common.optims import ( + LinearWarmupCosineLRScheduler, + LinearWarmupStepLRScheduler, +) +from minigpt4.common.registry import registry +from minigpt4.common.utils import now + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * +import wandb + + +def parse_args(): + parser = argparse.ArgumentParser(description="Training") + + parser.add_argument("--cfg-path",default="train_configs_llama2/224_v2_llama2_video.yaml", required=False, help="path to configuration file.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + parser.add_argument("--job_name",default="test",type=str) + args = parser.parse_args() + + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +def get_runner_class(cfg): + """ + Get runner class from config. Default to epoch-based runner. + """ + runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base")) + + return runner_cls + + +def setup_environ_flags(rank): + """Set environment flags for debugging purposes""" + os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + if rank == 0: + print(f"--> Running with torch dist debug set to detail") + + +def main(): + # allow auto-dl completes on main process without timeout when using NCCL backend. + # os.environ["NCCL_BLOCKING_WAIT"] = "1" + + # set before init_distributed_mode() to ensure the same job_id shared across all ranks. + setup_environ_flags(get_rank()) + job_id = now() + args = parse_args() + cfg = Config(args) + init_distributed_mode(cfg.run_cfg) + setup_seeds(cfg) + + # set after in + # it_distributed_mode() to only log on master. + setup_logger() + wandb.login() + # print(wandb.run) + cfg.pretty_print() + + task = tasks.setup_task(cfg) + datasets = task.build_datasets(cfg) + model = task.build_model(cfg) + if not hasattr(cfg.run_cfg, 'rank') or cfg.run_cfg.rank == 0: + print("project name", args.job_name) + + wandb.init(project="minigpt4-spatial",name=args.job_name) + + wandb.config = {"learning_rate": 0.0001, "epochs": 100, "batch_size": 8} + wandb.watch(model) + + # print('+++++++++++++++++') + # print(type(model)) + # print('+++++++++++++++++') + # print(model) + # print('+++++++++++++++++') + # print(model.super().device) + # print('+++++++++++++++++') + # print(model.device) + + runner = get_runner_class(cfg)( + cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets + ) + runner.train() + + +if __name__ == "__main__": + main() diff --git a/train_configs/224_minigpt4_llama2_image.yaml b/train_configs/224_minigpt4_llama2_image.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f62a520b172851e94fee4bd02cf6e7e2df9529ed --- /dev/null +++ b/train_configs/224_minigpt4_llama2_image.yaml @@ -0,0 +1,59 @@ +model: + arch: minigpt4 + model_type: mini_gpt4_llama_v2 + llama_model: "meta-llama/Llama-2-7b-chat-hf" + + +datasets: + laion: + batch_size: 64 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 115 + cc_sbu: + batch_size: 64 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 14 + + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 4 + num_workers: 4 + warmup_steps: 5000 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "output/minigpt4_stage1_pretrain" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + + wandb_log: True + job_name: minigpt4_llama2_pretrain diff --git a/train_configs/224_minigpt4_mistral_image.yaml b/train_configs/224_minigpt4_mistral_image.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7b198a8833c31755c45408ac3bc5e2a53634c7c --- /dev/null +++ b/train_configs/224_minigpt4_mistral_image.yaml @@ -0,0 +1,59 @@ +model: + arch: minigpt4 + model_type: mini_gpt4_llama_v2 + llama_model: "mistralai/Mistral-7B-Instruct-v0.2" + + +datasets: + laion: + batch_size: 64 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 115 + cc_sbu: + batch_size: 64 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 14 + + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 4 + num_workers: 4 + warmup_steps: 5000 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "output/minigpt4_stage1_pretrain" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + + wandb_log: True + job_name: minigpt4_llama2_pretrain diff --git a/train_configs/224_v2_llama2_video_stage_2.yaml b/train_configs/224_v2_llama2_video_stage_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29df1a9e7e66c76ddfd250e3c0581e4fce2e189e --- /dev/null +++ b/train_configs/224_v2_llama2_video_stage_2.yaml @@ -0,0 +1,70 @@ +model: + arch: mini_gpt4_llama_v2 + model_type: pretrain_vicuna + freeze_vit: True + freeze_qformer: True + max_txt_len: 256 + low_resource: False + image_size: 224 + end_sym: "" + llama_model: "meta-llama/Llama-2-7b-chat-hf" + ckpt: "checkpoints/image_llama2_checkpoint.pth" + use_grad_checkpoint: True + chat_template: True + lora_r: 64 + lora_alpha: 16 + length: 50 + use_grad_checkpoint_llm: True + max_context_len: 3600 + token_pooling: True + + +datasets: + cmd_video: # 15938 + batch_size: 4 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 100 + webvid: # 42387 + batch_size: 4 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 50 + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 50 + num_workers: 16 + warmup_steps: 1000 + iters_per_epoch: 1000 + + seed: 42 + output_dir: "training_output/cmd_webvid_pretrain" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True diff --git a/train_configs/224_v2_llama2_video_stage_3.yaml b/train_configs/224_v2_llama2_video_stage_3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cdb236ee9df09c96b9adbb76814c47c6112ae7e3 --- /dev/null +++ b/train_configs/224_v2_llama2_video_stage_3.yaml @@ -0,0 +1,58 @@ +model: + arch: mini_gpt4_llama_v2 + model_type: pretrain_vicuna + freeze_vit: True + freeze_qformer: True + max_txt_len: 256 + low_resource: False + image_size: 224 + end_sym: "" + llama_model: "meta-llama/Llama-2-7b-chat-hf" + ckpt: "checkpoints/video_captioning_llama_checkpoint_last.pth" + use_grad_checkpoint: True + chat_template: True + lora_r: 64 + lora_alpha: 16 + length: 50 + use_grad_checkpoint_llm: True + max_context_len: 3600 + + +datasets: + video_chatgpt: #99378 row - 13224 video + batch_size: 4 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 200 +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 50 + num_workers: 1 + warmup_steps: 1000 + iters_per_epoch: 1000 + + seed: 42 + output_dir: "training_output/pretrained_video_instruct" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True diff --git a/train_configs/224_v2_mistral_video_stage_2.yaml b/train_configs/224_v2_mistral_video_stage_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b731ddfc987e3453cce93b8bb69b99ab5a21066e --- /dev/null +++ b/train_configs/224_v2_mistral_video_stage_2.yaml @@ -0,0 +1,70 @@ +model: + arch: mini_gpt4_llama_v2 + model_type: pretrain_vicuna + freeze_vit: True + freeze_qformer: True + max_txt_len: 512 + low_resource: False + image_size: 224 + end_sym: "" + llama_model: "mistralai/Mistral-7B-Instruct-v0.2" + ckpt: "checkpoints/image_mistral_checkpoint.pth" + use_grad_checkpoint: True + chat_template: True + lora_r: 64 + lora_alpha: 16 + length: 50 + use_grad_checkpoint_llm: True + max_context_len: 7200 + + +datasets: + cmd_video: # 15938 + batch_size: 1 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 100 + webvid: # 42387 + batch_size: 1 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 50 + + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 50 + num_workers: 16 + warmup_steps: 875 + iters_per_epoch: 875 + + seed: 42 + output_dir: "training_output/cmd_webvid_pretrain" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True diff --git a/train_configs/224_v2_mistral_video_stage_3.yaml b/train_configs/224_v2_mistral_video_stage_3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4b77cd4547e226eb372c36e213594b387dc2c47 --- /dev/null +++ b/train_configs/224_v2_mistral_video_stage_3.yaml @@ -0,0 +1,60 @@ +model: + arch: mini_gpt4_llama_v2 + model_type: pretrain_vicuna + freeze_vit: True + freeze_qformer: True + max_txt_len: 512 + low_resource: False + image_size: 224 + end_sym: "" + llama_model: "mistralai/Mistral-7B-Instruct-v0.2" + ckpt: "checkpoints/video_captioning_mistral_checkpoint_last.pth" + use_grad_checkpoint: True + chat_template: True + lora_r: 64 + lora_alpha: 16 + length: 50 + use_grad_checkpoint_llm: True + max_context_len: 7200 + + +datasets: + video_chatgpt: #99378 row - 13224 video + batch_size: 1 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 200 + + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 50 + num_workers: 16 + warmup_steps: 875 + iters_per_epoch: 875 + + seed: 42 + output_dir: "training_output/pretrained_video_instruct" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True diff --git a/train_llamav2.py b/train_llamav2.py new file mode 100644 index 0000000000000000000000000000000000000000..a664cb0c52d910138771aec3fdd770cd13f7ccf0 --- /dev/null +++ b/train_llamav2.py @@ -0,0 +1,120 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +import minigpt4.tasks as tasks +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank, init_distributed_mode +from minigpt4.common.logger import setup_logger +from minigpt4.common.optims import ( + LinearWarmupCosineLRScheduler, + LinearWarmupStepLRScheduler, +) +from minigpt4.common.registry import registry +from minigpt4.common.utils import now + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * +import wandb + +def parse_args(): + parser = argparse.ArgumentParser(description="Training") + + parser.add_argument("--cfg-path", required=True, help="path to configuration file.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + parser.add_argument("--job_name",default="minigpt_spatial_coco_control",type=str) + + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +def get_runner_class(cfg): + """ + Get runner class from config. Default to epoch-based runner. + """ + runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base")) + + return runner_cls + + +def main(): + # allow auto-dl completes on main process without timeout when using NCCL backend. + # os.environ["NCCL_BLOCKING_WAIT"] = "1" + + # set before init_distributed_mode() to ensure the same job_id shared across all ranks. + job_id = now() + args = parse_args() + cfg = Config(args) + init_distributed_mode(cfg.run_cfg) + setup_seeds(cfg) + + # set after in + # it_distributed_mode() to only log on master. + setup_logger() + wandb.login() + # print(wandb.run) + cfg.pretty_print() + + task = tasks.setup_task(cfg) + datasets = task.build_datasets(cfg) + model = task.build_model(cfg) + if not hasattr(cfg.run_cfg, 'rank') or cfg.run_cfg.rank == 0: + print("project name", args.job_name) + + wandb.init(project="minigpt4-spatial",name=args.job_name) + + wandb.config = {"learning_rate": 0.0001, "epochs": 100, "batch_size": 8} + wandb.watch(model) + + # print('+++++++++++++++++') + # print(type(model)) + # print('+++++++++++++++++') + # print(model) + # print('+++++++++++++++++') + # print(model.super().device) + # print('+++++++++++++++++') + # print(model.device) + + runner = get_runner_class(cfg)( + cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets + ) + runner.train() + + +if __name__ == "__main__": + main() diff --git a/train_multinode.py b/train_multinode.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb7fc0f84e1e038c235230a396e0fb63da1fc8a --- /dev/null +++ b/train_multinode.py @@ -0,0 +1,152 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +import minigpt4.tasks as tasks +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank, init_distributed_mode +from minigpt4.common.logger import setup_logger +from minigpt4.common.optims import ( + LinearWarmupCosineLRScheduler, + LinearWarmupStepLRScheduler, +) +from minigpt4.common.registry import registry +from minigpt4.common.utils import now + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * +import wandb +import torch.distributed as dist + +def parse_args(): + parser = argparse.ArgumentParser(description="Training",add_help=False) + + parser.add_argument("--cfg-path", required=True, help="path to configuration file.") + parser.add_argument( + "--options", + nargs="+" + ) + parser.add_argument("--job_name",default="minigpt_spatial_coco_control",type=str) + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + + # args = parser.parse_args() + + + + + return parser + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +def get_runner_class(cfg): + """ + Get runner class from config. Default to epoch-based runner. + """ + runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base")) + + return runner_cls + + +def main(): + # allow auto-dl completes on main process without timeout when using NCCL backend. + # os.environ["NCCL_BLOCKING_WAIT"] = "1" + + # set before init_distributed_mode() to ensure the same job_id shared across all ranks. + + print("start!!!") + job_id = now() + args = parse_args().parse_args() + + + print("0000") + cfg = Config(args) + + if 'LOCAL_RANK' not in os.environ: + print("not in the os") + os.environ['LOCAL_RANK'] = str(args.local_rank) + print("111") + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + torch.cuda.set_device(local_rank) + + print("local rank",local_rank) + + dist.init_process_group(backend='nccl', init_method='env://') + + num_nodes = dist.get_world_size() + print(f"Number of nodes: {num_nodes}") + + + init_distributed_mode(cfg.run_cfg) + + setup_seeds(cfg) + + # set after in + # it_distributed_mode() to only log on master. + setup_logger() + + + wandb.login() + # print(wandb.run) + + + cfg.pretty_print() + + task = tasks.setup_task(cfg) + datasets = task.build_datasets(cfg) + model = task.build_model(cfg) + if cfg.run_cfg.rank == 0: + print("project name", args.job_name) + + wandb.init(project="minigpt4-spatial",name=args.job_name) + + wandb.config = {"learning_rate": 0.0001, "epochs": 100, "batch_size": 8} + wandb.watch(model) + + # print('+++++++++++++++++') + # print(type(model)) + # print('+++++++++++++++++') + # print(model) + # print('+++++++++++++++++') + # print(model.super().device) + # print('+++++++++++++++++') + # print(model.device) + + runner = get_runner_class(cfg)( + cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets + ) + runner.train() + + +if __name__ == "__main__": + main()