Spaces:
Running
on
Zero
Running
on
Zero
Vision-CAIR
commited on
Commit
·
a6e7156
1
Parent(s):
a2c1e4d
apply_latest_updates
Browse files- Custom_training.md +33 -0
- README.md +212 -12
- environment.yml +1 -2
- eval_video.py +1 -1
- jobs_video/train/stage_2.sh +23 -0
- jobs_video/train/stage_3.sh +23 -0
- minigpt4/configs/datasets/cc_sbu/align.yaml +2 -1
- minigpt4/configs/datasets/cmd_video/default.yaml +4 -3
- minigpt4/configs/datasets/template/default.yaml +16 -0
- minigpt4/configs/datasets/video_chatgpt/default.yaml +4 -8
- minigpt4/configs/datasets/webvid/default.yaml +4 -3
- minigpt4/datasets/builders/image_text_pair_builder.py +96 -3
- minigpt4/datasets/datasets/cc_sbu_dataset.py +47 -0
- minigpt4/datasets/datasets/video_datasets.py +211 -174
- minigpt4/models/mini_gpt4_llama_v2.py +11 -85
- minigpt4/runners/runner_base.py +4 -4
- minigpt4_video_demo.py +7 -7
- minigpt4_video_inference.py +170 -84
- train_configs/224_minigpt4_llama2_image.yaml +5 -3
- train_configs/224_minigpt4_llama2_image_align.yaml +53 -0
- train_configs/224_minigpt4_mistral_image.yaml +6 -5
- train_configs/224_minigpt4_mistral_image_align.yaml +53 -0
- train_configs/224_v2_llama2_video_stage_2.yaml +2 -2
- train_configs/224_v2_llama2_video_stage_3.yaml +3 -3
- train_configs/224_v2_mistral_video_stage_2.yaml +2 -2
- train_configs/224_v2_mistral_video_stage_3.yaml +2 -2
- train_configs/alignment.txt +4 -0
Custom_training.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Customizing MiniGPT4-video for your own Video-text dataset
|
2 |
+
|
3 |
+
## Add your own video dataloader
|
4 |
+
Construct your own dataloader here `minigpt4/datasets/datasets/video_datasets.py` based on the existing dataloaders.<br>
|
5 |
+
Copy Video_loader_template class and edit it according to you data nature.
|
6 |
+
|
7 |
+
## Create config file for your dataloader
|
8 |
+
Here `minigpt4/configs/datasets/dataset_name/default.yaml` creates your yaml file that includes paths to your dataset.<br>
|
9 |
+
Copy the template file `minigpt4/configs/datasets/template/default.yaml` and edit the paths to your dataset.
|
10 |
+
|
11 |
+
|
12 |
+
## Register your dataloader
|
13 |
+
In the `minigpt4/datasets/builders/image_text_pair_builder.py` file
|
14 |
+
Import your data loader class from the `minigpt4/datasets/datasets/video_datasets.py` file <br>
|
15 |
+
Copy and edit the VideoTemplateBuilder class.<br>
|
16 |
+
put the train_dataset_cls = YourVideoLoaderClass that you imported from `minigpt4/datasets/datasets/video_datasets.py` file.
|
17 |
+
|
18 |
+
## Edit training config file
|
19 |
+
Add your dataset to the datasets in the yml file as shown below:
|
20 |
+
```yaml
|
21 |
+
datasets:
|
22 |
+
dataset_name: # change this to your dataset name
|
23 |
+
batch_size: 4 # change this to your desired batch size
|
24 |
+
vis_processor:
|
25 |
+
train:
|
26 |
+
name: "blip2_image_train"
|
27 |
+
image_size: 224
|
28 |
+
text_processor:
|
29 |
+
train:
|
30 |
+
name: "blip_caption"
|
31 |
+
sample_ratio: 200 # if you including joint training with other datasets, you can set the sample ratio here
|
32 |
+
```
|
33 |
+
|
README.md
CHANGED
@@ -1,12 +1,212 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MiniGPT4-Video: Advancing Multimodal LLMs for Video Understanding with Interleaved Visual-Textual Tokens
|
2 |
+
<!-- technical report link -->
|
3 |
+
<!-- demo link -->
|
4 |
+
<a href='https://vision-cair.github.io/MiniGPT4-video/'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
5 |
+
<a href='https://arxiv.org/abs/2404.03413'><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
|
6 |
+
<a href='https://23e140b581cffa9101.gradio.live'><img src='https://img.shields.io/badge/Project-Demo-violet'></a>
|
7 |
+
<!-- <a href='https://github.com/Vision-CAIR/MiniGPT4-video'><img src='https://img.shields.io/badge/Github-Code-blue'></a> -->
|
8 |
+
![demo_1](repo_imgs/sample_1.gif)
|
9 |
+
![demo_2](repo_imgs/sample_2.gif)
|
10 |
+
![demo_3](repo_imgs/sample_3.gif)
|
11 |
+
## Overview
|
12 |
+
This paper introduces MiniGPT4-Video, a multimodal Large Language Model (LLM) designed specifically for video understanding. The model is capable of processing both temporal visual and textual data, making it adept at understanding the complexities of videos.
|
13 |
+
Building upon the success of MiniGPT-v2, which excelled in translating visual features into the LLM space for single images and achieved impressive results on various image-text benchmarks, this paper extends the model's capabilities to process a sequence of frames, enabling it to comprehend videos.
|
14 |
+
MiniGPT4-video does not only consider visual content but also incorporates textual conversations, allowing the model to effectively answer queries involving both visual and text components. The proposed model outperforms existing state-of-the-art methods, registering gains of 4.22%, 1.13%, 20.82%, and 13.1% on the MSVD, MSRVTT, TGIF, and TVQA benchmarks respectively.
|
15 |
+
During inference, a speech to text model such as Whisper model is utilized to generate subtitles for the video. Then, both the video and the subtitle are input to the MiniGPT4-Video model with the instruction and the model outputs the answer.
|
16 |
+
![methodology](repo_imgs/MiniGPT4-video_fig.jpg)
|
17 |
+
|
18 |
+
## :rocket: Demo
|
19 |
+
**1. Clone the repository** <br>
|
20 |
+
```bash
|
21 |
+
git clone https://github.com/Vision-CAIR/MiniGPT4-video.git
|
22 |
+
cd MiniGPT4-video
|
23 |
+
```
|
24 |
+
|
25 |
+
**2. Set up the environment** <br>
|
26 |
+
```bash
|
27 |
+
conda env create -f environment.yml
|
28 |
+
```
|
29 |
+
**3. Download the checkpoints**
|
30 |
+
|
31 |
+
| MiniGPT4-Video (Llama2 Chat 7B) | MiniGPT4-Video (Mistral 7B) |
|
32 |
+
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
33 |
+
| [Download](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_llama_checkpoint_last.pth) | [Download](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_mistral_checkpoint_last.pth) |
|
34 |
+
|
35 |
+
**4. Run the demo** <br>
|
36 |
+
|
37 |
+
```bash
|
38 |
+
# Llama2
|
39 |
+
python minigpt4_video_demo.py --ckpt path_to_video_checkpoint --cfg-path test_configs/llama2_test_config.yaml
|
40 |
+
# Mistral
|
41 |
+
python minigpt4_video_demo.py --ckpt path_to_video_checkpoint --cfg-path test_configs/mistral_test_config.yaml
|
42 |
+
```
|
43 |
+
### Inference
|
44 |
+
Do the previous steps and replace step 4 with this step
|
45 |
+
|
46 |
+
```bash
|
47 |
+
# Llama2
|
48 |
+
python minigpt4_video_inference.py --ckpt path_to_video_checkpoint --cfg-path test_configs/llama2_test_config.yaml --video_path path_to_video --question "Your question here"
|
49 |
+
# Mistral
|
50 |
+
python minigpt4_video_inference.py --ckpt path_to_video_checkpoint --cfg-path test_configs/mistral_test_config.yaml --video_path path_to_video --question "Your question here"
|
51 |
+
```
|
52 |
+
## :fire: Training
|
53 |
+
### To customize MiniGPT4-Video for your own Video-text dataset
|
54 |
+
<!-- point to file here Custom_training.md -->
|
55 |
+
You can find the steps to customize MiniGPT4-Video for your own video-text dataset in [Custom_training.md](Custom_training.md)
|
56 |
+
### Training datasets
|
57 |
+
After downloading the datasets below, **you should go to the datasets configuration folder here minigpt4/configs/datasets set the paths for each dataset there.**<br>
|
58 |
+
Image text training<br>
|
59 |
+
You can find the steps to download the datasets in [MiniGPT4](https://github.com/Vision-CAIR/MiniGPT-4/tree/main/dataset)<br>
|
60 |
+
+ LAION <br>
|
61 |
+
+ Conceptual Captions <br>
|
62 |
+
+ SBU <br>
|
63 |
+
|
64 |
+
Video text training:<br>
|
65 |
+
|
66 |
+
+ [CMD](https://www.robots.ox.ac.uk/~vgg/data/condensed-movies/) <br>
|
67 |
+
+ [Webvid](https://github.com/m-bain/webvid/) <br> <!-- + [Webvid](https://huggingface.co/datasets/TempoFunk/webvid-10M?row=2) <br> -->
|
68 |
+
+ [Video Instructional Dataset 100K](https://huggingface.co/datasets/MBZUAI/VideoInstruct-100K) <br>
|
69 |
+
|
70 |
+
You can find the datasets annotation files for video_text datasets here [download](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/tree/main/datasets/training_datasets) <br>
|
71 |
+
|
72 |
+
|
73 |
+
### Model training:
|
74 |
+
You can edit the number of gpus in the each script.sh below<br>
|
75 |
+
#### Stage 1 (image text pretraining)
|
76 |
+
|
77 |
+
You can directly download the pretrained MiniGPT4 [checkpoint](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) aligned with Llama2. <br>
|
78 |
+
|
79 |
+
Or train by yourself:
|
80 |
+
|
81 |
+
```bash
|
82 |
+
# pretrain
|
83 |
+
# Llama2
|
84 |
+
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/224_minigpt4_llama2_image.yaml
|
85 |
+
# Mistral
|
86 |
+
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/224_minigpt4_mistral_image.yaml
|
87 |
+
|
88 |
+
# align
|
89 |
+
# To launch the second stage alignment, first specify the path to the checkpoint file trained in pretrain stage.
|
90 |
+
# Llama2
|
91 |
+
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/224_minigpt4_llama2_image_align.yaml
|
92 |
+
# Mistral
|
93 |
+
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/224_minigpt4_mistral_image_align.yaml
|
94 |
+
```
|
95 |
+
You can download our trained weights for this stage from here [Llama2](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/image_llama2_checkpoint.pth) [Mistral](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/image_mistral_checkpoint.pth)<br>
|
96 |
+
#### Stage 2 (video captioning pretraining)
|
97 |
+
|
98 |
+
For **Llama2** <br>
|
99 |
+
set the cfg-path in the script to `train_configs/224_v2_llama2_video_stage_2.yaml` <br>
|
100 |
+
set the model name here `minigpt4/configs/datasets/cmd_video/default.yaml` and `minigpt4/configs/datasets/webvid/default.yaml` to llama2<br>
|
101 |
+
For **Mistral**<br>
|
102 |
+
set the cfg-path in the script to `train_configs/224_v2_mistral_video_stage_2.yaml` <br>
|
103 |
+
set the model name here `minigpt4/configs/datasets/cmd_video/default.yaml` and `minigpt4/configs/datasets/webvid/default.yaml` to mistral<br>
|
104 |
+
|
105 |
+
```bash
|
106 |
+
bash jobs_video/train/stage_2.sh
|
107 |
+
```
|
108 |
+
You can download our trained weights for this stage from here [Llama2](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_captioning_llama_checkpoint_last.pth) [Mistral](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_captioning_mistral_checkpoint_last.pth)<br>
|
109 |
+
|
110 |
+
#### Stage 3 (video Instruction finetuning)
|
111 |
+
|
112 |
+
For **Llama2** <br>
|
113 |
+
set the cfg-path in the script to `train_configs/224_v2_llama2_video_stage_3.yaml` <br>
|
114 |
+
set the model name here `minigpt4/configs/datasets/video_chatgpt/default.yaml` to llama2<br>
|
115 |
+
|
116 |
+
For **Mistral**<br>
|
117 |
+
set the cfg-path in the script to `train_configs/224_v2_mistral_video_stage_3.yaml` <br>
|
118 |
+
set the model name here `minigpt4/configs/datasets/video_chatgpt/default.yaml` to mistral<br>
|
119 |
+
|
120 |
+
```bash
|
121 |
+
bash jobs_video/train/stage_3.sh
|
122 |
+
```
|
123 |
+
You can download our trained weights for this stage from here [Llama2](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_llama_checkpoint_last.pth) [Mistral](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_mistral_checkpoint_last.pth)<br>
|
124 |
+
|
125 |
+
## :zap: Evaluation
|
126 |
+
To reproduce the results use the best checkpoints for each model <br>
|
127 |
+
[Llama2](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_captioning_llama_checkpoint_best.pth) [Mistral](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_captioning_mistral_checkpoint_best.pth)<br>
|
128 |
+
We used the same evaluation as [Video-ChatGPT](https://mbzuai-oryx.github.io/Video-ChatGPT/)<br>
|
129 |
+
<!-- ![short_results](repo_imgs/short_results.PNG) -->
|
130 |
+
|
131 |
+
|Method| Using Subtitles | Information Correctness | Detailed Orientation | Contextual Understanding | Temporal Understanding | Consistency |
|
132 |
+
|:--------------------:|:----:|:------------------------:|:---------------------:|:-------------------------:|:-----------------------:|:------------:|
|
133 |
+
| LLaMA Adapter | :x:| 2.03 | 2.32| 2.30| 1.98| 2.15 |
|
134 |
+
| Video LLaMA| :x:| 1.96 | 2.18| 2.16| 1.82| 1.79 |
|
135 |
+
| Video Chat| :x:| 2.23 | 2.50| 2.53| 1.94| 2.24 |
|
136 |
+
| Video-ChatGPT | :x:| 2.40 | 2.52| 2.62| 1.98| 2.37 |
|
137 |
+
| BT-Adapter-7B | :x:| 2.68 | 2.69| 3.27| 2.34| 2.46 |
|
138 |
+
| LLaMA-VID-7B| :x:| 2.96 | 3.00| 3.53| 2.46| 2.51 |
|
139 |
+
| **Ours-7B Llama2**| :x:| 2.93 | 2.97| 3.45| **2.47**| **2.60**|
|
140 |
+
| **Ours-7B Llama2**| :white_check_mark:| **3.08** | **3.02**| **3.57**| **2.65**| **2.67**|
|
141 |
+
| **Ours-7B Mistral** | :x:| 2.83|2.52 |3.01 |2.32 |2.40 |
|
142 |
+
| **Ours-7B Mistral**| :white_check_mark:| 2.91 | 2.57| 3.11|2.33 | 2.39|
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
|Method| Using Subtitles | MSVD Acc.↑ | MSVD Score↑ | MSRVTT Acc.↑ | MSRVTT Score↑ | TGIF Acc.↑ | TGIF Score↑ | ActivityNet Acc.↑ | ActivityNet Score↑ | TVQA Acc.↑ |
|
147 |
+
|:---------------------------------------:|:----------------:|:-----------:|:------------:|:--------------:|:---------------:|:-----------:|:------------:|:-------------------:|:--------------------:|:------------:|
|
148 |
+
| FrozenBiLM|:x:|32.2| --|16.8 |--| 41 |-- |24.7|--|29.7 |
|
149 |
+
| LLaMA Adapter|:x:|54.9| 3.1 |43.8 |2.7| -- |-- |34.2| 2.7| --|
|
150 |
+
| Video LLaMA|:x:|51.6| 2.5 |29|1.8| -- |-- |12.4| 1.1| --|
|
151 |
+
| Video Chat|:x:|56.3| 2.8 |45|2.5|34.4| 2.3 |26.5| 2.2|--|
|
152 |
+
| Video-ChatGPT|:x:|64.9| 3.3 |49.3 |2.8|51.4| 3.0 |35.2| 2.7|23.35|
|
153 |
+
| BT-Adapter-7B|:x:|67.7| 3.7 |57|3.2| -- |-- |45.7| 3.2| --|
|
154 |
+
| LLaMA-VID-7B |:x:|69.7| 3.7 |57.7 |3.2| -- |-- |**47.4**| **3.3**| --|
|
155 |
+
| **Ours-7B LLama2**|:x:|72.93|3.84|58.83|3.29|67.9|3.71| 45.85 |3.23|36.45|
|
156 |
+
| **Ours-7B Llama2**|:white_check_mark:|72.93|3.84|**59.73**|**3.3** |67.9|3.71| 46.3|3.4 |46.94|
|
157 |
+
| **Ours-7B Mistral**|:x:|**73.92**|**4.06**|58.26|3.52|**72.22**|**4.08**|44.25 |3.35|33.90|
|
158 |
+
| **Ours-7B Mistral**|:white_check_mark:|**73.92**|**4.06**|58.68|3.53 |**72.22**|**4.08**| 44.38|3.36 |**54.21** |
|
159 |
+
|
160 |
+
### Download datasets for evaluation
|
161 |
+
+ [MSVD](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/) <br>
|
162 |
+
+ [MSRVTT](https://cove.thecvf.com/datasets/839) <br>
|
163 |
+
+ [TGIF](https://github.com/YunseokJANG/tgif-qa/blob/master/dataset/README.md) <br>
|
164 |
+
+ [ActivityNet](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/hanoona_bangalath_mbzuai_ac_ae/ESa302OCJMNHsMk7wuBbQc8BZH5CqlcdCWiSpXynQZDfAQ?e=CrOPbm) <br>
|
165 |
+
+ [TVQA](https://tvqa.cs.unc.edu/) <br>
|
166 |
+
+ [Video-ChatGPT benchmark](https://mbzuai-oryx.github.io/Video-ChatGPT/) <br>
|
167 |
+
|
168 |
+
You can find the evaluation datasets annotation files [download](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/tree/main/datasets/evaluation_datasets) <br>
|
169 |
+
|
170 |
+
### Run evaluation script
|
171 |
+
Set the each evaluation script parameters to include the path to the checkpoints, the dataset name and whether to use subtitles or not <br>
|
172 |
+
|
173 |
+
```bash
|
174 |
+
# Llama2
|
175 |
+
bash jobs_video/eval/llama2_evaluation.sh
|
176 |
+
# Mistral
|
177 |
+
bash jobs_video/eval/mistral_evalualtion.sh
|
178 |
+
```
|
179 |
+
Then Use GPT3.5 turbo to compare the predictions with the ground truth and generate the accuracy and scores <br>
|
180 |
+
Set these variables in both evaluate_benchmark.sh and evaluate_zeroshot.sh <br>
|
181 |
+
```bash
|
182 |
+
PRED="path_to_predictions"
|
183 |
+
OUTPUT_DIR="path_to_output_dir"
|
184 |
+
API_KEY="openAI_key"
|
185 |
+
NUM_TASKS=128
|
186 |
+
```
|
187 |
+
Then to evaluate [Video-ChatGPT benchmark] run the following script <br>
|
188 |
+
```bash
|
189 |
+
bash test_benchmark/quantitative_evaluation/evaluate_benchmark.sh
|
190 |
+
```
|
191 |
+
To evaluate open ended questions run the following script <br>
|
192 |
+
```bash
|
193 |
+
bash test_benchmark/quantitative_evaluation/evaluate_zeroshot.sh
|
194 |
+
```
|
195 |
+
|
196 |
+
If you're using MiniGPT4-Video in your research or applications, please cite using this BibTeX:
|
197 |
+
```
|
198 |
+
@article{ataallah2024minigpt4,
|
199 |
+
title={MiniGPT4-Video: Advancing Multimodal LLMs for Video Understanding with Interleaved Visual-Textual Tokens},
|
200 |
+
author={Ataallah, Kirolos and Shen, Xiaoqian and Abdelrahman, Eslam and Sleiman, Essam and Zhu, Deyao and Ding, Jian and Elhoseiny, Mohamed},
|
201 |
+
journal={arXiv preprint arXiv:2404.03413},
|
202 |
+
year={2024}
|
203 |
+
}
|
204 |
+
```
|
205 |
+
|
206 |
+
## Acknowledgements
|
207 |
+
[MiniGPT4](https://github.com/Vision-CAIR/MiniGPT-4) <br>
|
208 |
+
[Video-ChatGPT](https://mbzuai-oryx.github.io/Video-ChatGPT)
|
209 |
+
|
210 |
+
## License
|
211 |
+
This repository is under [BSD 3-Clause License](LICENSE.md).
|
212 |
+
Many codes are based on [MiniGPT4](https://github.com/Vision-CAIR/MiniGPT-4).
|
environment.yml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
name:
|
2 |
channels:
|
3 |
- conda-forge
|
4 |
dependencies:
|
@@ -143,7 +143,6 @@ dependencies:
|
|
143 |
- ffmpeg-python==0.2.0
|
144 |
- ffmpy==0.3.1
|
145 |
- filelock==3.13.1
|
146 |
-
- flash-attn==2.5.4
|
147 |
- flask==3.0.2
|
148 |
- flatbuffers==23.5.26
|
149 |
- fonttools==4.47.0
|
|
|
1 |
+
name: minigpt4_video
|
2 |
channels:
|
3 |
- conda-forge
|
4 |
dependencies:
|
|
|
143 |
- ffmpeg-python==0.2.0
|
144 |
- ffmpy==0.3.1
|
145 |
- filelock==3.13.1
|
|
|
146 |
- flask==3.0.2
|
147 |
- flatbuffers==23.5.26
|
148 |
- fonttools==4.47.0
|
eval_video.py
CHANGED
@@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
|
|
5 |
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
|
6 |
from minigpt4.conversation.conversation import CONV_VISION
|
7 |
from minigpt4.processors.blip_processors import Blip2ImageTrainProcessor,BlipCaptionProcessor
|
8 |
-
from minigpt4.datasets.datasets.video_datasets import VideoChatGPTEvalDataset,VideoChatGPTEval_consistancy,Video_validation_Dataset,TVQAEVAL
|
9 |
|
10 |
parser = eval_parser()
|
11 |
parser.add_argument("--dataset", type=str, default='msvd', help="dataset to evaluate")
|
|
|
5 |
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
|
6 |
from minigpt4.conversation.conversation import CONV_VISION
|
7 |
from minigpt4.processors.blip_processors import Blip2ImageTrainProcessor,BlipCaptionProcessor
|
8 |
+
from minigpt4.datasets.datasets.video_datasets import VideoChatGPTEvalDataset,VideoChatGPTEval_consistancy,Video_validation_Dataset,TVQAEVAL
|
9 |
|
10 |
parser = eval_parser()
|
11 |
parser.add_argument("--dataset", type=str, default='msvd', help="dataset to evaluate")
|
jobs_video/train/stage_2.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=batch
|
3 |
+
#SBATCH --job-name=test
|
4 |
+
#SBATCH --output=test.out
|
5 |
+
#SBATCH --error=test.err
|
6 |
+
#SBATCH --time=23:00:00
|
7 |
+
#SBATCH --mem=110G
|
8 |
+
#SBATCH --gres=gpu:a100:4
|
9 |
+
#SBATCH --cpus-per-task=16
|
10 |
+
## run the application:
|
11 |
+
job_name=test # Name of the experiment
|
12 |
+
cfg_path="train_configs/224_v2_llama2_video_stage_2.yaml" # path to the config file
|
13 |
+
number_of_gpus=1 # number of gpus
|
14 |
+
# cd ../../
|
15 |
+
|
16 |
+
read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range
|
17 |
+
while :
|
18 |
+
do
|
19 |
+
PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`"
|
20 |
+
ss -lpn | grep -q ":$PORT " || break
|
21 |
+
done
|
22 |
+
echo "Port is $PORT"
|
23 |
+
torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path}
|
jobs_video/train/stage_3.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition=batch
|
3 |
+
#SBATCH --job-name=test
|
4 |
+
#SBATCH --output=test.out
|
5 |
+
#SBATCH --error=test.err
|
6 |
+
#SBATCH --time=23:00:00
|
7 |
+
#SBATCH --mem=110G
|
8 |
+
#SBATCH --gres=gpu:a100:4
|
9 |
+
#SBATCH --cpus-per-task=16
|
10 |
+
## run the application:
|
11 |
+
job_name="test" # Name of the experiment
|
12 |
+
cfg_path="train_configs/224_v2_llama2_video_stage_3.yaml" # path to the config file
|
13 |
+
number_of_gpus=1 # number of gpus
|
14 |
+
# cd ../../
|
15 |
+
|
16 |
+
read LOWERPORT UPPERPORT < /proc/sys/net/ipv4/ip_local_port_range
|
17 |
+
while :
|
18 |
+
do
|
19 |
+
PORT="`shuf -i $LOWERPORT-$UPPERPORT -n 1`"
|
20 |
+
ss -lpn | grep -q ":$PORT " || break
|
21 |
+
done
|
22 |
+
echo "Port is $PORT"
|
23 |
+
torchrun --master-port ${PORT} --nproc-per-node $number_of_gpus train.py --job_name ${job_name} --cfg-path ${cfg_path}
|
minigpt4/configs/datasets/cc_sbu/align.yaml
CHANGED
@@ -2,4 +2,5 @@ datasets:
|
|
2 |
cc_sbu_align:
|
3 |
data_type: images
|
4 |
build_info:
|
5 |
-
storage: "/ibex/project/
|
|
|
|
2 |
cc_sbu_align:
|
3 |
data_type: images
|
4 |
build_info:
|
5 |
+
# storage: "/ibex/project/c2090/datasets/cc_sbu_align"
|
6 |
+
storage: "path/to/cc_sbu_align/dataset"
|
minigpt4/configs/datasets/cmd_video/default.yaml
CHANGED
@@ -10,6 +10,7 @@ datasets:
|
|
10 |
|
11 |
build_info:
|
12 |
# Be careful not to append minus sign (-) before split to avoid itemizing
|
13 |
-
vis_root: /
|
14 |
-
ann_paths: [
|
15 |
-
|
|
|
|
10 |
|
11 |
build_info:
|
12 |
# Be careful not to append minus sign (-) before split to avoid itemizing
|
13 |
+
vis_root: path/to/videos/
|
14 |
+
ann_paths: [path/to/annotations.json]
|
15 |
+
subtitles_path: path/to/subtitles_folder # folder that contains subtitles of .vtt format
|
16 |
+
model_name: 'llama2' # Language Model Name (available: llama2, mistral)
|
minigpt4/configs/datasets/template/default.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
2 |
+
# All rights reserved.
|
3 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
5 |
+
|
6 |
+
datasets:
|
7 |
+
dataset_name: # same as the name of the train_config yaml file
|
8 |
+
# data_dir: ${env.data_dir}/datasets
|
9 |
+
data_type: images # let it be images for now even if it is videos
|
10 |
+
|
11 |
+
build_info: # this is the information needed to build the dataset
|
12 |
+
# Be careful not to append minus sign (-) before split to avoid itemizing
|
13 |
+
ann_paths: [path/to/annotations_json] # list of paths to annotation files
|
14 |
+
vis_root: path/to/videos_folder
|
15 |
+
subtitles_path: path/to/subtitles_folder
|
16 |
+
model_name: 'llama2' # Language Model Name (available: llama2, mistral)
|
minigpt4/configs/datasets/video_chatgpt/default.yaml
CHANGED
@@ -10,11 +10,7 @@ datasets:
|
|
10 |
|
11 |
build_info:
|
12 |
# Be careful not to append minus sign (-) before split to avoid itemizing
|
13 |
-
ann_paths: [
|
14 |
-
vis_root: /
|
15 |
-
|
16 |
-
|
17 |
-
videos_path: "/ibex/project/c2090/datasets/VideoInstruct100K/test_videos/all_datasets_samples_val"
|
18 |
-
subtitles_path: "inference_subtitles"
|
19 |
-
annotations_keys: ['question','answer','video_id']
|
20 |
-
add_subtitles: True
|
|
|
10 |
|
11 |
build_info:
|
12 |
# Be careful not to append minus sign (-) before split to avoid itemizing
|
13 |
+
ann_paths: [path/to/annotations_json] # list of paths to annotation files
|
14 |
+
vis_root: path/to/videos_folder
|
15 |
+
subtitles_path: path/to/subtitles_folder # folder that contains subtitles of .vtt format
|
16 |
+
model_name: 'llama2' # Language Model Name (available: llama2, mistral)
|
|
|
|
|
|
|
|
minigpt4/configs/datasets/webvid/default.yaml
CHANGED
@@ -10,6 +10,7 @@ datasets:
|
|
10 |
|
11 |
build_info:
|
12 |
# Be careful not to append minus sign (-) before split to avoid itemizing
|
13 |
-
ann_paths: [
|
14 |
-
vis_root: /
|
15 |
-
subtitles_path: /
|
|
|
|
10 |
|
11 |
build_info:
|
12 |
# Be careful not to append minus sign (-) before split to avoid itemizing
|
13 |
+
ann_paths: [path/to/annotations.json]
|
14 |
+
vis_root: path/to/videos/
|
15 |
+
subtitles_path: path/to/subtitles_folder/ # folder that contains subtitles of .vtt format
|
16 |
+
model_name: 'llama2' # Language Model Name (available: llama2, mistral)
|
minigpt4/datasets/builders/image_text_pair_builder.py
CHANGED
@@ -5,6 +5,7 @@ import warnings
|
|
5 |
from minigpt4.common.registry import registry
|
6 |
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
7 |
from minigpt4.datasets.datasets.laion_dataset import LaionDataset
|
|
|
8 |
from minigpt4.datasets.datasets.vg_dataset import ReferVisualGenomeDataset
|
9 |
from minigpt4.datasets.datasets.open_images import OpenImageDataset,OpenBboxToObjectDataset
|
10 |
from minigpt4.datasets.datasets.locna_dataset import LocNaCOCODataset
|
@@ -16,7 +17,7 @@ from minigpt4.datasets.datasets.coyo_dataset import COYOCaptionWDSDataset,COYOBo
|
|
16 |
# , COYOBBoxPhraseDataset
|
17 |
from minigpt4.datasets.datasets.grounded_detailed_image_caption_dataset import GroundedDetailDataset
|
18 |
from minigpt4.datasets.datasets.reasoning_dataset import ReasoningDataset
|
19 |
-
from minigpt4.datasets.datasets.video_datasets import CMDVideoDataset, WebVidDataset,VideoChatGPTDataset
|
20 |
from minigpt4.datasets.datasets.cot import CoTDataset
|
21 |
from minigpt4.datasets.datasets.unnatural_instruction import UnnaturalDataset
|
22 |
from minigpt4.datasets.datasets.caption_reasoning import CaptionReasonDataset
|
@@ -441,9 +442,68 @@ class CoyoBboxPhraseBuilder(BaseDatasetBuilder):
|
|
441 |
return datasets
|
442 |
|
443 |
|
|
|
|
|
|
|
444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
|
|
|
|
|
446 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
|
448 |
|
449 |
@registry.register_builder("textcaps_ocr")
|
@@ -739,7 +799,8 @@ class CMDVideoBuilder(BaseDatasetBuilder):
|
|
739 |
text_processor=self.text_processors["train"],
|
740 |
vis_root=build_info.vis_root,
|
741 |
ann_paths=build_info.ann_paths,
|
742 |
-
|
|
|
743 |
)
|
744 |
|
745 |
return datasets
|
@@ -770,6 +831,7 @@ class WebVidBuilder(BaseDatasetBuilder):
|
|
770 |
vis_root=build_info.vis_root,
|
771 |
ann_paths=build_info.ann_paths,
|
772 |
subtitles_path=build_info.subtitles_path,
|
|
|
773 |
)
|
774 |
|
775 |
return datasets
|
@@ -778,7 +840,6 @@ class WebVidBuilder(BaseDatasetBuilder):
|
|
778 |
@registry.register_builder("video_chatgpt")
|
779 |
class VideoChatGPTBuilder(BaseDatasetBuilder):
|
780 |
train_dataset_cls = VideoChatGPTDataset
|
781 |
-
eval_dataset_cls=Video_validation_Dataset
|
782 |
|
783 |
DATASET_CONFIG_DICT = {
|
784 |
"default": "configs/datasets/video_chatgpt/default.yaml",
|
@@ -800,6 +861,38 @@ class VideoChatGPTBuilder(BaseDatasetBuilder):
|
|
800 |
text_processor=self.text_processors["train"],
|
801 |
vis_root=build_info.vis_root,
|
802 |
ann_paths=build_info.ann_paths,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
803 |
)
|
804 |
|
805 |
return datasets
|
|
|
5 |
from minigpt4.common.registry import registry
|
6 |
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
7 |
from minigpt4.datasets.datasets.laion_dataset import LaionDataset
|
8 |
+
from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
|
9 |
from minigpt4.datasets.datasets.vg_dataset import ReferVisualGenomeDataset
|
10 |
from minigpt4.datasets.datasets.open_images import OpenImageDataset,OpenBboxToObjectDataset
|
11 |
from minigpt4.datasets.datasets.locna_dataset import LocNaCOCODataset
|
|
|
17 |
# , COYOBBoxPhraseDataset
|
18 |
from minigpt4.datasets.datasets.grounded_detailed_image_caption_dataset import GroundedDetailDataset
|
19 |
from minigpt4.datasets.datasets.reasoning_dataset import ReasoningDataset
|
20 |
+
from minigpt4.datasets.datasets.video_datasets import CMDVideoDataset, WebVidDataset,VideoChatGPTDataset
|
21 |
from minigpt4.datasets.datasets.cot import CoTDataset
|
22 |
from minigpt4.datasets.datasets.unnatural_instruction import UnnaturalDataset
|
23 |
from minigpt4.datasets.datasets.caption_reasoning import CaptionReasonDataset
|
|
|
442 |
return datasets
|
443 |
|
444 |
|
445 |
+
@registry.register_builder("cc_sbu_align")
|
446 |
+
class CCSBUAlignBuilder(BaseDatasetBuilder):
|
447 |
+
train_dataset_cls = CCSBUAlignDataset
|
448 |
|
449 |
+
DATASET_CONFIG_DICT = {
|
450 |
+
"default": "configs/datasets/cc_sbu/align.yaml",
|
451 |
+
}
|
452 |
+
|
453 |
+
def build_datasets(self):
|
454 |
+
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
455 |
+
logging.info("Building datasets...")
|
456 |
+
self.build_processors()
|
457 |
+
|
458 |
+
build_info = self.config.build_info
|
459 |
+
storage_path = build_info.storage
|
460 |
+
|
461 |
+
datasets = dict()
|
462 |
+
|
463 |
+
if not os.path.exists(storage_path):
|
464 |
+
warnings.warn("storage path {} does not exist.".format(storage_path))
|
465 |
+
|
466 |
+
# create datasets
|
467 |
+
dataset_cls = self.train_dataset_cls
|
468 |
+
datasets['train'] = dataset_cls(
|
469 |
+
vis_processor=self.vis_processors["train"],
|
470 |
+
text_processor=self.text_processors["train"],
|
471 |
+
ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
|
472 |
+
vis_root=os.path.join(storage_path, 'image'),
|
473 |
+
)
|
474 |
+
|
475 |
+
return datasets
|
476 |
+
|
477 |
+
@registry.register_builder("cc_sbu")
|
478 |
+
class CCSBUBuilder(BaseDatasetBuilder):
|
479 |
+
train_dataset_cls = CCSBUDataset
|
480 |
+
|
481 |
+
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
|
482 |
+
|
483 |
+
def _download_ann(self):
|
484 |
+
pass
|
485 |
+
|
486 |
+
def _download_vis(self):
|
487 |
+
pass
|
488 |
+
|
489 |
+
def build(self):
|
490 |
+
self.build_processors()
|
491 |
+
|
492 |
+
build_info = self.config.build_info
|
493 |
|
494 |
+
datasets = dict()
|
495 |
+
split = "train"
|
496 |
|
497 |
+
# create datasets
|
498 |
+
# [NOTE] return inner_datasets (wds.DataPipeline)
|
499 |
+
dataset_cls = self.train_dataset_cls
|
500 |
+
datasets[split] = dataset_cls(
|
501 |
+
vis_processor=self.vis_processors[split],
|
502 |
+
text_processor=self.text_processors[split],
|
503 |
+
location=build_info.storage,
|
504 |
+
).inner_dataset
|
505 |
+
|
506 |
+
return datasets
|
507 |
|
508 |
|
509 |
@registry.register_builder("textcaps_ocr")
|
|
|
799 |
text_processor=self.text_processors["train"],
|
800 |
vis_root=build_info.vis_root,
|
801 |
ann_paths=build_info.ann_paths,
|
802 |
+
subtitles_path=build_info.subtitles_path,
|
803 |
+
model_name= build_info.model_name,
|
804 |
)
|
805 |
|
806 |
return datasets
|
|
|
831 |
vis_root=build_info.vis_root,
|
832 |
ann_paths=build_info.ann_paths,
|
833 |
subtitles_path=build_info.subtitles_path,
|
834 |
+
model_name= build_info.model_name,
|
835 |
)
|
836 |
|
837 |
return datasets
|
|
|
840 |
@registry.register_builder("video_chatgpt")
|
841 |
class VideoChatGPTBuilder(BaseDatasetBuilder):
|
842 |
train_dataset_cls = VideoChatGPTDataset
|
|
|
843 |
|
844 |
DATASET_CONFIG_DICT = {
|
845 |
"default": "configs/datasets/video_chatgpt/default.yaml",
|
|
|
861 |
text_processor=self.text_processors["train"],
|
862 |
vis_root=build_info.vis_root,
|
863 |
ann_paths=build_info.ann_paths,
|
864 |
+
subtitles_path=build_info.subtitles_path,
|
865 |
+
model_name=build_info.model_name
|
866 |
+
)
|
867 |
+
|
868 |
+
return datasets
|
869 |
+
|
870 |
+
@registry.register_builder("Name of the builder as in the config file")
|
871 |
+
class VideoTemplateBuilder(BaseDatasetBuilder):
|
872 |
+
train_dataset_cls = ... # Add the dataset class here
|
873 |
+
|
874 |
+
DATASET_CONFIG_DICT = {
|
875 |
+
"default": "path to the config file",
|
876 |
+
}
|
877 |
+
print(DATASET_CONFIG_DICT)
|
878 |
+
|
879 |
+
def build_datasets(self):
|
880 |
+
# download, split, etc...
|
881 |
+
# only called on 1 GPU/TPU in distributed
|
882 |
+
self.build_processors()
|
883 |
+
|
884 |
+
build_info = self.config.build_info # information from the config file
|
885 |
+
datasets = dict()
|
886 |
+
|
887 |
+
# create datasets
|
888 |
+
dataset_cls = self.train_dataset_cls
|
889 |
+
datasets['train'] = dataset_cls(
|
890 |
+
vis_processor=self.vis_processors["train"], # Add the vis_processor here
|
891 |
+
text_processor=self.text_processors["train"], # Add the text_processor here
|
892 |
+
vis_root=build_info.vis_root, # Add videos path here
|
893 |
+
ann_paths=build_info.ann_paths, # Add annotations path here
|
894 |
+
subtitles_path=build_info.subtitles_path, # Add subtitles path here
|
895 |
+
model_name='llama2' # Add model name here (llama2 or mistral)
|
896 |
)
|
897 |
|
898 |
return datasets
|
minigpt4/datasets/datasets/cc_sbu_dataset.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import webdataset as wds
|
4 |
+
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
5 |
+
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
6 |
+
|
7 |
+
|
8 |
+
class CCSBUDataset(BaseDataset):
|
9 |
+
def __init__(self, vis_processor, text_processor, location):
|
10 |
+
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
11 |
+
|
12 |
+
self.inner_dataset = wds.DataPipeline(
|
13 |
+
wds.ResampledShards(location),
|
14 |
+
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
15 |
+
wds.shuffle(1000, handler=wds.warn_and_continue),
|
16 |
+
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
17 |
+
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
18 |
+
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
19 |
+
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
20 |
+
)
|
21 |
+
|
22 |
+
def to_dict(self, sample):
|
23 |
+
return {
|
24 |
+
"image": sample[0],
|
25 |
+
"answer": self.text_processor(sample[1]["caption"]),
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
class CCSBUAlignDataset(CaptionDataset):
|
30 |
+
|
31 |
+
def __getitem__(self, index):
|
32 |
+
|
33 |
+
# TODO this assumes image input, not general enough
|
34 |
+
ann = self.annotation[index]
|
35 |
+
|
36 |
+
img_file = '{}.jpg'.format(ann["image_id"])
|
37 |
+
image_path = os.path.join(self.vis_root, img_file)
|
38 |
+
image = Image.open(image_path).convert("RGB")
|
39 |
+
|
40 |
+
image = self.vis_processor(image)
|
41 |
+
caption = ann["caption"]
|
42 |
+
|
43 |
+
return {
|
44 |
+
"image": image,
|
45 |
+
"answer": caption,
|
46 |
+
"image_id": self.img_ids[ann["image_id"]],
|
47 |
+
}
|
minigpt4/datasets/datasets/video_datasets.py
CHANGED
@@ -7,7 +7,8 @@
|
|
7 |
|
8 |
import os
|
9 |
from collections import OrderedDict
|
10 |
-
import sys
|
|
|
11 |
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
12 |
from PIL import Image
|
13 |
import random
|
@@ -98,7 +99,7 @@ class __DisplMixin:
|
|
98 |
|
99 |
|
100 |
class CMDVideoDataset(BaseDataset, __DisplMixin):
|
101 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths,
|
102 |
"""
|
103 |
vis_root (string): Root directory of images (e.g. coco/images/)
|
104 |
ann_root (string): directory to store the annotation file
|
@@ -119,51 +120,89 @@ class CMDVideoDataset(BaseDataset, __DisplMixin):
|
|
119 |
'Please provide a depiction of the video.',
|
120 |
'Illustrate what is happening in the video.',
|
121 |
]
|
122 |
-
|
123 |
-
|
124 |
-
self.
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
|
131 |
-
self.
|
132 |
-
self.
|
133 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
def __getitem__(self, index):
|
136 |
ann = self.annotation[index]
|
137 |
video_id = ann["image_id"]
|
138 |
-
|
139 |
-
answer = self.text_processor(ann["caption"])
|
140 |
instruction = random.choice(self.instruction_pool)
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
if sampling_interval == 0:
|
146 |
sampling_interval = 1
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
if len(images) >= self.length:
|
158 |
break
|
159 |
-
|
160 |
-
if len(images)
|
|
|
|
|
|
|
161 |
last_item = images[-1]
|
162 |
while len(images) < self.length:
|
163 |
images.append(last_item)
|
|
|
164 |
images = torch.stack(images)
|
165 |
-
instruction =
|
166 |
-
return
|
167 |
"image": images,
|
168 |
"answer": answer,
|
169 |
"image_id": video_id,
|
@@ -172,10 +211,8 @@ class CMDVideoDataset(BaseDataset, __DisplMixin):
|
|
172 |
}
|
173 |
|
174 |
|
175 |
-
|
176 |
-
|
177 |
class WebVidDataset(BaseDataset, __DisplMixin):
|
178 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths,subtitles_path,add_subtitles=False):
|
179 |
"""
|
180 |
vis_root (string): Root directory of images (e.g. coco/images/)
|
181 |
ann_root (string): directory to store the annotation file
|
@@ -196,10 +233,13 @@ class WebVidDataset(BaseDataset, __DisplMixin):
|
|
196 |
'Please provide a depiction of the video.',
|
197 |
'Illustrate what is happening in the video.',
|
198 |
]
|
199 |
-
self.
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
203 |
self.add_subtitles = add_subtitles
|
204 |
self.videos_has_subtitles = {}
|
205 |
if self.add_subtitles:
|
@@ -207,23 +247,16 @@ class WebVidDataset(BaseDataset, __DisplMixin):
|
|
207 |
for sub in os.listdir(self.subtitle_folder):
|
208 |
video_id = sub.split('.')[0]
|
209 |
self.videos_has_subtitles[video_id] = True
|
210 |
-
for ann in self.annotation:
|
211 |
-
img_id = ann["videoid"]
|
212 |
-
if img_id not in self.img_ids.keys():
|
213 |
-
self.img_ids[img_id] = n
|
214 |
-
n += 1
|
215 |
self.transform = transforms.Compose([
|
216 |
transforms.ToPILImage(),
|
217 |
])
|
218 |
|
219 |
def __getitem__(self, index):
|
220 |
ann = self.annotation[index]
|
221 |
-
|
222 |
video_id = ann["videoid"]
|
223 |
images = []
|
224 |
caption = ann["name"].split('-')[-1].split(':')[-1]
|
225 |
# caption = self.text_processor(caption)
|
226 |
-
|
227 |
video_path = os.path.join(self.vis_root, ann['page_dir'], f'{video_id}.mp4')
|
228 |
has_subtitles = self.videos_has_subtitles.get(video_id, False)
|
229 |
if self.add_subtitles and has_subtitles:
|
@@ -245,20 +278,22 @@ class WebVidDataset(BaseDataset, __DisplMixin):
|
|
245 |
subtitle_text_in_interval = ""
|
246 |
history_subtitles = {}
|
247 |
number_of_sub_words=0
|
|
|
248 |
while cap.isOpened():
|
249 |
ret, frame = cap.read()
|
250 |
if not ret:
|
251 |
break
|
252 |
-
# Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
|
253 |
-
|
254 |
if self.add_subtitles and has_subtitles:
|
255 |
for subtitle in vtt_file:
|
256 |
sub=subtitle.text.replace('\n',' ')
|
257 |
-
if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds)
|
258 |
if not history_subtitles.get(sub,False):
|
259 |
-
|
|
|
|
|
260 |
history_subtitles[sub]=True
|
261 |
-
break
|
262 |
if frame_count % sampling_interval == 0:
|
263 |
frame = self.transform(frame[:,:,::-1])
|
264 |
frame = self.vis_processor(frame)
|
@@ -267,6 +302,7 @@ class WebVidDataset(BaseDataset, __DisplMixin):
|
|
267 |
if self.add_subtitles and has_subtitles and subtitle_text_in_interval != "" and number_of_sub_words<self.max_sub_len:
|
268 |
img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
|
269 |
number_of_sub_words+=len(subtitle_text_in_interval.split(' '))
|
|
|
270 |
subtitle_text_in_interval = ""
|
271 |
frame_count += 1
|
272 |
if len(images) >= self.length:
|
@@ -291,7 +327,7 @@ class WebVidDataset(BaseDataset, __DisplMixin):
|
|
291 |
}
|
292 |
|
293 |
class VideoChatGPTDataset(BaseDataset, __DisplMixin):
|
294 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths,add_subtitles=True
|
295 |
"""
|
296 |
vis_root (string): Root directory of images (e.g. coco/images/)
|
297 |
ann_root (string): directory to store the annotation file
|
@@ -299,12 +335,17 @@ class VideoChatGPTDataset(BaseDataset, __DisplMixin):
|
|
299 |
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
300 |
self.img_ids = {}
|
301 |
n=0
|
302 |
-
self.
|
303 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
304 |
self.add_subtitles = add_subtitles
|
305 |
self.videos_has_subtitles = {}
|
306 |
if self.add_subtitles:
|
307 |
-
self.subtitle_folder =
|
308 |
for sub in os.listdir(self.subtitle_folder):
|
309 |
video_id = sub.split('.')[0]
|
310 |
self.videos_has_subtitles[video_id] = True
|
@@ -315,7 +356,7 @@ class VideoChatGPTDataset(BaseDataset, __DisplMixin):
|
|
315 |
n+= 1
|
316 |
|
317 |
self.videos_extension={}
|
318 |
-
for video in os.listdir(
|
319 |
self.videos_extension[video.split('.')[0]]=video.split('.')[1]
|
320 |
|
321 |
self.transform = transforms.Compose([
|
@@ -336,7 +377,7 @@ class VideoChatGPTDataset(BaseDataset, __DisplMixin):
|
|
336 |
# Load the VTT subtitle file
|
337 |
vtt_file = webvtt.read(subtitle_path)
|
338 |
|
339 |
-
video_path = os.path.join(self.vis_root,
|
340 |
clip = VideoFileClip(video_path)
|
341 |
total_num_frames = int(clip.duration * clip.fps)
|
342 |
clip.close()
|
@@ -349,20 +390,22 @@ class VideoChatGPTDataset(BaseDataset, __DisplMixin):
|
|
349 |
subtitle_text_in_interval = ""
|
350 |
history_subtitles = {}
|
351 |
number_of_sub_words=0
|
|
|
352 |
while cap.isOpened():
|
353 |
ret, frame = cap.read()
|
354 |
if not ret:
|
355 |
break
|
356 |
-
# Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
|
357 |
-
|
358 |
if self.add_subtitles and has_subtitles:
|
359 |
for subtitle in vtt_file:
|
360 |
sub=subtitle.text.replace('\n',' ')
|
361 |
-
if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds)
|
362 |
if not history_subtitles.get(sub,False):
|
363 |
-
|
|
|
|
|
364 |
history_subtitles[sub]=True
|
365 |
-
break
|
366 |
if frame_count % sampling_interval == 0:
|
367 |
frame = self.transform(frame[:,:,::-1])# BGR to RGB
|
368 |
frame = self.vis_processor(frame)
|
@@ -372,6 +415,7 @@ class VideoChatGPTDataset(BaseDataset, __DisplMixin):
|
|
372 |
if subtitle_text_in_interval != "":
|
373 |
img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
|
374 |
number_of_sub_words+=len(subtitle_text_in_interval.split(' '))
|
|
|
375 |
subtitle_text_in_interval = ""
|
376 |
frame_count += 1
|
377 |
if len(images) >= self.length:
|
@@ -513,8 +557,8 @@ class WebVidEvalDataset(torch.utils.data.Dataset):
|
|
513 |
ret, frame = cap.read()
|
514 |
if not ret:
|
515 |
break
|
516 |
-
# Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
|
517 |
-
|
518 |
if self.add_subtitles and has_subtitles:
|
519 |
for subtitle in vtt_file:
|
520 |
sub=subtitle.text.replace('\n',' ')
|
@@ -616,8 +660,8 @@ class VideoChatGPTEvalDataset(torch.utils.data.Dataset):
|
|
616 |
ret, frame = cap.read()
|
617 |
if not ret:
|
618 |
break
|
619 |
-
# Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
|
620 |
-
|
621 |
if self.add_subtitles and subtitle_path is not None:
|
622 |
for subtitle in vtt_file:
|
623 |
sub=subtitle.text.replace('\n',' ')
|
@@ -711,8 +755,8 @@ class Video_validation_Dataset(torch.utils.data.Dataset):
|
|
711 |
ret, frame = cap.read()
|
712 |
if not ret:
|
713 |
break
|
714 |
-
# Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
|
715 |
-
|
716 |
if self.add_subtitles and subtitle_path is not None:
|
717 |
for subtitle in vtt_file:
|
718 |
sub=subtitle.text.replace('\n',' ')
|
@@ -808,8 +852,8 @@ class VideoChatGPTEval_consistancy(torch.utils.data.Dataset):
|
|
808 |
ret, frame = cap.read()
|
809 |
if not ret:
|
810 |
break
|
811 |
-
# Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
|
812 |
-
|
813 |
if self.add_subtitles and subtitle_path is not None:
|
814 |
for subtitle in vtt_file:
|
815 |
sub=subtitle.text.replace('\n',' ')
|
@@ -900,8 +944,8 @@ class TVQAEVAL (torch.utils.data.Dataset):
|
|
900 |
history_subtitles = {}
|
901 |
number_of_sub_words=0
|
902 |
for i,frame in enumerate(sorted(os.listdir(video_frames_path))):
|
903 |
-
# Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
|
904 |
-
|
905 |
if self.add_subtitles:
|
906 |
for subtitle in self.subtitles[video_id]:
|
907 |
if (subtitle['start'] <= (i / self.fps) <= subtitle['end']) and subtitle['text'] not in subtitle_text_in_interval:
|
@@ -934,118 +978,111 @@ class TVQAEVAL (torch.utils.data.Dataset):
|
|
934 |
return images,instruction,answer,self.length,video_id
|
935 |
|
936 |
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
|
|
|
|
|
|
|
|
|
|
945 |
self.length = 90
|
946 |
self.max_sub_len = 800
|
|
|
|
|
|
|
947 |
self.add_subtitles = add_subtitles
|
948 |
-
self.
|
949 |
-
self.
|
950 |
-
|
951 |
-
|
952 |
-
|
|
|
|
|
|
|
|
|
953 |
self.transform = transforms.Compose([
|
954 |
transforms.ToPILImage(),
|
955 |
])
|
956 |
-
self.videos_features_path=videos_features_path
|
957 |
-
self.processed_videos={}
|
958 |
-
self.save_pkl="subtitles" if self.add_subtitles else "no_subtitles"
|
959 |
-
for video_pkl in os.listdir(videos_features_path):
|
960 |
-
video_id_sub=video_pkl.split('.')[0]
|
961 |
-
self.processed_videos[video_id_sub]=True
|
962 |
-
def extract_season_episode(self,video_name):
|
963 |
-
# Define a regex pattern to match season and episode numbers
|
964 |
-
pattern = r's(\d+)e(\d+)'
|
965 |
-
|
966 |
-
# Use re.search to find the pattern in the video name
|
967 |
-
match = re.search(pattern, video_name, re.IGNORECASE)
|
968 |
-
|
969 |
-
if match:
|
970 |
-
# Extract season and episode numbers from the matched groups
|
971 |
-
season_number = int(match.group(1))
|
972 |
-
episode_number = int(match.group(2))
|
973 |
-
return f"season_{season_number}", f"episode_{episode_number}"
|
974 |
-
else:
|
975 |
-
# Return None if the pattern is not found
|
976 |
-
return None, None
|
977 |
-
|
978 |
def __len__(self):
|
979 |
return len(self.annotation)
|
980 |
def __getitem__(self, index):
|
981 |
ann = self.annotation[index]
|
982 |
-
|
983 |
-
|
984 |
-
|
985 |
-
video_id = f"{folder_name}_{season_number}_{episode_number}"
|
986 |
-
answer=str(ann['answer_idx'])
|
987 |
-
instruction=ann["q"]+" \n\n As you watched in this video Choose ONE suitable answer from these mutiple choices \n\n"
|
988 |
-
for i in range(5):
|
989 |
-
ans=ann[f"a{i}"]
|
990 |
-
instruction+=f"option {i}: {ans} \n\n"
|
991 |
-
# instruction+="\n Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 4 INCLUSIVE"
|
992 |
-
instruction+=f"option 5: Can't answer based on the provided information \n\n"
|
993 |
-
instruction+="\n Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 5 INCLUSIVE"
|
994 |
images=[]
|
995 |
img_placeholder = ""
|
996 |
-
|
997 |
-
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
-
else:
|
1002 |
-
video_frames_path = os.path.join(self.videos_path,folder_name,season_number,episode_number)
|
1003 |
-
video_subtitle_path=os.path.join(self.subtitles_path,folder_name,season_number,episode_number+".srt")
|
1004 |
-
video_subtitles=read_subtitles(video_subtitle_path)
|
1005 |
-
total_num_frames=len(os.listdir(video_frames_path))
|
1006 |
-
sampling_interval = round(total_num_frames / self.length)
|
1007 |
-
if sampling_interval == 0:
|
1008 |
-
sampling_interval = 1
|
1009 |
-
subtitle_text_in_interval = ""
|
1010 |
-
history_subtitles = {}
|
1011 |
-
number_of_sub_words=0
|
1012 |
-
number_of_interval_words=0
|
1013 |
-
max_number_of_interval_words=10
|
1014 |
-
for i,frame in enumerate(sorted(os.listdir(video_frames_path))):
|
1015 |
-
# Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
|
1016 |
-
# we choose 1 frame for every 2 seconds,so we need to combine the subtitles in the interval of 2 seconds
|
1017 |
-
if self.add_subtitles:
|
1018 |
-
for subtitle in video_subtitles:
|
1019 |
-
if (srt_time_to_seconds(subtitle.start) <= (i / self.fps) <= srt_time_to_seconds(subtitle.end)) and subtitle.text not in subtitle_text_in_interval:
|
1020 |
-
if not history_subtitles.get(subtitle.text,False) and number_of_interval_words<max_number_of_interval_words:
|
1021 |
-
subtitle_text_in_interval+=subtitle.text+" "
|
1022 |
-
number_of_interval_words+=len(subtitle.text.split(' '))
|
1023 |
-
history_subtitles[subtitle.text]=True
|
1024 |
-
break
|
1025 |
-
if i % sampling_interval == 0:
|
1026 |
-
frame = Image.open(os.path.join(video_frames_path,frame)).convert("RGB")
|
1027 |
-
frame = self.vis_processor(frame)
|
1028 |
-
images.append(frame)
|
1029 |
-
img_placeholder += '<Img><ImageHere>'
|
1030 |
-
if self.add_subtitles and number_of_sub_words<self.max_sub_len:
|
1031 |
-
if subtitle_text_in_interval != "":
|
1032 |
-
img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
|
1033 |
-
number_of_sub_words+=len(subtitle_text_in_interval.split(' '))
|
1034 |
-
subtitle_text_in_interval = ""
|
1035 |
-
if len(images) >= self.length:
|
1036 |
-
break
|
1037 |
-
if len(images) ==0:
|
1038 |
-
print("Video not found",video_frames_path)
|
1039 |
|
1040 |
-
|
1041 |
-
|
1042 |
-
|
1043 |
-
|
1044 |
-
|
1045 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1046 |
|
1047 |
-
|
1048 |
-
|
1049 |
-
self.
|
1050 |
-
|
1051 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
import os
|
9 |
from collections import OrderedDict
|
10 |
+
import sys
|
11 |
+
sys.path.append('/ibex/project/c2090/kirolos/MiniGPT4-video-llama3')
|
12 |
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
13 |
from PIL import Image
|
14 |
import random
|
|
|
99 |
|
100 |
|
101 |
class CMDVideoDataset(BaseDataset, __DisplMixin):
|
102 |
+
def __init__(self, vis_processor, text_processor, vis_root, ann_paths, subtitles_path,model_name='llama2'):
|
103 |
"""
|
104 |
vis_root (string): Root directory of images (e.g. coco/images/)
|
105 |
ann_root (string): directory to store the annotation file
|
|
|
120 |
'Please provide a depiction of the video.',
|
121 |
'Illustrate what is happening in the video.',
|
122 |
]
|
123 |
+
|
124 |
+
self.model_name=model_name
|
125 |
+
if self.model_name =='mistral':
|
126 |
+
self.length = 90
|
127 |
+
self.max_sub_len = 800
|
128 |
+
else:
|
129 |
+
self.length = 45
|
130 |
+
self.max_sub_len = 400
|
131 |
|
132 |
+
self.subtitle_folder = subtitles_path
|
133 |
+
self.videos_has_subtitles={}
|
134 |
+
for sub in os.listdir(self.subtitle_folder):
|
135 |
+
video_id = sub.split('.')[0]
|
136 |
+
self.videos_has_subtitles[video_id] = True
|
137 |
+
self.transform = transforms.Compose([
|
138 |
+
transforms.ToPILImage(),
|
139 |
+
])
|
140 |
|
141 |
def __getitem__(self, index):
|
142 |
ann = self.annotation[index]
|
143 |
video_id = ann["image_id"]
|
144 |
+
answer =ann['caption']
|
|
|
145 |
instruction = random.choice(self.instruction_pool)
|
146 |
+
has_subtitles = self.videos_has_subtitles.get(video_id, False)
|
147 |
+
if has_subtitles:
|
148 |
+
subtitle_path = os.path.join(self.subtitle_folder, f'{video_id}.en.vtt')
|
149 |
+
# Load the VTT subtitle file
|
150 |
+
vtt_file = webvtt.read(subtitle_path)
|
151 |
+
video_path = os.path.join(self.vis_root, f'{video_id}.mp4')
|
152 |
+
clip = VideoFileClip(video_path)
|
153 |
+
total_num_frames = int(clip.duration * clip.fps)
|
154 |
+
clip.close()
|
155 |
+
cap = cv2.VideoCapture(video_path)
|
156 |
+
frame_count = 0
|
157 |
+
sampling_interval = int(total_num_frames / self.length)
|
158 |
if sampling_interval == 0:
|
159 |
sampling_interval = 1
|
160 |
+
img_placeholder = ""
|
161 |
+
subtitle_text_in_interval = ""
|
162 |
+
number_of_sub_words=0
|
163 |
+
images=[]
|
164 |
+
history_subtitles = {}
|
165 |
+
previous_sub = ""
|
166 |
+
while cap.isOpened():
|
167 |
+
ret, frame = cap.read()
|
168 |
+
if not ret:
|
169 |
+
break
|
170 |
+
# Find the corresponding subtitle for the each frame and combine the interval subtitles into one subtitle
|
171 |
+
if has_subtitles:
|
172 |
+
for subtitle in vtt_file:
|
173 |
+
sub=subtitle.text.replace('\n',' ')
|
174 |
+
if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds):
|
175 |
+
if not history_subtitles.get(sub,False):
|
176 |
+
for word in sub.split(' '):
|
177 |
+
if word not in subtitle_text_in_interval and word not in previous_sub:
|
178 |
+
subtitle_text_in_interval+=word+" "
|
179 |
+
history_subtitles[sub]=True
|
180 |
+
if frame_count % sampling_interval == 0:
|
181 |
+
frame = self.transform(frame[:,:,::-1])# BGR to RGB
|
182 |
+
frame = self.vis_processor(frame)
|
183 |
+
images.append(frame)
|
184 |
+
img_placeholder += '<Img><ImageHere>'
|
185 |
+
if has_subtitles and number_of_sub_words<self.max_sub_len:
|
186 |
+
if subtitle_text_in_interval != "":
|
187 |
+
img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
|
188 |
+
number_of_sub_words+=len(subtitle_text_in_interval.split(' '))
|
189 |
+
previous_sub = subtitle_text_in_interval
|
190 |
+
subtitle_text_in_interval = ""
|
191 |
+
frame_count += 1
|
192 |
if len(images) >= self.length:
|
193 |
break
|
194 |
+
cap.release()
|
195 |
+
if len(images) ==0:
|
196 |
+
print("Video not found",video_path)
|
197 |
+
|
198 |
+
if 0 <len(images) < self.length:
|
199 |
last_item = images[-1]
|
200 |
while len(images) < self.length:
|
201 |
images.append(last_item)
|
202 |
+
img_placeholder += '<Img><ImageHere>'
|
203 |
images = torch.stack(images)
|
204 |
+
instruction = img_placeholder + '\n' + instruction
|
205 |
+
return{
|
206 |
"image": images,
|
207 |
"answer": answer,
|
208 |
"image_id": video_id,
|
|
|
211 |
}
|
212 |
|
213 |
|
|
|
|
|
214 |
class WebVidDataset(BaseDataset, __DisplMixin):
|
215 |
+
def __init__(self, vis_processor, text_processor, vis_root, ann_paths,subtitles_path,model_name,add_subtitles=False):
|
216 |
"""
|
217 |
vis_root (string): Root directory of images (e.g. coco/images/)
|
218 |
ann_root (string): directory to store the annotation file
|
|
|
233 |
'Please provide a depiction of the video.',
|
234 |
'Illustrate what is happening in the video.',
|
235 |
]
|
236 |
+
self.model_name=model_name
|
237 |
+
if self.model_name =='mistral':
|
238 |
+
self.length = 90
|
239 |
+
self.max_sub_len = 800
|
240 |
+
else:
|
241 |
+
self.length = 45
|
242 |
+
self.max_sub_len = 400
|
243 |
self.add_subtitles = add_subtitles
|
244 |
self.videos_has_subtitles = {}
|
245 |
if self.add_subtitles:
|
|
|
247 |
for sub in os.listdir(self.subtitle_folder):
|
248 |
video_id = sub.split('.')[0]
|
249 |
self.videos_has_subtitles[video_id] = True
|
|
|
|
|
|
|
|
|
|
|
250 |
self.transform = transforms.Compose([
|
251 |
transforms.ToPILImage(),
|
252 |
])
|
253 |
|
254 |
def __getitem__(self, index):
|
255 |
ann = self.annotation[index]
|
|
|
256 |
video_id = ann["videoid"]
|
257 |
images = []
|
258 |
caption = ann["name"].split('-')[-1].split(':')[-1]
|
259 |
# caption = self.text_processor(caption)
|
|
|
260 |
video_path = os.path.join(self.vis_root, ann['page_dir'], f'{video_id}.mp4')
|
261 |
has_subtitles = self.videos_has_subtitles.get(video_id, False)
|
262 |
if self.add_subtitles and has_subtitles:
|
|
|
278 |
subtitle_text_in_interval = ""
|
279 |
history_subtitles = {}
|
280 |
number_of_sub_words=0
|
281 |
+
previous_sub = ""
|
282 |
while cap.isOpened():
|
283 |
ret, frame = cap.read()
|
284 |
if not ret:
|
285 |
break
|
286 |
+
# Find the corresponding subtitle for the each frame and combine the interval subtitles into one subtitle
|
287 |
+
|
288 |
if self.add_subtitles and has_subtitles:
|
289 |
for subtitle in vtt_file:
|
290 |
sub=subtitle.text.replace('\n',' ')
|
291 |
+
if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds):
|
292 |
if not history_subtitles.get(sub,False):
|
293 |
+
for word in sub.split(' '):
|
294 |
+
if word not in subtitle_text_in_interval and word not in previous_sub:
|
295 |
+
subtitle_text_in_interval+=word+" "
|
296 |
history_subtitles[sub]=True
|
|
|
297 |
if frame_count % sampling_interval == 0:
|
298 |
frame = self.transform(frame[:,:,::-1])
|
299 |
frame = self.vis_processor(frame)
|
|
|
302 |
if self.add_subtitles and has_subtitles and subtitle_text_in_interval != "" and number_of_sub_words<self.max_sub_len:
|
303 |
img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
|
304 |
number_of_sub_words+=len(subtitle_text_in_interval.split(' '))
|
305 |
+
previous_sub = subtitle_text_in_interval
|
306 |
subtitle_text_in_interval = ""
|
307 |
frame_count += 1
|
308 |
if len(images) >= self.length:
|
|
|
327 |
}
|
328 |
|
329 |
class VideoChatGPTDataset(BaseDataset, __DisplMixin):
|
330 |
+
def __init__(self, vis_processor, text_processor, vis_root, ann_paths,subtitles_path,model_name='llama2',add_subtitles=True):
|
331 |
"""
|
332 |
vis_root (string): Root directory of images (e.g. coco/images/)
|
333 |
ann_root (string): directory to store the annotation file
|
|
|
335 |
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
336 |
self.img_ids = {}
|
337 |
n=0
|
338 |
+
self.model_name=model_name
|
339 |
+
if self.model_name =='mistral':
|
340 |
+
self.length = 90
|
341 |
+
self.max_sub_len = 800
|
342 |
+
else:
|
343 |
+
self.length = 45
|
344 |
+
self.max_sub_len = 400
|
345 |
self.add_subtitles = add_subtitles
|
346 |
self.videos_has_subtitles = {}
|
347 |
if self.add_subtitles:
|
348 |
+
self.subtitle_folder = subtitles_path
|
349 |
for sub in os.listdir(self.subtitle_folder):
|
350 |
video_id = sub.split('.')[0]
|
351 |
self.videos_has_subtitles[video_id] = True
|
|
|
356 |
n+= 1
|
357 |
|
358 |
self.videos_extension={}
|
359 |
+
for video in os.listdir(self.vis_root):
|
360 |
self.videos_extension[video.split('.')[0]]=video.split('.')[1]
|
361 |
|
362 |
self.transform = transforms.Compose([
|
|
|
377 |
# Load the VTT subtitle file
|
378 |
vtt_file = webvtt.read(subtitle_path)
|
379 |
|
380 |
+
video_path = os.path.join(self.vis_root,f'{video_id}.{self.videos_extension[video_id]}')
|
381 |
clip = VideoFileClip(video_path)
|
382 |
total_num_frames = int(clip.duration * clip.fps)
|
383 |
clip.close()
|
|
|
390 |
subtitle_text_in_interval = ""
|
391 |
history_subtitles = {}
|
392 |
number_of_sub_words=0
|
393 |
+
previous_sub = ""
|
394 |
while cap.isOpened():
|
395 |
ret, frame = cap.read()
|
396 |
if not ret:
|
397 |
break
|
398 |
+
# Find the corresponding subtitle for the each frame and combine the interval subtitles into one subtitle
|
399 |
+
|
400 |
if self.add_subtitles and has_subtitles:
|
401 |
for subtitle in vtt_file:
|
402 |
sub=subtitle.text.replace('\n',' ')
|
403 |
+
if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds):
|
404 |
if not history_subtitles.get(sub,False):
|
405 |
+
for word in sub.split(' '):
|
406 |
+
if word not in subtitle_text_in_interval and word not in previous_sub:
|
407 |
+
subtitle_text_in_interval+=word+" "
|
408 |
history_subtitles[sub]=True
|
|
|
409 |
if frame_count % sampling_interval == 0:
|
410 |
frame = self.transform(frame[:,:,::-1])# BGR to RGB
|
411 |
frame = self.vis_processor(frame)
|
|
|
415 |
if subtitle_text_in_interval != "":
|
416 |
img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
|
417 |
number_of_sub_words+=len(subtitle_text_in_interval.split(' '))
|
418 |
+
previous_sub = subtitle_text_in_interval
|
419 |
subtitle_text_in_interval = ""
|
420 |
frame_count += 1
|
421 |
if len(images) >= self.length:
|
|
|
557 |
ret, frame = cap.read()
|
558 |
if not ret:
|
559 |
break
|
560 |
+
# Find the corresponding subtitle for the each frame and combine the interval subtitles into one subtitle
|
561 |
+
|
562 |
if self.add_subtitles and has_subtitles:
|
563 |
for subtitle in vtt_file:
|
564 |
sub=subtitle.text.replace('\n',' ')
|
|
|
660 |
ret, frame = cap.read()
|
661 |
if not ret:
|
662 |
break
|
663 |
+
# Find the corresponding subtitle for the each frame and combine the interval subtitles into one subtitle
|
664 |
+
|
665 |
if self.add_subtitles and subtitle_path is not None:
|
666 |
for subtitle in vtt_file:
|
667 |
sub=subtitle.text.replace('\n',' ')
|
|
|
755 |
ret, frame = cap.read()
|
756 |
if not ret:
|
757 |
break
|
758 |
+
# Find the corresponding subtitle for the each frame and combine the interval subtitles into one subtitle
|
759 |
+
|
760 |
if self.add_subtitles and subtitle_path is not None:
|
761 |
for subtitle in vtt_file:
|
762 |
sub=subtitle.text.replace('\n',' ')
|
|
|
852 |
ret, frame = cap.read()
|
853 |
if not ret:
|
854 |
break
|
855 |
+
# Find the corresponding subtitle for the each frame and combine the interval subtitles into one subtitle
|
856 |
+
|
857 |
if self.add_subtitles and subtitle_path is not None:
|
858 |
for subtitle in vtt_file:
|
859 |
sub=subtitle.text.replace('\n',' ')
|
|
|
944 |
history_subtitles = {}
|
945 |
number_of_sub_words=0
|
946 |
for i,frame in enumerate(sorted(os.listdir(video_frames_path))):
|
947 |
+
# Find the corresponding subtitle for the each frame and combine the interval subtitles into one subtitle
|
948 |
+
|
949 |
if self.add_subtitles:
|
950 |
for subtitle in self.subtitles[video_id]:
|
951 |
if (subtitle['start'] <= (i / self.fps) <= subtitle['end']) and subtitle['text'] not in subtitle_text_in_interval:
|
|
|
978 |
return images,instruction,answer,self.length,video_id
|
979 |
|
980 |
|
981 |
+
|
982 |
+
|
983 |
+
|
984 |
+
|
985 |
+
class Video_loader_template(BaseDataset, __DisplMixin):
|
986 |
+
def __init__(self, vis_processor, text_processor, vis_root, ann_paths,subtitles_path,model_name='llama2',add_subtitles=True):
|
987 |
+
"""
|
988 |
+
vis_root (string): Root directory of images (e.g. coco/images/)
|
989 |
+
ann_root (string): directory to store the annotation file
|
990 |
+
"""
|
991 |
+
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
992 |
+
self.model_name=model_name
|
993 |
+
if self.model_name =='mistral':
|
994 |
self.length = 90
|
995 |
self.max_sub_len = 800
|
996 |
+
else:
|
997 |
+
self.length = 45
|
998 |
+
self.max_sub_len = 400
|
999 |
self.add_subtitles = add_subtitles
|
1000 |
+
self.videos_has_subtitles = {}
|
1001 |
+
if self.add_subtitles:
|
1002 |
+
self.subtitle_folder = subtitles_path
|
1003 |
+
for sub in os.listdir(self.subtitle_folder):
|
1004 |
+
video_id = sub.split('.')[0]
|
1005 |
+
self.videos_has_subtitles[video_id] = True
|
1006 |
+
self.videos_extension={}
|
1007 |
+
for video in os.listdir(os.path.join(self.vis_root,'videos')):
|
1008 |
+
self.videos_extension[video.split('.')[0]]=video.split('.')[1]
|
1009 |
self.transform = transforms.Compose([
|
1010 |
transforms.ToPILImage(),
|
1011 |
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1012 |
def __len__(self):
|
1013 |
return len(self.annotation)
|
1014 |
def __getitem__(self, index):
|
1015 |
ann = self.annotation[index]
|
1016 |
+
video_id = ann["video_id"] # video_id
|
1017 |
+
answer=ann["a"] # answer (ground truth)
|
1018 |
+
instruction=ann["q"] # question (instruction)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1019 |
images=[]
|
1020 |
img_placeholder = ""
|
1021 |
+
has_subtitles = self.videos_has_subtitles.get(video_id, False)
|
1022 |
+
if self.add_subtitles and has_subtitles:
|
1023 |
+
subtitle_path = os.path.join(self.subtitle_folder, f'{video_id}.vtt')
|
1024 |
+
# Load the VTT subtitle file
|
1025 |
+
vtt_file = webvtt.read(subtitle_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1026 |
|
1027 |
+
video_path = os.path.join(self.vis_root,'videos',f'{video_id}.{self.videos_extension[video_id]}')
|
1028 |
+
clip = VideoFileClip(video_path)
|
1029 |
+
total_num_frames = int(clip.duration * clip.fps)
|
1030 |
+
clip.close()
|
1031 |
+
cap = cv2.VideoCapture(video_path)
|
1032 |
+
frame_count = 0
|
1033 |
+
# Choose sampling interval based on the total number of frames in the video and the desired length of the video
|
1034 |
+
sampling_interval = int(total_num_frames / self.length)
|
1035 |
+
if sampling_interval == 0:
|
1036 |
+
sampling_interval = 1
|
1037 |
+
img_placeholder = ""
|
1038 |
+
subtitle_text_in_interval = ""
|
1039 |
+
history_subtitles = {}
|
1040 |
+
number_of_sub_words=0
|
1041 |
+
# Iterate through the video frames and extract the frames based on the sampling interval and add the subtitles if needed
|
1042 |
+
while cap.isOpened():
|
1043 |
+
ret, frame = cap.read()
|
1044 |
+
if not ret:
|
1045 |
+
break
|
1046 |
+
# Find the corresponding subtitle for the each frame and combine the interval subtitles into one subtitle
|
1047 |
+
if self.add_subtitles and has_subtitles:
|
1048 |
+
for subtitle in vtt_file:
|
1049 |
+
sub=subtitle.text.replace('\n',' ')
|
1050 |
+
if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds) and sub not in subtitle_text_in_interval:
|
1051 |
+
if not history_subtitles.get(sub,False):
|
1052 |
+
subtitle_text_in_interval+=sub+" "
|
1053 |
+
history_subtitles[sub]=True
|
1054 |
+
break
|
1055 |
+
if frame_count % sampling_interval == 0:
|
1056 |
+
frame = self.transform(frame[:,:,::-1])# BGR to RGB
|
1057 |
+
frame = self.vis_processor(frame)
|
1058 |
+
images.append(frame)
|
1059 |
+
img_placeholder += '<Img><ImageHere>'
|
1060 |
+
if self.add_subtitles and has_subtitles and number_of_sub_words<self.max_sub_len:
|
1061 |
+
if subtitle_text_in_interval != "":
|
1062 |
+
img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
|
1063 |
+
number_of_sub_words+=len(subtitle_text_in_interval.split(' '))
|
1064 |
+
subtitle_text_in_interval = ""
|
1065 |
+
frame_count += 1
|
1066 |
+
if len(images) >= self.length:
|
1067 |
+
break
|
1068 |
+
cap.release()
|
1069 |
+
if len(images) ==0:
|
1070 |
+
print("Video not found",video_path)
|
1071 |
|
1072 |
+
if 0 <len(images) < self.length:
|
1073 |
+
last_item = images[-1]
|
1074 |
+
while len(images) < self.length:
|
1075 |
+
images.append(last_item)
|
1076 |
+
img_placeholder += '<Img><ImageHere>'
|
1077 |
+
images = torch.stack(images)
|
1078 |
+
# Combine the images and the instruction
|
1079 |
+
instruction = img_placeholder + '\n' + instruction
|
1080 |
+
# Return the images, instruction, answer, video_id, and the length of the video
|
1081 |
+
return{
|
1082 |
+
"image": images,
|
1083 |
+
"answer": answer,
|
1084 |
+
"image_id": video_id,
|
1085 |
+
"instruction_input": instruction,
|
1086 |
+
"length": self.length,
|
1087 |
+
}
|
1088 |
+
|
minigpt4/models/mini_gpt4_llama_v2.py
CHANGED
@@ -87,11 +87,6 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
87 |
self.max_context_len = max_context_len
|
88 |
self.chat_template = chat_template
|
89 |
|
90 |
-
# print('Loading VIT')
|
91 |
-
# self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
92 |
-
# vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
93 |
-
# )
|
94 |
-
|
95 |
if freeze_vit:
|
96 |
# vit_precision="fp32"
|
97 |
print("vit precision", vit_precision)
|
@@ -147,18 +142,6 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
147 |
# device_map={'':0}
|
148 |
|
149 |
)
|
150 |
-
# bnb_config = BitsAndBytesConfig(
|
151 |
-
# load_in_4bit=True,
|
152 |
-
# bnb_4bit_use_double_quant=True,
|
153 |
-
# bnb_4bit_quant_type="nf4",
|
154 |
-
# bnb_4bit_compute_dtype=torch.bfloat16,
|
155 |
-
# )
|
156 |
-
# self.llama_model = llm_model.from_pretrained(
|
157 |
-
# llama_model,
|
158 |
-
# torch_dtype=torch.bfloat16,
|
159 |
-
# device_map={'':torch.cuda.current_device()},
|
160 |
-
# quantization_config=bnb_config,
|
161 |
-
# )
|
162 |
else:
|
163 |
self.llama_model = llm_model.from_pretrained(
|
164 |
llama_model,
|
@@ -182,24 +165,10 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
182 |
)
|
183 |
self.llama_model = get_peft_model(self.llama_model, loraconfig)
|
184 |
|
185 |
-
# if ckpt_path:
|
186 |
-
# print('load the llm under lora')
|
187 |
-
# ckpt = torch.load(ckpt_path)
|
188 |
-
# set_peft_model_state_dict(self.llama_model,ckpt)
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
self.llama_model.print_trainable_parameters()
|
193 |
|
194 |
if self.use_grad_checkpoint_llm:
|
195 |
self.llama_model.gradient_checkpointing_enable()
|
196 |
-
|
197 |
-
# if not self.low_resource:
|
198 |
-
# for name, param in self.llama_model.named_parameters():
|
199 |
-
# if "embed_token" in name:
|
200 |
-
# param.data = param.data.float()
|
201 |
-
# param.requires_grad = True
|
202 |
-
|
203 |
|
204 |
print('Loading LLAMA Done')
|
205 |
|
@@ -256,15 +225,6 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
256 |
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
257 |
|
258 |
mixed_embs = torch.cat(mixed_embs, dim=1)
|
259 |
-
# # truncate the length of tokens to the max context window
|
260 |
-
# mixed_embs_without_instruction = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair]
|
261 |
-
# mixed_embs_without_instruction=torch.cat(mixed_embs_without_instruction, dim=1)
|
262 |
-
# # check if the number of token in the second dimention is more than the max context window then truncate it
|
263 |
-
# context_window=self.max_context_len-seg_embs[-1].shape[1]
|
264 |
-
# if mixed_embs_without_instruction.shape[1] > context_window :
|
265 |
-
# mixed_embs_without_instruction = mixed_embs_without_instruction[:, 0:context_window]
|
266 |
-
# mixed_embs=torch.cat([mixed_embs_without_instruction,seg_embs[-1]], dim=1)
|
267 |
-
# print("mixed_embs",mixed_embs.shape)
|
268 |
|
269 |
return mixed_embs
|
270 |
|
@@ -288,7 +248,8 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
288 |
else:
|
289 |
# return the multi-modal embedding in right padding
|
290 |
emb_lists = []
|
291 |
-
|
|
|
292 |
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
|
293 |
pn = each_img_embed.shape[-2]
|
294 |
if lengths is not None:
|
@@ -299,12 +260,8 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
299 |
interleave_emb = []
|
300 |
for idx, seg in enumerate(p_segs[:-1]):
|
301 |
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
302 |
-
# print("p_embed device",p_tokens.input_ids.device)
|
303 |
-
# print("p_tokens",img_embeds.device)
|
304 |
-
# print("emb layer", list(self.llama_model.base_model.model.model.embed_tokens.parameters())[0].device)
|
305 |
p_embed = self.embed_tokens(p_tokens.input_ids)
|
306 |
|
307 |
-
# print("model device",self.llama_model.get_device())
|
308 |
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
|
309 |
|
310 |
wrapped_emb = torch.cat(interleave_emb, dim=1)
|
@@ -356,17 +313,6 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
356 |
input_atts[i][input_len:]
|
357 |
])
|
358 |
)
|
359 |
-
# print('===================================')
|
360 |
-
# print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones])
|
361 |
-
# print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2])
|
362 |
-
# print('check out emb: ', output_embs[i][:2])
|
363 |
-
# print('check out pad emb: ', output_embs[i][-2:])
|
364 |
-
# print('+++++++++++++++++++++++++++++++++++')
|
365 |
-
#
|
366 |
-
# print('check attn before: ', input_atts[i][:this_input_ones])
|
367 |
-
# print('check attn after: ', input_atts[i][this_input_ones:])
|
368 |
-
# print('check attn gt before: ', output_atts[i][:3])
|
369 |
-
# print('check attn gt after: ', output_atts[i][-3:])
|
370 |
|
371 |
cat_embs = torch.stack(cat_embs)
|
372 |
cat_atts = torch.stack(cat_atts)
|
@@ -433,7 +379,6 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
433 |
### prepare input tokens
|
434 |
if 'image' in samples:
|
435 |
img_embeds, img_atts = self.encode_img(samples["image"])
|
436 |
-
# print("img_embeds shape",img_embeds.shape)
|
437 |
else:
|
438 |
img_embeds = img_atts = None
|
439 |
|
@@ -453,12 +398,15 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
453 |
cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
|
454 |
|
455 |
else:
|
456 |
-
|
|
|
|
|
|
|
|
|
|
|
457 |
|
458 |
-
# print("instruction before", instruction)
|
459 |
if self.remove_template:
|
460 |
instruction = remove_special_tokens(instruction)
|
461 |
-
# print("instruction after", instruction)
|
462 |
|
463 |
if self.chat_template:
|
464 |
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
|
@@ -502,9 +450,6 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
502 |
# concat the embedding to condition and the embedding to regress
|
503 |
inputs_embeds, attention_mask, input_lens = \
|
504 |
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
|
505 |
-
print("inputs_embeds shape",inputs_embeds.shape)
|
506 |
-
print("cond_embeds shape",cond_embeds.shape)
|
507 |
-
print("regress_embeds shape",regress_embeds.shape)
|
508 |
# get bos token embedding
|
509 |
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
|
510 |
bos_embeds = self.embed_tokens(bos)
|
@@ -513,16 +458,12 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
513 |
# add bos token at the begining
|
514 |
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
515 |
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
|
516 |
-
|
517 |
-
# for i in range (len(samples["instruction_input"])):
|
518 |
-
# print("instruction_input length",len(samples["instruction_input"][i].split(" ")))
|
519 |
-
# print("answer length",len(samples["answer"][i].split(" ")))
|
520 |
-
# ensemble the final targets
|
521 |
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
522 |
dtype=torch.long).to(self.device).fill_(-100)
|
523 |
for i, target in enumerate(part_targets):
|
524 |
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
|
525 |
-
|
526 |
with self.maybe_autocast():
|
527 |
outputs = self.llama_model(
|
528 |
inputs_embeds=inputs_embeds,
|
@@ -569,7 +510,6 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
569 |
img_embeds = self.llama_proj(img_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
|
570 |
atts_img = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
571 |
|
572 |
-
print("img_embeds shape",img_embeds.shape)
|
573 |
if lengths is not None:
|
574 |
image_lists = []
|
575 |
img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
|
@@ -592,8 +532,6 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
592 |
emb_len = emb.shape[1]
|
593 |
embs[i, -emb_len:] = emb[0]
|
594 |
attn_mask[i, -emb_len:] = 1
|
595 |
-
# print("inputs_embeds shape",embs.shape)
|
596 |
-
# print("attention_mask shape",attn_mask.shape)
|
597 |
# check if the input embedding tokens are in the range of the model cotext window (4096) and if it is not, then truncate it to the max context window
|
598 |
if self.model_type == "Llama":
|
599 |
context_window = 3700
|
@@ -602,8 +540,6 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
602 |
if embs.shape[1] > context_window:
|
603 |
embs = embs[:, -context_window:]
|
604 |
attn_mask = attn_mask[:, -context_window:]
|
605 |
-
print("inputs_embeds shape",embs.shape)
|
606 |
-
print("attention_mask shape",attn_mask.shape)
|
607 |
with self.maybe_autocast():
|
608 |
if return_video_temporal_features:
|
609 |
last_hidden_state = self.llama_model(
|
@@ -665,15 +601,8 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
665 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
|
666 |
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
|
667 |
|
668 |
-
# seg_tokens=[]
|
669 |
-
# for i, text in enumerate(texts):
|
670 |
-
# seg_tokens.append(self.llama_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(self.device).input_ids)
|
671 |
-
|
672 |
batch_embs = [torch.cat([self.embed_tokens(seg_t)]) for seg_t in seg_tokens]
|
673 |
|
674 |
-
# seg_embs = torch.cat(seg_embs, dim=1)
|
675 |
-
# print("seg_embs shape",seg_embs.shape)
|
676 |
-
# batch_embs=[seg_embs]
|
677 |
batch_size = len(batch_embs)
|
678 |
max_len = max([emb.shape[1] for emb in batch_embs])
|
679 |
emb_dim = batch_embs[0].shape[2]
|
@@ -687,9 +616,6 @@ class MiniGPT4_llama_v2(Blip2Base):
|
|
687 |
embs[i, -emb_len:] = emb[0]
|
688 |
attn_mask[i, -emb_len:] = 1
|
689 |
|
690 |
-
|
691 |
-
print("inputs_embeds shape",embs.shape)
|
692 |
-
print("attention_mask shape",attn_mask.shape)
|
693 |
with self.maybe_autocast():
|
694 |
outputs = self.llama_model.generate(
|
695 |
inputs_embeds=embs,
|
@@ -892,4 +818,4 @@ def assign_imgs(batched_instruct_list, batched_img_embeds):
|
|
892 |
n_assigned.append(None)
|
893 |
batched_assigned.append(assigned_img)
|
894 |
|
895 |
-
return batched_assigned
|
|
|
87 |
self.max_context_len = max_context_len
|
88 |
self.chat_template = chat_template
|
89 |
|
|
|
|
|
|
|
|
|
|
|
90 |
if freeze_vit:
|
91 |
# vit_precision="fp32"
|
92 |
print("vit precision", vit_precision)
|
|
|
142 |
# device_map={'':0}
|
143 |
|
144 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
else:
|
146 |
self.llama_model = llm_model.from_pretrained(
|
147 |
llama_model,
|
|
|
165 |
)
|
166 |
self.llama_model = get_peft_model(self.llama_model, loraconfig)
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
self.llama_model.print_trainable_parameters()
|
169 |
|
170 |
if self.use_grad_checkpoint_llm:
|
171 |
self.llama_model.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
print('Loading LLAMA Done')
|
174 |
|
|
|
225 |
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
226 |
|
227 |
mixed_embs = torch.cat(mixed_embs, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
return mixed_embs
|
230 |
|
|
|
248 |
else:
|
249 |
# return the multi-modal embedding in right padding
|
250 |
emb_lists = []
|
251 |
+
if type(prompts) == str:
|
252 |
+
prompts = [prompts] * len(img_embeds)
|
253 |
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
|
254 |
pn = each_img_embed.shape[-2]
|
255 |
if lengths is not None:
|
|
|
260 |
interleave_emb = []
|
261 |
for idx, seg in enumerate(p_segs[:-1]):
|
262 |
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
|
|
|
|
|
|
263 |
p_embed = self.embed_tokens(p_tokens.input_ids)
|
264 |
|
|
|
265 |
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
|
266 |
|
267 |
wrapped_emb = torch.cat(interleave_emb, dim=1)
|
|
|
313 |
input_atts[i][input_len:]
|
314 |
])
|
315 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
cat_embs = torch.stack(cat_embs)
|
318 |
cat_atts = torch.stack(cat_atts)
|
|
|
379 |
### prepare input tokens
|
380 |
if 'image' in samples:
|
381 |
img_embeds, img_atts = self.encode_img(samples["image"])
|
|
|
382 |
else:
|
383 |
img_embeds = img_atts = None
|
384 |
|
|
|
398 |
cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
|
399 |
|
400 |
else:
|
401 |
+
if "instruction_input" in samples:
|
402 |
+
instruction = samples["instruction_input"]
|
403 |
+
elif len(self.prompt_list) > 1:
|
404 |
+
instruction = random.choice(self.prompt_list)
|
405 |
+
else:
|
406 |
+
instruction = None
|
407 |
|
|
|
408 |
if self.remove_template:
|
409 |
instruction = remove_special_tokens(instruction)
|
|
|
410 |
|
411 |
if self.chat_template:
|
412 |
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
|
|
|
450 |
# concat the embedding to condition and the embedding to regress
|
451 |
inputs_embeds, attention_mask, input_lens = \
|
452 |
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
|
|
|
|
|
|
|
453 |
# get bos token embedding
|
454 |
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
|
455 |
bos_embeds = self.embed_tokens(bos)
|
|
|
458 |
# add bos token at the begining
|
459 |
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
460 |
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
|
461 |
+
|
|
|
|
|
|
|
|
|
462 |
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
463 |
dtype=torch.long).to(self.device).fill_(-100)
|
464 |
for i, target in enumerate(part_targets):
|
465 |
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
|
466 |
+
|
467 |
with self.maybe_autocast():
|
468 |
outputs = self.llama_model(
|
469 |
inputs_embeds=inputs_embeds,
|
|
|
510 |
img_embeds = self.llama_proj(img_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
|
511 |
atts_img = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
512 |
|
|
|
513 |
if lengths is not None:
|
514 |
image_lists = []
|
515 |
img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
|
|
|
532 |
emb_len = emb.shape[1]
|
533 |
embs[i, -emb_len:] = emb[0]
|
534 |
attn_mask[i, -emb_len:] = 1
|
|
|
|
|
535 |
# check if the input embedding tokens are in the range of the model cotext window (4096) and if it is not, then truncate it to the max context window
|
536 |
if self.model_type == "Llama":
|
537 |
context_window = 3700
|
|
|
540 |
if embs.shape[1] > context_window:
|
541 |
embs = embs[:, -context_window:]
|
542 |
attn_mask = attn_mask[:, -context_window:]
|
|
|
|
|
543 |
with self.maybe_autocast():
|
544 |
if return_video_temporal_features:
|
545 |
last_hidden_state = self.llama_model(
|
|
|
601 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
|
602 |
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
|
603 |
|
|
|
|
|
|
|
|
|
604 |
batch_embs = [torch.cat([self.embed_tokens(seg_t)]) for seg_t in seg_tokens]
|
605 |
|
|
|
|
|
|
|
606 |
batch_size = len(batch_embs)
|
607 |
max_len = max([emb.shape[1] for emb in batch_embs])
|
608 |
emb_dim = batch_embs[0].shape[2]
|
|
|
616 |
embs[i, -emb_len:] = emb[0]
|
617 |
attn_mask[i, -emb_len:] = 1
|
618 |
|
|
|
|
|
|
|
619 |
with self.maybe_autocast():
|
620 |
outputs = self.llama_model.generate(
|
621 |
inputs_embeds=embs,
|
|
|
818 |
n_assigned.append(None)
|
819 |
batched_assigned.append(assigned_img)
|
820 |
|
821 |
+
return batched_assigned
|
minigpt4/runners/runner_base.py
CHANGED
@@ -428,10 +428,10 @@ class RunnerBase:
|
|
428 |
# wandb.log({"epoch": cur_epoch, "GPT4_Accuracy": val_log['agg_metrics']})
|
429 |
# print("Validation finished")
|
430 |
|
431 |
-
else:
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
|
436 |
if self.evaluate_only:
|
437 |
break
|
|
|
428 |
# wandb.log({"epoch": cur_epoch, "GPT4_Accuracy": val_log['agg_metrics']})
|
429 |
# print("Validation finished")
|
430 |
|
431 |
+
# else:
|
432 |
+
# if no validation split is provided, we just save the checkpoint at the end of each epoch.
|
433 |
+
if not self.evaluate_only:
|
434 |
+
self._save_checkpoint(cur_epoch, is_best=False)
|
435 |
|
436 |
if self.evaluate_only:
|
437 |
break
|
minigpt4_video_demo.py
CHANGED
@@ -155,7 +155,7 @@ def run (video_path,instruction,model,vis_processor,gen_subtitles=True):
|
|
155 |
subtitle_path=None
|
156 |
prepared_images,prepared_instruction=prepare_input(vis_processor,video_path,subtitle_path,instruction)
|
157 |
if prepared_images is None:
|
158 |
-
return "
|
159 |
length=len(prepared_images)
|
160 |
prepared_images=prepared_images.unsqueeze(0)
|
161 |
conv = CONV_VISION.copy()
|
@@ -166,10 +166,10 @@ def run (video_path,instruction,model,vis_processor,gen_subtitles=True):
|
|
166 |
prompt = [conv.get_prompt()]
|
167 |
answers = model.generate(prepared_images, prompt, max_new_tokens=args.max_new_tokens, do_sample=True, lengths=[length],num_beams=2)
|
168 |
# remove the subtitle file and the video file
|
169 |
-
if subtitle_path:
|
170 |
-
|
171 |
-
#if video_path.split('.')[-1] == 'mp4' or video_path.split('.')[-1] == 'mkv' or video_path.split('.')[-1] == 'avi':
|
172 |
-
#
|
173 |
return answers[0]
|
174 |
|
175 |
def run_single_image (image_path,instruction,model,vis_processor):
|
@@ -268,7 +268,7 @@ description = """<h5>This is the demo of MiniGPT4-video Model.</h5>"""
|
|
268 |
project_page = """<p><a href='https://vision-cair.github.io/MiniGPT4-video/'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
|
269 |
code_link="""<p><a href='https://github.com/Vision-CAIR/MiniGPT4-video'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p>"""
|
270 |
paper_link="""<p><a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>"""
|
271 |
-
|
272 |
with gr.Blocks(title="MiniGPT4-video 🎞️🍿",css=text_css ) as demo :
|
273 |
# with gr.Row():
|
274 |
# with gr.Column(scale=2):
|
@@ -330,7 +330,7 @@ with gr.Blocks(title="MiniGPT4-video 🎞️🍿",css=text_css ) as demo :
|
|
330 |
# )
|
331 |
with gr.Row():
|
332 |
with gr.Column():
|
333 |
-
youtube_link = gr.Textbox(label="Enter the youtube link", placeholder="Paste YouTube URL
|
334 |
video_player = gr.Video(autoplay=False)
|
335 |
download_finish = gr.State(value=False)
|
336 |
youtube_link.change(
|
|
|
155 |
subtitle_path=None
|
156 |
prepared_images,prepared_instruction=prepare_input(vis_processor,video_path,subtitle_path,instruction)
|
157 |
if prepared_images is None:
|
158 |
+
return "Please re-upload the video while changing the instructions."
|
159 |
length=len(prepared_images)
|
160 |
prepared_images=prepared_images.unsqueeze(0)
|
161 |
conv = CONV_VISION.copy()
|
|
|
166 |
prompt = [conv.get_prompt()]
|
167 |
answers = model.generate(prepared_images, prompt, max_new_tokens=args.max_new_tokens, do_sample=True, lengths=[length],num_beams=2)
|
168 |
# remove the subtitle file and the video file
|
169 |
+
# if subtitle_path:
|
170 |
+
# os.system(f"rm {subtitle_path}")
|
171 |
+
# if video_path.split('.')[-1] == 'mp4' or video_path.split('.')[-1] == 'mkv' or video_path.split('.')[-1] == 'avi':
|
172 |
+
# os.system(f"rm {video_path}")
|
173 |
return answers[0]
|
174 |
|
175 |
def run_single_image (image_path,instruction,model,vis_processor):
|
|
|
268 |
project_page = """<p><a href='https://vision-cair.github.io/MiniGPT4-video/'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
|
269 |
code_link="""<p><a href='https://github.com/Vision-CAIR/MiniGPT4-video'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p>"""
|
270 |
paper_link="""<p><a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>"""
|
271 |
+
video_path=""
|
272 |
with gr.Blocks(title="MiniGPT4-video 🎞️🍿",css=text_css ) as demo :
|
273 |
# with gr.Row():
|
274 |
# with gr.Column(scale=2):
|
|
|
330 |
# )
|
331 |
with gr.Row():
|
332 |
with gr.Column():
|
333 |
+
youtube_link = gr.Textbox(label="Enter the youtube link", placeholder="Paste YouTube URL with this format 'https://www.youtube.com/watch?v=video_id'")
|
334 |
video_player = gr.Video(autoplay=False)
|
335 |
download_finish = gr.State(value=False)
|
336 |
youtube_link.change(
|
minigpt4_video_inference.py
CHANGED
@@ -1,94 +1,180 @@
|
|
1 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
3 |
from pytubefix import YouTube
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
except Exception as e:
|
31 |
-
print(
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
# Get the video's available captions (subtitles).
|
37 |
-
captions = yt.captions.all()
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
end_time = int(p.get("t"))
|
71 |
-
# Format and append the VTT entry to the list
|
72 |
-
vtt_subtitle.append(f"{ms_to_vtt_time(start_time)} --> {ms_to_vtt_time(end_time)}\n{subtitle_text}\n")
|
73 |
-
toggle = not toggle
|
74 |
-
# Join the VTT entries into a single string
|
75 |
-
vtt_content = "WEBVTT\n\n" + "\n".join(vtt_subtitle)
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
for video_path in tqdm(data,desc='Downloading videos') :
|
85 |
-
video_id=video_path.split('/')[-1].split('.')[0]
|
86 |
-
if existed_video_id.get(video_id,False):
|
87 |
-
continue
|
88 |
-
video_downloaded,caption_downloaded=download_video_with_subtitles(video_id)
|
89 |
-
if caption_downloaded:
|
90 |
-
# convert xml to vtt
|
91 |
-
xml_file_path=f'subtitles_xml/{video_id} (a.en).xml'
|
92 |
-
convert_xml_vtt(xml_file_path,f'subtitles_vtt/{video_id}.vtt')
|
93 |
-
|
94 |
-
|
|
|
1 |
+
import torch
|
2 |
+
import webvtt
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
from minigpt4.common.eval_utils import prepare_texts, init_model
|
6 |
+
from minigpt4.conversation.conversation import CONV_VISION
|
7 |
+
from torchvision import transforms
|
8 |
+
import json
|
9 |
from tqdm import tqdm
|
10 |
+
import soundfile as sf
|
11 |
+
import argparse
|
12 |
+
import moviepy.editor as mp
|
13 |
+
import gradio as gr
|
14 |
from pytubefix import YouTube
|
15 |
+
import shutil
|
16 |
+
from PIL import Image
|
17 |
+
from moviepy.editor import VideoFileClip
|
18 |
+
import torch
|
19 |
+
import random
|
20 |
+
import numpy as np
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
|
23 |
+
def prepare_input(vis_processor,video_path,subtitle_path,instruction):
|
24 |
+
cap = cv2.VideoCapture(video_path)
|
25 |
+
if subtitle_path is not None:
|
26 |
+
# Load the VTT subtitle file
|
27 |
+
vtt_file = webvtt.read(subtitle_path)
|
28 |
+
print("subtitle loaded successfully")
|
29 |
+
clip = VideoFileClip(video_path)
|
30 |
+
total_num_frames = int(clip.duration * clip.fps)
|
31 |
+
# print("Video duration = ",clip.duration)
|
32 |
+
clip.close()
|
33 |
+
else :
|
34 |
+
# calculate the total number of frames in the video using opencv
|
35 |
+
total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
36 |
+
max_images_length = 45
|
37 |
+
max_sub_len = 400
|
38 |
+
images = []
|
39 |
+
frame_count = 0
|
40 |
+
sampling_interval = int(total_num_frames / max_images_length)
|
41 |
+
if sampling_interval == 0:
|
42 |
+
sampling_interval = 1
|
43 |
+
img_placeholder = ""
|
44 |
+
subtitle_text_in_interval = ""
|
45 |
+
history_subtitles = {}
|
46 |
+
raw_frames=[]
|
47 |
+
number_of_words=0
|
48 |
+
transform=transforms.Compose([
|
49 |
+
transforms.ToPILImage(),
|
50 |
+
])
|
51 |
+
while cap.isOpened():
|
52 |
+
ret, frame = cap.read()
|
53 |
+
if not ret:
|
54 |
+
break
|
55 |
+
# Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
|
56 |
+
# we choose 1 frame for every 2 seconds,so we need to combine the subtitles in the interval of 2 seconds
|
57 |
+
if subtitle_path is not None:
|
58 |
+
for subtitle in vtt_file:
|
59 |
+
sub=subtitle.text.replace('\n',' ')
|
60 |
+
if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds) and sub not in subtitle_text_in_interval:
|
61 |
+
if not history_subtitles.get(sub,False):
|
62 |
+
subtitle_text_in_interval+=sub+" "
|
63 |
+
history_subtitles[sub]=True
|
64 |
+
break
|
65 |
+
if frame_count % sampling_interval == 0:
|
66 |
+
raw_frames.append(Image.fromarray(cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB)))
|
67 |
+
frame = transform(frame[:,:,::-1]) # convert to RGB
|
68 |
+
frame = vis_processor(frame)
|
69 |
+
images.append(frame)
|
70 |
+
img_placeholder += '<Img><ImageHere>'
|
71 |
+
if subtitle_path is not None and subtitle_text_in_interval != "" and number_of_words< max_sub_len:
|
72 |
+
img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
|
73 |
+
number_of_words+=len(subtitle_text_in_interval.split(' '))
|
74 |
+
subtitle_text_in_interval = ""
|
75 |
+
frame_count += 1
|
76 |
|
77 |
+
if len(images) >= max_images_length:
|
78 |
+
break
|
79 |
+
cap.release()
|
80 |
+
cv2.destroyAllWindows()
|
81 |
+
if len(images) == 0:
|
82 |
+
# skip the video if no frame is extracted
|
83 |
+
return None,None
|
84 |
+
images = torch.stack(images)
|
85 |
+
instruction = img_placeholder + '\n' + instruction
|
86 |
+
return images,instruction
|
87 |
+
def extract_audio(video_path, audio_path):
|
88 |
+
video_clip = mp.VideoFileClip(video_path)
|
89 |
+
audio_clip = video_clip.audio
|
90 |
+
audio_clip.write_audiofile(audio_path, codec="libmp3lame", bitrate="320k")
|
91 |
+
|
92 |
+
def generate_subtitles(video_path):
|
93 |
+
video_id=video_path.split('/')[-1].split('.')[0]
|
94 |
+
audio_path = f"workspace/inference_subtitles/mp3/{video_id}"+'.mp3'
|
95 |
+
os.makedirs("workspace/inference_subtitles/mp3",exist_ok=True)
|
96 |
+
if existed_subtitles.get(video_id,False):
|
97 |
+
return f"workspace/inference_subtitles/{video_id}"+'.vtt'
|
98 |
+
try:
|
99 |
+
extract_audio(video_path,audio_path)
|
100 |
+
print("successfully extracted")
|
101 |
+
os.system(f"whisper {audio_path} --language English --model large --output_format vtt --output_dir workspace/inference_subtitles")
|
102 |
+
# remove the audio file
|
103 |
+
os.system(f"rm {audio_path}")
|
104 |
+
print("subtitle successfully generated")
|
105 |
+
return f"workspace/inference_subtitles/{video_id}"+'.vtt'
|
106 |
except Exception as e:
|
107 |
+
print("error",e)
|
108 |
+
print("error",video_path)
|
109 |
+
return None
|
110 |
+
|
|
|
|
|
|
|
111 |
|
112 |
+
def run (video_path,instruction,model,vis_processor,gen_subtitles=True):
|
113 |
+
if gen_subtitles:
|
114 |
+
subtitle_path=generate_subtitles(video_path)
|
115 |
+
else :
|
116 |
+
subtitle_path=None
|
117 |
+
prepared_images,prepared_instruction=prepare_input(vis_processor,video_path,subtitle_path,instruction)
|
118 |
+
if prepared_images is None:
|
119 |
+
return "Video cann't be open ,check the video path again"
|
120 |
+
length=len(prepared_images)
|
121 |
+
prepared_images=prepared_images.unsqueeze(0)
|
122 |
+
conv = CONV_VISION.copy()
|
123 |
+
conv.system = ""
|
124 |
+
# if you want to make conversation comment the 2 lines above and make the conv is global variable
|
125 |
+
conv.append_message(conv.roles[0], prepared_instruction)
|
126 |
+
conv.append_message(conv.roles[1], None)
|
127 |
+
prompt = [conv.get_prompt()]
|
128 |
+
answers = model.generate(prepared_images, prompt, max_new_tokens=args.max_new_tokens, do_sample=True, lengths=[length],num_beams=1)
|
129 |
+
return answers[0]
|
130 |
|
131 |
+
|
132 |
+
def get_arguments():
|
133 |
+
parser = argparse.ArgumentParser(description="Inference parameters")
|
134 |
+
parser.add_argument("--cfg-path", help="path to configuration file.",default="test_configs/llama2_test_config.yaml")
|
135 |
+
parser.add_argument("--ckpt", type=str,default='checkpoints/video_llama_checkpoint_last.pth', help="path to checkpoint")
|
136 |
+
parser.add_argument("--add_subtitles",action= 'store_true',help="whether to add subtitles")
|
137 |
+
parser.add_argument("--question", type=str, help="question to ask")
|
138 |
+
parser.add_argument("--video_path", type=str, help="Path to the video file")
|
139 |
+
parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens")
|
140 |
+
parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model")
|
141 |
+
parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha")
|
142 |
+
parser.add_argument(
|
143 |
+
"--options",
|
144 |
+
nargs="+",
|
145 |
+
help="override some settings in the used config, the key-value pair "
|
146 |
+
"in xxx=yyy format will be merged into config file (deprecate), "
|
147 |
+
"change to --cfg-options instead.",
|
148 |
+
)
|
149 |
+
return parser.parse_args()
|
150 |
+
args=get_arguments()
|
151 |
+
def setup_seeds(seed):
|
152 |
+
random.seed(seed)
|
153 |
+
np.random.seed(seed)
|
154 |
+
torch.manual_seed(seed)
|
155 |
+
torch.cuda.manual_seed(seed)
|
156 |
+
cudnn.benchmark = False
|
157 |
+
cudnn.deterministic = True
|
158 |
|
159 |
+
import yaml
|
160 |
+
with open('test_configs/llama2_test_config.yaml') as file:
|
161 |
+
config = yaml.load(file, Loader=yaml.FullLoader)
|
162 |
+
seed=config['run']['seed']
|
163 |
+
print("seed",seed)
|
164 |
|
165 |
+
model, vis_processor = init_model(args)
|
166 |
+
conv = CONV_VISION.copy()
|
167 |
+
conv.system = ""
|
168 |
+
inference_subtitles_folder="inference_subtitles"
|
169 |
+
os.makedirs(inference_subtitles_folder,exist_ok=True)
|
170 |
+
existed_subtitles={}
|
171 |
+
for sub in os.listdir(inference_subtitles_folder):
|
172 |
+
existed_subtitles[sub.split('.')[0]]=True
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
+
if __name__ == "__main__":
|
175 |
+
video_path=args.video_path
|
176 |
+
instruction=args.question
|
177 |
+
add_subtitles=args.add_subtitles
|
178 |
+
# setup_seeds(seed)
|
179 |
+
pred=run(video_path,instruction,model,vis_processor,gen_subtitles=add_subtitles)
|
180 |
+
print(pred)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_configs/224_minigpt4_llama2_image.yaml
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
model:
|
2 |
-
arch:
|
3 |
-
model_type:
|
4 |
llama_model: "meta-llama/Llama-2-7b-chat-hf"
|
|
|
|
|
5 |
|
6 |
|
7 |
datasets:
|
@@ -42,7 +44,7 @@ run:
|
|
42 |
iters_per_epoch: 5000
|
43 |
|
44 |
seed: 42
|
45 |
-
output_dir: "output/
|
46 |
|
47 |
amp: True
|
48 |
resume_ckpt_path: null
|
|
|
1 |
model:
|
2 |
+
arch: mini_gpt4_llama_v2
|
3 |
+
model_type: pretrain_vicuna
|
4 |
llama_model: "meta-llama/Llama-2-7b-chat-hf"
|
5 |
+
max_txt_len: 160
|
6 |
+
max_context_len: 512
|
7 |
|
8 |
|
9 |
datasets:
|
|
|
44 |
iters_per_epoch: 5000
|
45 |
|
46 |
seed: 42
|
47 |
+
output_dir: "output/minigpt4_stage1_pretrain_llama2"
|
48 |
|
49 |
amp: True
|
50 |
resume_ckpt_path: null
|
train_configs/224_minigpt4_llama2_image_align.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: mini_gpt4_llama_v2
|
3 |
+
model_type: pretrain_vicuna
|
4 |
+
llama_model: "meta-llama/Llama-2-7b-chat-hf"
|
5 |
+
|
6 |
+
max_txt_len: 160
|
7 |
+
max_context_len: 512
|
8 |
+
end_sym: "</s>"
|
9 |
+
prompt_path: "train_configs/alignment.txt"
|
10 |
+
prompt_template: '[INST] {} [/INST] '
|
11 |
+
ckpt: put your pretrained ckpt here
|
12 |
+
|
13 |
+
datasets:
|
14 |
+
cc_sbu_align:
|
15 |
+
batch_size: 12
|
16 |
+
vis_processor:
|
17 |
+
train:
|
18 |
+
name: "blip2_image_train"
|
19 |
+
image_size: 224
|
20 |
+
text_processor:
|
21 |
+
train:
|
22 |
+
name: "blip_caption"
|
23 |
+
|
24 |
+
run:
|
25 |
+
task: image_text_pretrain
|
26 |
+
# optimizer
|
27 |
+
lr_sched: "linear_warmup_cosine_lr"
|
28 |
+
init_lr: 3e-5
|
29 |
+
min_lr: 1e-5
|
30 |
+
warmup_lr: 1e-6
|
31 |
+
|
32 |
+
weight_decay: 0.05
|
33 |
+
max_epoch: 5
|
34 |
+
iters_per_epoch: 200
|
35 |
+
num_workers: 4
|
36 |
+
warmup_steps: 200
|
37 |
+
|
38 |
+
seed: 42
|
39 |
+
output_dir: "output/minigpt4_stage2_finetune"
|
40 |
+
|
41 |
+
amp: True
|
42 |
+
resume_ckpt_path: null
|
43 |
+
|
44 |
+
evaluate: False
|
45 |
+
train_splits: ["train"]
|
46 |
+
|
47 |
+
device: "cuda"
|
48 |
+
world_size: 1
|
49 |
+
dist_url: "env://"
|
50 |
+
distributed: True
|
51 |
+
|
52 |
+
wandb_log: True
|
53 |
+
job_name: minigpt4_finetune
|
train_configs/224_minigpt4_mistral_image.yaml
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
model:
|
2 |
-
arch:
|
3 |
-
model_type:
|
4 |
llama_model: "mistralai/Mistral-7B-Instruct-v0.2"
|
5 |
-
|
|
|
6 |
|
7 |
datasets:
|
8 |
laion:
|
@@ -42,7 +43,7 @@ run:
|
|
42 |
iters_per_epoch: 5000
|
43 |
|
44 |
seed: 42
|
45 |
-
output_dir: "output/
|
46 |
|
47 |
amp: True
|
48 |
resume_ckpt_path: null
|
@@ -56,4 +57,4 @@ run:
|
|
56 |
distributed: True
|
57 |
|
58 |
wandb_log: True
|
59 |
-
job_name:
|
|
|
1 |
model:
|
2 |
+
arch: mini_gpt4_llama_v2
|
3 |
+
model_type: pretrain_vicuna
|
4 |
llama_model: "mistralai/Mistral-7B-Instruct-v0.2"
|
5 |
+
max_txt_len: 160
|
6 |
+
max_context_len: 512
|
7 |
|
8 |
datasets:
|
9 |
laion:
|
|
|
43 |
iters_per_epoch: 5000
|
44 |
|
45 |
seed: 42
|
46 |
+
output_dir: "output/minigpt4_stage1_pretrain_mistral"
|
47 |
|
48 |
amp: True
|
49 |
resume_ckpt_path: null
|
|
|
57 |
distributed: True
|
58 |
|
59 |
wandb_log: True
|
60 |
+
job_name: minigpt4_mistral_pretrain
|
train_configs/224_minigpt4_mistral_image_align.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: mini_gpt4_llama_v2
|
3 |
+
model_type: pretrain_vicuna
|
4 |
+
llama_model: "mistralai/Mistral-7B-Instruct-v0.2"
|
5 |
+
|
6 |
+
max_txt_len: 160
|
7 |
+
max_context_len: 512
|
8 |
+
end_sym: "</s>"
|
9 |
+
prompt_path: "train_configs/alignment.txt"
|
10 |
+
prompt_template: '[INST] {} [/INST] '
|
11 |
+
ckpt: put your pretrained ckpt here
|
12 |
+
|
13 |
+
datasets:
|
14 |
+
cc_sbu_align:
|
15 |
+
batch_size: 12
|
16 |
+
vis_processor:
|
17 |
+
train:
|
18 |
+
name: "blip2_image_train"
|
19 |
+
image_size: 224
|
20 |
+
text_processor:
|
21 |
+
train:
|
22 |
+
name: "blip_caption"
|
23 |
+
|
24 |
+
run:
|
25 |
+
task: image_text_pretrain
|
26 |
+
# optimizer
|
27 |
+
lr_sched: "linear_warmup_cosine_lr"
|
28 |
+
init_lr: 3e-5
|
29 |
+
min_lr: 1e-5
|
30 |
+
warmup_lr: 1e-6
|
31 |
+
|
32 |
+
weight_decay: 0.05
|
33 |
+
max_epoch: 5
|
34 |
+
iters_per_epoch: 200
|
35 |
+
num_workers: 4
|
36 |
+
warmup_steps: 200
|
37 |
+
|
38 |
+
seed: 42
|
39 |
+
output_dir: "output/minigpt4_stage2_finetune"
|
40 |
+
|
41 |
+
amp: True
|
42 |
+
resume_ckpt_path: null
|
43 |
+
|
44 |
+
evaluate: False
|
45 |
+
train_splits: ["train"]
|
46 |
+
|
47 |
+
device: "cuda"
|
48 |
+
world_size: 1
|
49 |
+
dist_url: "env://"
|
50 |
+
distributed: True
|
51 |
+
|
52 |
+
wandb_log: True
|
53 |
+
job_name: minigpt4_finetune
|
train_configs/224_v2_llama2_video_stage_2.yaml
CHANGED
@@ -8,7 +8,7 @@ model:
|
|
8 |
image_size: 224
|
9 |
end_sym: "</s>"
|
10 |
llama_model: "meta-llama/Llama-2-7b-chat-hf"
|
11 |
-
ckpt: "checkpoints/image_llama2_checkpoint.pth"
|
12 |
use_grad_checkpoint: True
|
13 |
chat_template: True
|
14 |
lora_r: 64
|
@@ -56,7 +56,7 @@ run:
|
|
56 |
iters_per_epoch: 1000
|
57 |
|
58 |
seed: 42
|
59 |
-
output_dir: "training_output/cmd_webvid_pretrain"
|
60 |
|
61 |
amp: True
|
62 |
resume_ckpt_path: null
|
|
|
8 |
image_size: 224
|
9 |
end_sym: "</s>"
|
10 |
llama_model: "meta-llama/Llama-2-7b-chat-hf"
|
11 |
+
ckpt: "checkpoints/image_llama2_checkpoint.pth" # set the checkpoint to start the training from
|
12 |
use_grad_checkpoint: True
|
13 |
chat_template: True
|
14 |
lora_r: 64
|
|
|
56 |
iters_per_epoch: 1000
|
57 |
|
58 |
seed: 42
|
59 |
+
output_dir: "training_output/cmd_webvid_pretrain/llama2"
|
60 |
|
61 |
amp: True
|
62 |
resume_ckpt_path: null
|
train_configs/224_v2_llama2_video_stage_3.yaml
CHANGED
@@ -7,8 +7,8 @@ model:
|
|
7 |
low_resource: False
|
8 |
image_size: 224
|
9 |
end_sym: "</s>"
|
10 |
-
llama_model: "meta-llama/Llama-
|
11 |
-
ckpt: "checkpoints/video_captioning_llama_checkpoint_last.pth"
|
12 |
use_grad_checkpoint: True
|
13 |
chat_template: True
|
14 |
lora_r: 64
|
@@ -44,7 +44,7 @@ run:
|
|
44 |
iters_per_epoch: 1000
|
45 |
|
46 |
seed: 42
|
47 |
-
output_dir: "training_output/pretrained_video_instruct"
|
48 |
|
49 |
amp: True
|
50 |
resume_ckpt_path: null
|
|
|
7 |
low_resource: False
|
8 |
image_size: 224
|
9 |
end_sym: "</s>"
|
10 |
+
llama_model: "meta-llama/Meta-Llama-3-8B-Instruct"
|
11 |
+
# ckpt: "checkpoints/video_captioning_llama_checkpoint_last.pth" # set the checkpoint to start the training from
|
12 |
use_grad_checkpoint: True
|
13 |
chat_template: True
|
14 |
lora_r: 64
|
|
|
44 |
iters_per_epoch: 1000
|
45 |
|
46 |
seed: 42
|
47 |
+
output_dir: "training_output/pretrained_video_instruct/llama2"
|
48 |
|
49 |
amp: True
|
50 |
resume_ckpt_path: null
|
train_configs/224_v2_mistral_video_stage_2.yaml
CHANGED
@@ -8,7 +8,7 @@ model:
|
|
8 |
image_size: 224
|
9 |
end_sym: "</s>"
|
10 |
llama_model: "mistralai/Mistral-7B-Instruct-v0.2"
|
11 |
-
ckpt: "checkpoints/image_mistral_checkpoint.pth"
|
12 |
use_grad_checkpoint: True
|
13 |
chat_template: True
|
14 |
lora_r: 64
|
@@ -56,7 +56,7 @@ run:
|
|
56 |
iters_per_epoch: 875
|
57 |
|
58 |
seed: 42
|
59 |
-
output_dir: "training_output/cmd_webvid_pretrain"
|
60 |
|
61 |
amp: True
|
62 |
resume_ckpt_path: null
|
|
|
8 |
image_size: 224
|
9 |
end_sym: "</s>"
|
10 |
llama_model: "mistralai/Mistral-7B-Instruct-v0.2"
|
11 |
+
ckpt: "checkpoints/image_mistral_checkpoint.pth" # set the checkpoint to start the training from
|
12 |
use_grad_checkpoint: True
|
13 |
chat_template: True
|
14 |
lora_r: 64
|
|
|
56 |
iters_per_epoch: 875
|
57 |
|
58 |
seed: 42
|
59 |
+
output_dir: "training_output/cmd_webvid_pretrain/mistral"
|
60 |
|
61 |
amp: True
|
62 |
resume_ckpt_path: null
|
train_configs/224_v2_mistral_video_stage_3.yaml
CHANGED
@@ -8,7 +8,7 @@ model:
|
|
8 |
image_size: 224
|
9 |
end_sym: "</s>"
|
10 |
llama_model: "mistralai/Mistral-7B-Instruct-v0.2"
|
11 |
-
ckpt: "checkpoints/video_captioning_mistral_checkpoint_last.pth"
|
12 |
use_grad_checkpoint: True
|
13 |
chat_template: True
|
14 |
lora_r: 64
|
@@ -46,7 +46,7 @@ run:
|
|
46 |
iters_per_epoch: 875
|
47 |
|
48 |
seed: 42
|
49 |
-
output_dir: "training_output/pretrained_video_instruct"
|
50 |
|
51 |
amp: True
|
52 |
resume_ckpt_path: null
|
|
|
8 |
image_size: 224
|
9 |
end_sym: "</s>"
|
10 |
llama_model: "mistralai/Mistral-7B-Instruct-v0.2"
|
11 |
+
ckpt: "checkpoints/video_captioning_mistral_checkpoint_last.pth" # set the checkpoint to start the training from
|
12 |
use_grad_checkpoint: True
|
13 |
chat_template: True
|
14 |
lora_r: 64
|
|
|
46 |
iters_per_epoch: 875
|
47 |
|
48 |
seed: 42
|
49 |
+
output_dir: "training_output/pretrained_video_instruct/mistral"
|
50 |
|
51 |
amp: True
|
52 |
resume_ckpt_path: null
|
train_configs/alignment.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<Img><ImageHere></Img> Describe this image in detail.
|
2 |
+
<Img><ImageHere></Img> Take a look at this image and describe what you notice.
|
3 |
+
<Img><ImageHere></Img> Please provide a detailed description of the picture.
|
4 |
+
<Img><ImageHere></Img> Could you describe the contents of this image for me?
|