cathyxl commited on
Commit
f239efc
·
1 Parent(s): 8fb958b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. .gitignore +66 -0
  3. DATA.md +124 -0
  4. README.md +376 -5
  5. app.py +18 -0
  6. assert/data.png +3 -0
  7. assert/logo.png +3 -0
  8. assert/module.png +3 -0
  9. assert/performance.png +3 -0
  10. assert/teaser.jpg +3 -0
  11. assert/zeroshot.png +3 -0
  12. dataset/__init__.py +158 -0
  13. dataset/base_dataset.py +108 -0
  14. dataset/it_dataset.py +206 -0
  15. dataset/utils.py +41 -0
  16. dataset/video_utils.py +214 -0
  17. docs/PoolLLaVA_Report.pdf +3 -0
  18. example/1917.mp4 +3 -0
  19. example/bear.jpg +3 -0
  20. example/cooking.mp4 +3 -0
  21. example/dog.png +3 -0
  22. example/jesse_dance.mp4 +3 -0
  23. example/working.mp4 +3 -0
  24. example/yoga.mp4 +3 -0
  25. models/__init__.py +0 -0
  26. models/pllava/__init__.py +55 -0
  27. models/pllava/configuration_pllava.py +149 -0
  28. models/pllava/convert_pllava_weights_to_hf.py +1 -0
  29. models/pllava/modeling_pllava.py +626 -0
  30. models/pllava/processing_pllava.py +292 -0
  31. python_scripts/hf.py +80 -0
  32. requirements.no_torch.txt +244 -0
  33. requirements.torch.txt +4 -0
  34. requirements.txt +246 -0
  35. scripts/accel_config_deepspeed_zero2.yaml +21 -0
  36. scripts/accel_config_deepspeed_zero3_offload.yaml +22 -0
  37. scripts/accel_config_deepspeed_zero3_offload_multinode.yaml +25 -0
  38. scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml +25 -0
  39. scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml +25 -0
  40. scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml +23 -0
  41. scripts/accel_config_multigpu.yaml +16 -0
  42. scripts/accel_config_multinode.yaml +18 -0
  43. scripts/accel_config_singlegpu.yaml +16 -0
  44. scripts/demo.sh +32 -0
  45. scripts/eval.sh +104 -0
  46. scripts/eval_yiprompt.sh +53 -0
  47. scripts/gallery.sh +11 -0
  48. scripts/train_pllava.sh +34 -0
  49. scripts/train_pllava_13b.sh +50 -0
  50. scripts/train_pllava_34b.sh +50 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ *.mov filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
39
+ *.jpg filter=lfs diff=lfs merge=lfs -text
40
+ *.pdf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # local #
2
+ tmp*/
3
+ cache/*
4
+ */cache*/
5
+ tmp*.py
6
+ tmp*
7
+ *pickle
8
+ data/
9
+
10
+ # Zip Files/Packages #
11
+ *.7z
12
+ *.dmg
13
+ *.gz
14
+ *.iso
15
+ *.jar
16
+ *.rar
17
+ *.tar
18
+ *.zip
19
+
20
+ # Logs and databases #
21
+ *.log
22
+ *.sql
23
+ *.sqlite
24
+ .ipynb_checkpoints/
25
+ *.swp
26
+ *.vscode/
27
+ *.idea/
28
+ *.pyc
29
+ __pycache__
30
+ slurm*out
31
+
32
+ # OS files #
33
+ .DS_Store
34
+ .DS_Store?
35
+ ._*
36
+ .Spotlight-V100
37
+ .Trashes
38
+ ehthumbs.db
39
+ Thumbs.db
40
+
41
+
42
+ .vim-arsync
43
+ scratch.norg
44
+ sync_to_red.sh
45
+
46
+ anno/
47
+ wandb/
48
+ logs/
49
+ accelerate_config/
50
+ *.pth
51
+ hf_*
52
+
53
+ # local folders
54
+ MODELS
55
+ DATAS
56
+ SAVED
57
+ EXPERIMENTS
58
+ REMOTE_HF
59
+ TEST
60
+
61
+ test_results
62
+ test_training
63
+ test_hdfs.py
64
+ magic_video_outputs/llava*
65
+ magic_video_outputs
66
+ pllava_video_outputs/
DATA.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data
2
+ ## Instruction Training Data
3
+ <!-- > *originated from [Videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2)* -->
4
+
5
+
6
+ For training, we leveraged the video instruction tuning data from [Videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2).
7
+
8
+ #### 1. Download json annotation files from huggingface.
9
+ [![Dataset meta](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-VideoChat2%20IT-blue)](https://huggingface.co/datasets/OpenGVLab/VideoChat2-IT)
10
+
11
+ <!-- > ![images](./assert/data.png) -->
12
+
13
+ #### 2. Download the raw videos from the following links.
14
+ The video directories can be found in tasks/train/instruction_data.py. You can also change them to your own saved paths.
15
+
16
+ - [VideoChat](https://github.com/OpenGVLab/InternVideo/tree/main/Data/instruction_data): Based on [InternVid](https://github.com/OpenGVLab/InternVideo/tree/main/Data/InternVid), download the processed version directly [here](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/data/videochat2_conversation_videos.zip)
17
+ - [VideoChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main/data)
18
+ - [Kinetics-710](https://github.com/OpenGVLab/UniFormerV2/blob/main/DATASET.md), download Kinetics 400/600/700 [here](https://openxlab.org.cn/datasets?keywords=kinetics).
19
+ - [SthSthV2](https://developer.qualcomm.com/software/ai-datasets/something-something): Option candidates were generated from [UMT](https://github.com/OpenGVLab/unmasked_teacher) top-20 predictions.
20
+ - [NExTQA](https://github.com/doc-doc/NExT-QA)
21
+ - [CLEVRER](https://clevrer.csail.mit.edu/)
22
+ - [WebVid](https://maxbain.com/webvid-dataset/)
23
+ - [YouCook2](https://youcook2.eecs.umich.edu/), download the processed version [here](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/data/youcook_split_videos.zip).
24
+ - [TextVR](https://github.com/callsys/textvr)
25
+ - [TGIF](https://github.com/YunseokJANG/tgif-qa)
26
+ - [EgoQA](https://ego4d-data.org/), download the processed version [here](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/data/egoqa_split_videos.zip).
27
+
28
+ #### 3. We also provide our processed json annotation files here.
29
+
30
+ [![Dataset meta](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-magic%5Fjsons-blue)](https://huggingface.co/datasets/cathyxl/magic_jsons)
31
+
32
+
33
+ <!-- We leveraged the training data from [Videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2). We only used the video part for video instruct tuning. -->
34
+
35
+ ## Evaluation Data & Others
36
+ Follow this section to obtain the evaluation open resources.
37
+
38
+ ### VCGBench
39
+
40
+ We refer to the VideoChatGPT video question answering evaluation as VCGBench in this repo. We followed the original [repo](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main) to prepare the evaluation data.
41
+
42
+ ### MVBench
43
+ We follow the original [Videochat2 repo](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2) in setting up the MVBench Evaluation. You can also find helpful resources at their [huggingface repo](https://huggingface.co/datasets/OpenGVLab/MVBench)
44
+
45
+
46
+ ### Videoqabench
47
+ We refer to all other video question answering benchmarks as videoqabench in this repo. They are mainly prepared folloing the original repos. Each listed:
48
+ 1. [MSVD](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/) & [MSRVTT](https://github.com/xudejing/video-question-answering)
49
+
50
+ 3. [Activity Net](https://github.com/MILVLG/activitynet-qa/tree/master)
51
+ 4. [TGIF](https://github.com/raingo/TGIF-Release/tree/master)
52
+
53
+ Also other fantastic repo intergrating these benchmarks are helpful in the process of setting up the evaluation data:
54
+ - [VideoChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main)
55
+ - [VideoLlava](https://github.com/PKU-YuanGroup/Video-LLaVA/tree/main/videollava)
56
+ - [IG-VLM](https://github.com/imagegridworth/IG-VLM/tree/main)
57
+
58
+
59
+
60
+ ### Recaptioning
61
+ #### Inter4k
62
+
63
+ This is a dataset with 1000 samples of high resolution videos. We prepare the data folloing the instructions from their [official website](https://alexandrosstergiou.github.io/datasets/Inter4K/index.html)
64
+
65
+ #### Extending Reacptioning
66
+ The recaptioning part is designed to be extendable.
67
+
68
+ inference script [tasks/eval/recaption/pllava_recaption.py](tasks/eval/recaption/pllava_recaption.py) would use a dataset class [RecaptionDataset](tasks/eval/recaption/__init__.py#L197). The detailed information is kept in the data_list_info attribute as:
69
+ ```
70
+ data_list_info = OrderedDict({
71
+ # "Panda70M": OrderedDict(
72
+ # json_relpath="Panda70M/annotations.json",
73
+ # prefix="DATAS/Recaption/Panda70M/videos",
74
+ # data_type="video",
75
+ # bound=False,
76
+ # key_rename_map={
77
+ # # 'caption': 'hint',
78
+ # },
79
+ # name_key='video_name',
80
+ # postfix=('mp4', 'mkv', 'webm'),
81
+ # recaption_type=RecaptionSample,
82
+ # ), # don't has start & end
83
+ "Inter4K": OrderedDict(
84
+ json_relpath="Inter4K/annotations.json",
85
+ prefix="DATAS/Recaption/Inter4K/60fps/UHD",
86
+ data_type="video",
87
+ bound=False,
88
+ key_rename_map={
89
+ # 'caption': 'hint',
90
+ },
91
+ name_key='video_name',
92
+ postfix=('mp4', 'mkv', 'webm'),
93
+ recaption_type=CaptionSample,
94
+ ), # don't has start & end
95
+ })
96
+ ```
97
+ It contains the path to a annotation json file where there is a list and each item of the list is a sample waiting for captioning. For example, the Inter4K/annotations.json is like:
98
+ ```json
99
+ [
100
+ {
101
+ "video_name": "973"
102
+ },
103
+ ...
104
+ ]
105
+ ```
106
+ and the directory DATAS/Recaption/Inter4K/60fps/UHD would look like:
107
+ ```
108
+ $ ls DATAS/Recaption/Inter4K/60fps/UHD
109
+ 1.mp4 134.mp4 170.mp4 ....
110
+ ```
111
+
112
+ Naively, only the video is needed when captioning directly, therefore the annotation file only needs to contain the names of each video under the "prefix" directory.
113
+
114
+ Extending a dataset for captioning would consist of the folloing steps:
115
+ 1. have all the videos downloaded
116
+ 2. construct a annotation.json file with sepecific format.
117
+ 3. configure the recaption dataset [here](tasks/eval/recaption/__init__.py#L197), where you would need to determine:
118
+ - json_relpath: the annotation relative path
119
+ - prefix: root directory for videos
120
+ - postfix: a list containing all the file extensions for these videos
121
+
122
+ The other options are experimental, so stick with the default setting as in Inter4k. The recommended length of video is around 5-20 seconds.
123
+
124
+ p.s. "bound" is to make sure the video pass to the model doesn't have scene transition or so. This part wasn't tested, so set the bound to false and make sure the original videos files are single clip of a video. But always feel free to discover and contribute to PLLaVA!
README.md CHANGED
@@ -1,12 +1,383 @@
1
  ---
2
- title: Pllava 7b Demo
3
- emoji: 🌖
4
  colorFrom: blue
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.28.3
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Plava 7b Demo
3
+ emoji: 👁
4
  colorFrom: blue
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.27.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ <div align="center">
13
+
14
+ <h2><a href="https://pllava.github.io/">PLLaVA: Parameter-free LLaVA Extension from Images to Videos for Video Dense Captioning</a></h2>
15
+
16
+ [Lin Xu](https://scholar.google.com/citations?user=_Gu69coAAAAJ), [Yilin Zhao](https://ermu2001.github.io/me.io/), [Daquan Zhou](https://scholar.google.com/citations?user=DdCAbWwAAAAJ), [Zhijie Lin](https://scholar.google.com/citations?user=xXMj6_EAAAAJ), [See-Kiong Ng](https://scholar.google.com/citations?user=_wsommYAAAAJ), [Jiashi Feng](https://scholar.google.com.sg/citations?user=Q8iay0gAAAAJ&hl=en)
17
+
18
+ </div>
19
+
20
+ <!-- [![Paper](https://img.shields.io/badge/cs.CV-2311.17005-b31b1b?logo=arxiv&logoColor=red)](https://arxiv.org/abs/2311.17005) -->
21
+
22
+ **Project Page: [PLLaVA](https://pllava.github.io/)**
23
+
24
+ [![arXiv](https://img.shields.io/badge/arXiv-2404.16994-b31b1b.svg)](https://arxiv.org/abs/2404.16994)
25
+ [![YouTube Video](https://img.shields.io/badge/YouTube-Video-red)](https://www.youtube.com/watch?v=nAEje8tu18U)
26
+ [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-34b)
27
+
28
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-activitynet)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-activitynet?p=pllava-parameter-free-llava-extension-from-1)
29
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-msrvtt-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msrvtt-qa?p=pllava-parameter-free-llava-extension-from-1)
30
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-msvd-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msvd-qa?p=pllava-parameter-free-llava-extension-from-1)
31
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-question-answering-on-mvbench)](https://paperswithcode.com/sota/video-question-answering-on-mvbench?p=pllava-parameter-free-llava-extension-from-1)
32
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-tgif-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-tgif-qa?p=pllava-parameter-free-llava-extension-from-1)
33
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-4)](https://paperswithcode.com/sota/video-based-generative-performance-4?p=pllava-parameter-free-llava-extension-from-1)
34
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-3)](https://paperswithcode.com/sota/video-based-generative-performance-3?p=pllava-parameter-free-llava-extension-from-1)
35
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance)](https://paperswithcode.com/sota/video-based-generative-performance?p=pllava-parameter-free-llava-extension-from-1)
36
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-2)](https://paperswithcode.com/sota/video-based-generative-performance-2?p=pllava-parameter-free-llava-extension-from-1)
37
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-1)](https://paperswithcode.com/sota/video-based-generative-performance-1?p=pllava-parameter-free-llava-extension-from-1)
38
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-5)](https://paperswithcode.com/sota/video-based-generative-performance-5?p=pllava-parameter-free-llava-extension-from-1)
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+ ![]()
47
+ <div align="center">
48
+ <a href="https://pllava.github.io">
49
+ <img src="assert/logo.png">
50
+ </a>
51
+ </div>
52
+
53
+ <div align="center">
54
+ <video src="https://github.com/magic-research/PLLaVA/assets/55656210/a6619702-12d3-489d-bfcc-0ef7105544b2" width="100%">
55
+ </div>
56
+
57
+
58
+
59
+
60
+
61
+
62
+ ## Overview
63
+
64
+ Welcome to PLLAVA!
65
+
66
+ The primary purpose of this repository is to support research and the development of prototype models. It is designed to facilitate ease of experimentation and enable a clear overview of results. Please note that this repo is currently undergoing development and reconstruction.
67
+
68
+ It's important to mention that we have not optimized the response speed of the application or the frontend logic. Our goal is to maintain simplicity, clarity, and ease of development, making it accessible for both researchers and students. If you have suggestions or want to enhance the application's performance, please feel free to contact us or contribute to the project.
69
+
70
+
71
+ We've briefly introduce our work in section [PLLAVA](#%EF%B8%8F-pllava). For more details, feel free to read our paper. Check out section [Usage](#hammer-usage) to start using this repo. If you felt our works interesting, please star us, your support is all we want. If you find our work helpful, feel free to [cite](#page_facing_up-citation) us directly.
72
+
73
+ ## :fire: Updates
74
+
75
+ - **2024/4/24**: Release:
76
+ - We are releasing our code/models/datasets.
77
+
78
+ ## 🏖️ PLLAVA
79
+ <div align="center">
80
+ <a href="https://www.youtube.com/embed/nAEje8tu18U?si=GXxjgP93j77FzDbw">
81
+ <img src="assert/teaser.jpg">
82
+ </a>
83
+ </div>
84
+
85
+
86
+ ### Abstract
87
+
88
+ Vision-language pre-training (VLP) has significantly elevated performance across a range of vision-language applications. Yet, the pre-training process for video-related tasks demands an exceptionally high degree of computational and data resources. This paper investigates a straightforward, highly efficient, and resource-light approach to adapting an existing image-language pre-training model for video data. Our preliminary experiments reveal that directly fine-tuning pre-trained image-language models with multiple frames on video datasets leads to performance saturation or even a drop in caption-related tasks. Besides, it is also vulnerable to prompts and tends to provide short descriptions. We conducted a deep analysis and observed that the performance saturation and the vulnerability might be related to the dominant patches that exist in some single video patches. We then propose a simple pooling strategy to smooth the feature distribution along the temporal dimension and thus reduce the dominant impacts from some extreme tokens. The new model is termed Pooling LLaVA, or PLLaVA in short. With the proposed pooling strategy, we achieve new state-of-the-art performance on all evaluated datasets. Notably, on the recent popular Video ChatGPT benchmark, PLLaVA achieves a score of 3.48 out of 5 on average of five evaluated dimensions, which is the new state-of-the-art score on the leaderboard and is 0.31 higher than the previous SOTA results from GPT4V (IG-VLM). On the latest multi-choice benchmark MVBench, PLLaVA achieves 58.1% accuracy on average across 20 sub-tasks, which is the new state-of-the-art result and is 14.5% higher than GPT4V (IG-VLM).
89
+
90
+ <div align="center"><img src="assert/module.png"></div>
91
+
92
+
93
+ ### SEARCHING FOR OPTIMAL POOLING STRATEGY
94
+ There are two dimensions for the pooling strategy: the spatial dimension and the temporal dimension. We empirically found that reducing the spatial dimension with a larger temporal dimension could lead to better model performance, compared to reducing the temporal dimension directly.
95
+
96
+ <div align="center"><img src="assert/zeroshot.png"></div>
97
+
98
+
99
+ ### STATE-OF-THE-ART PERFORMANCE
100
+ We compare the performance of PLLAVA with recent popular methods over both question-answer and captioning datasets. The results are shown below.
101
+
102
+ <div align="center"><img src="assert/performance.png"></div>
103
+
104
+ ## :hammer: Usage
105
+
106
+ This section provides guidance on how to run, train, and evaluate our models.
107
+
108
+ ### Install
109
+ First, you will need to set up the environment and download some pre-trained weights.
110
+
111
+ This repo is built up using [transformers](https://github.com/huggingface/transformers) for model construction along with [accelerate](https://github.com/huggingface/accelerate) for distributed training. Follow the instructions to install the needed environment.
112
+
113
+ 1. Above all, the following environment set up is for python 3.10. If you choose to use conda for environment setup, we recommend creating the virtual environment with:
114
+ ```bash
115
+ conda create -n pllava python=3.10
116
+ ```
117
+
118
+ 1. Firstly, install [pytorch](https://pytorch.org/) from the official website. The code runs on torch 2.2.1, cu118 or cu122. Select the version that suits your drive version.
119
+
120
+ ```
121
+ torch 2.2.1+cu118
122
+ torchaudio 2.2.1+cu118
123
+ torchvision 0.17.1+cu118
124
+ ```
125
+
126
+ If your driver version is higher than cu121, you could probably try installing with the following scripts:
127
+ ```bash
128
+ pip install -r requirements.txt
129
+ ```
130
+
131
+ Otherwise, you would need to install a torch for your server first, then install the other packages:
132
+ ```bash
133
+ pip install -r requirements.torch.txt # decide your own requirements, (this is for cu11), or install torch directly following the official website.
134
+ pip install -r requirements.no_torch.txt # install the following
135
+ ```
136
+
137
+ 1. Prepare the model.
138
+ We prefer to have huggingface models explicitly downloaded to a MODELS directory. However, if you are familiar with huggingface-hub usage, feel free to organize the model yourself.
139
+ ```
140
+ python python_scripts/hf.py
141
+ ```
142
+
143
+ Here are some detailed information of the obtained models:
144
+
145
+
146
+ | Model | Link | Initialized From |
147
+ | ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------- |
148
+ | pllava-7b | [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-7b) | [llava-hf/llava-v1.6-vicuna-7b-hf · Hugging Face](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf) |
149
+ | pllava-13b | [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-13b) | [llava-hf/llava-v1.6-vicuna-13b-hf · Hugging Face](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) |
150
+ | pllava-34b | [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-34b) | [llava-hf/llava-v1.6-34b-hf · Hugging Face](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) |
151
+
152
+ The model directory should look like this, where you would only need the corresponding model's weights and directory.
153
+
154
+ ```
155
+ $ tree MODELS
156
+ MODELS
157
+ |-- pllava-13b
158
+ | |-- added_tokens.json
159
+ | |-- config.json
160
+ | |-- generation_config.json
161
+ | |-- model-00001-of-00006.safetensors
162
+ | |-- model-00002-of-00006.safetensors
163
+ | |-- model-00003-of-00006.safetensors
164
+ | |-- model-00004-of-00006.safetensors
165
+ | |-- model-00005-of-00006.safetensors
166
+ | |-- model-00006-of-00006.safetensors
167
+ | |-- model.safetensors.index.json
168
+ | |-- preprocessor_config.json
169
+ | |-- processor_config.json
170
+ | |-- special_tokens_map.json
171
+ | |-- tokenizer.json
172
+ | |-- tokenizer.model
173
+ | `-- tokenizer_config.json
174
+ |-- pllava-34b
175
+ | |-- added_tokens.json
176
+ | |-- config.json
177
+ | |-- generation_config.json
178
+ | |-- model-00001-of-00015.safetensors
179
+ | |-- model-00002-of-00015.safetensors
180
+ | |-- model-00003-of-00015.safetensors
181
+ | |-- model-00004-of-00015.safetensors
182
+ | |-- model-00005-of-00015.safetensors
183
+ | |-- model-00006-of-00015.safetensors
184
+ | |-- model-00007-of-00015.safetensors
185
+ | |-- model-00008-of-00015.safetensors
186
+ | |-- model-00009-of-00015.safetensors
187
+ | |-- model-00010-of-00015.safetensors
188
+ | |-- model-00011-of-00015.safetensors
189
+ | |-- model-00012-of-00015.safetensors
190
+ | |-- model-00013-of-00015.safetensors
191
+ | |-- model-00014-of-00015.safetensors
192
+ | |-- model-00015-of-00015.safetensors
193
+ | |-- model.safetensors-deprecated
194
+ | |-- model.safetensors.index.json
195
+ | |-- preprocessor_config.json
196
+ | |-- processor_config.json
197
+ | |-- special_tokens_map.json
198
+ | |-- tokenizer.json
199
+ | |-- tokenizer.model
200
+ | `-- tokenizer_config.json
201
+ |-- pllava-7b
202
+ |-- added_tokens.json
203
+ |-- config.json
204
+ |-- generation_config.json
205
+ |-- model-00001-of-00003.safetensors
206
+ |-- model-00002-of-00003.safetensors
207
+ |-- model-00003-of-00003.safetensors
208
+ |-- model.safetensors.index.json
209
+ |-- preprocessor_config.json
210
+ |-- processor_config.json
211
+ |-- special_tokens_map.json
212
+ |-- tokenizer.json
213
+ |-- tokenizer.model
214
+ `-- tokenizer_config.json
215
+ ```
216
+
217
+ With the above steps, you should be able to proceed on with the following usages.
218
+
219
+ ### Run Application
220
+
221
+ To run our models, make sure you have downloaded a model pretrained weights from the huggingface spaces. Then, run the following scripts with the corresponding path input. Since we are only training with lora and the projector, the model to be run are determined with:
222
+
223
+ - **model_dir**: model directory, one with config.json as compatible with transformers. This refers to the base model's directory, for example "llava-hf/llava-v1.6-vicuna-7b-hf"/"ermu2001/pllava-7b"/"MODELS/pllava-7b". (default to: MODELS/plave-7b)
224
+ - **weights_dir**: your weights directory. could be the same as model_dir, but if you have a weights directory for the lora weights, you should set this weights_dir to that directory to load the lora weights. This directory should be local. Also, it would need to contain a config.json file within. (default to: ${model_dir}).
225
+
226
+ ```bash
227
+ model_dir="model directory"
228
+ weights_dir="weights directory"
229
+ bash scripts/demo.sh ${model_dir} ${weights_dir}
230
+ ```
231
+
232
+ Now check out the application demo and try play with PLLAVA!
233
+
234
+ ### Train
235
+
236
+ Follow the following steps to reproduce our results or train your own variant:
237
+
238
+ #### 1. Data Preparation
239
+
240
+ To train our model from a starting Image-aligned Vision LLM, you would need to download the data first. Our data set up is mainly based on the original Videochat2's training data. Check out [Instruction Data](./DATA.md) to prepare the instruction training data. Ideally, setting up a root data directory and alter the code [here](./tasks/train/instruction_data.py#L6) would accomodate the data for training most smoothly.
241
+
242
+ #### 2. Start Training
243
+
244
+ Now you're only a few step away from starting the training. Follow the instructions:
245
+
246
+ ##### Setup Accelerator
247
+
248
+ Customize a accelerate training config. For example, a simple config using multiple gpus with no distribution strategy (only torch DDP) would look like:
249
+
250
+ ```yaml
251
+ compute_environment: LOCAL_MACHINE
252
+ debug: false
253
+ distributed_type: MULTI_GPU
254
+ downcast_bf16: 'no'
255
+ gpu_ids: all
256
+ machine_rank: 0
257
+ main_training_function: main
258
+ mixed_precision: bf16
259
+ num_machines: 1
260
+ num_processes: 8
261
+ rdzv_backend: static
262
+ same_network: true
263
+ tpu_env: []
264
+ tpu_use_cluster: false
265
+ tpu_use_sudo: false
266
+ use_cpu: false
267
+ ```
268
+
269
+ Check out out the [Accelerate](https://huggingface.co/docs/accelerate/index) documents for more details.
270
+
271
+ ##### Overwatch the training configuration
272
+
273
+ Next, you should go over a basic training configuration of the training process in [here](tasks/train/config_pllava_nframe.py). Then passing this file as the first arg to the [training script](tasks/train/train_pllava_nframe_accel.py) would utilize every arguments in the file. You can customize some of the hyper parameters for your own training process by passing them in the format of "key" "value" pair in the following arguments. A example training scripts could be find [here](scripts/train_pllava.sh).
274
+
275
+ We recommand customize a [configuration](tasks/train/config_pllava_nframe.py) to set up a customized training!
276
+
277
+ With the above steps, you would be able to start the training process. The output would be well organized in the output directory, each a qualified model directory to pass in to demo as weights_dir, since we are only saveing the lora weights and projector weights to avoide redundancy.
278
+
279
+ ### Evaluation
280
+
281
+ This section mainly introduce how to reproduce the evaluation or evaluate your own model.
282
+
283
+ #### Set up Evaluation Data
284
+
285
+ Make sure you set up the "DATAS" directory as in [DATA.md](DATA.md), then you would be able to run the inference with fortune! The evaluation data directory of DATAS would look like:
286
+
287
+ ```
288
+ DATAS/:
289
+ DATAS/VideoQA:
290
+ DATAS/VideoQA/TGIF_QA:
291
+ test_a.json
292
+ test_q.json
293
+ DATAS/VideoQA/TGIF_QA/videos:
294
+ tumblr_m4387mGrlc1r6m5e8o1_250.gif
295
+ ...
296
+ DATAS/VideoQA/TGIF_QA/videos_mp4:
297
+ tumblr_m4387mGrlc1r6m5e8o1_250.mp4
298
+ ...
299
+ DATAS/VideoQA/TGIF_QA/video_gif:
300
+ tumblr_m4387mGrlc1r6m5e8o1_250.gif
301
+ ...
302
+ DATAS/VideoQA/MSVD_Zero_Shot_QA:
303
+ test_a.json
304
+ test_q.json
305
+ DATAS/VideoQA/MSVD_Zero_Shot_QA/videos:
306
+ -4wsuPCjDBc_5_15.avi
307
+ DATAS/VideoQA/MSVD_Zero_Shot_QA/msvd_qa:
308
+ DATAS/VideoQA/ActivityNet:
309
+ test_a.json
310
+ test_q.json
311
+ DATAS/VideoQA/ActivityNet/all_test:
312
+ v_--tFD65KaK4.mp4
313
+ ...
314
+ DATAS/VideoQA/MSRVTT_Zero_Shot_QA:
315
+ test_a.json
316
+ test_q.json
317
+ DATAS/VideoQA/MSRVTT_Zero_Shot_QA/videos:
318
+ DATAS/VideoQA/MSRVTT_Zero_Shot_QA/videos/all:
319
+ video0.mp4
320
+ ...
321
+
322
+ DATAS/MVBench:
323
+ ...
324
+
325
+ DATAS/Recaption/Inter4K:
326
+ annotations.json
327
+ DATAS/Recaption/Inter4K/60fps:
328
+ DATAS/Recaption/Inter4K/60fps/UHD:
329
+ 1.mp4
330
+ ...
331
+
332
+ ```
333
+
334
+ #### Start Evaluate
335
+
336
+ Once you have construted the evaluation data, you can start the evaluation as in [here](scripts/eval.sh). This script is for evaluating 7B/13B models. As pllava-34b model uses a slightly different prompting, it is evaluated with this [script](scripts/eval_yiprompt.sh).
337
+
338
+ ```
339
+ bash scripts/eval.sh
340
+ ```
341
+
342
+ Same as running the demo, you would need to determine the model_dir and weights_dir to evaluate the model. Feel free to comment out some commands and produce partial evaluation.
343
+
344
+ #### Overwatch the Results
345
+
346
+ The evaluation results would be shown to you with our results gallery demo:
347
+
348
+ ```bash
349
+ bash scripts/gallery.sh
350
+ ```
351
+
352
+ Feel free to use the compare version to compare differnt models' results or use the single gallery version to check out one model's results. They are basically the same. Check out this [script](scripts/gallery.sh) for more details
353
+
354
+ #### For Captioning and Recaptioning
355
+ Follow instructions at [DATA.md](DATA.md#extending-reacptioning) and you can extend the recaptioning data with a few steps.
356
+
357
+ Feel free to point out high quality dataset of videos, we would proceed on doing captioning on those datasets.
358
+
359
+
360
+ # :page_facing_up: Citation
361
+
362
+ If you find this project useful in your research, please consider cite:
363
+
364
+ ```BibTeX
365
+ @misc{xu2024pllava,
366
+ title={PLLaVA : Parameter-free LLaVA Extension from Images to Videos for Video Dense Captioning},
367
+ author={Lin Xu and Yilin Zhao and Daquan Zhou and Zhijie Lin and See Kiong Ng and Jiashi Feng},
368
+ year={2024},
369
+ eprint={2404.16994},
370
+ archivePrefix={arXiv},
371
+ primaryClass={cs.CV}
372
+ }
373
+ ```
374
+
375
+ # :dizzy: Acknowledgement
376
+
377
+ This code base is mainly built upon [Videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2). SALUTE.
378
+
379
+ We would also like to recognize and commend the following open source projects, thank you for your great contribution to the open source community:
380
+
381
+ - [LLaVA](https://github.com/haotian-liu/LLaVA): Fantastic Open Source Image LLM Model.
382
+ - [VideoChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main): Great Evaluation Benchmarking Framework.
383
+ - [VideoLlava](https://github.com/PKU-YuanGroup/Video-LLaVA/tree/main/videollava):Video LLM repo with helpful resources.
app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from huggingface_hub import snapshot_download
3
+ snapshot_download(
4
+ 'ermu2001/pllava-7b',
5
+ local_dir='MODELS/pllava-7b',
6
+ repo_type='model',
7
+ local_dir_use_symlinks=True,
8
+ )
9
+
10
+ sys.argv.extend([
11
+ "--pretrained_model_name_or_path", "MODELS/pllava-7b",
12
+ "--num_frames", "16",
13
+ "--use_lora",
14
+ "--weight_dir", "MODELS/pllava-7b",
15
+ "--lora_alpha", "4",
16
+ "--conv_mode", "plain",
17
+ ])
18
+ import tasks.eval.demo.pllava_demo
assert/data.png ADDED

Git LFS Details

  • SHA256: 72bd5fa48454bfcb6ee1c5b26c3baffd2397502a27bb666860069f0a5755a51b
  • Pointer size: 131 Bytes
  • Size of remote file: 224 kB
assert/logo.png ADDED

Git LFS Details

  • SHA256: df1ae4a260b20b749eaaef02d9bad7057cbba958fff92e23e28d1d3b91224668
  • Pointer size: 132 Bytes
  • Size of remote file: 1.32 MB
assert/module.png ADDED

Git LFS Details

  • SHA256: 7933116caeb3552590bc80c543f37456261dcb9984d75a6f81555f4d38ccfa65
  • Pointer size: 131 Bytes
  • Size of remote file: 226 kB
assert/performance.png ADDED

Git LFS Details

  • SHA256: 9bced5f433da0a6424d8bd1bd776f6cb16407ae94d5cf2fbc09ba09e407c37ac
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
assert/teaser.jpg ADDED

Git LFS Details

  • SHA256: f204476020f3995d37a5f7c5b341f8eb739cbb0b5e1e529a8c4e722e5976de54
  • Pointer size: 131 Bytes
  • Size of remote file: 372 kB
assert/zeroshot.png ADDED

Git LFS Details

  • SHA256: d6ee8e95e824759b2f93d63db9c4c57f81775576c8b2932b875dd4176b702dab
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
dataset/__init__.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import ConcatDataset, DataLoader
3
+ from torchvision import transforms
4
+ from torchvision.transforms import InterpolationMode
5
+ from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset
6
+
7
+
8
+ def get_media_type(dataset_config):
9
+ if len(dataset_config) == 3 and dataset_config[2] == "video":
10
+ return "video"
11
+ elif dataset_config[-1] == "only_video":
12
+ return "only_video"
13
+ else:
14
+ return "image"
15
+
16
+
17
+ def create_dataset(dataset_type, config):
18
+ if "clip" in config.model.get("vit_model", 'vit'):
19
+ mean = (0.485, 0.456, 0.406)
20
+ std = (0.229, 0.224, 0.225)
21
+ else:
22
+ vision_enc_name = config.model.vision_encoder.name
23
+ if "swin" in vision_enc_name or "vit" in vision_enc_name:
24
+ mean = (0.485, 0.456, 0.406)
25
+ std = (0.229, 0.224, 0.225)
26
+ elif "beit" in vision_enc_name:
27
+ mean = (0.5, 0.5, 0.5) # for all beit model except IN1K finetuning
28
+ std = (0.5, 0.5, 0.5)
29
+ elif "clip" in vision_enc_name:
30
+ mean = (0.48145466, 0.4578275, 0.40821073)
31
+ std = (0.26862954, 0.26130258, 0.27577711)
32
+ else:
33
+ raise ValueError
34
+
35
+ normalize = transforms.Normalize(mean, std)
36
+
37
+ # loaded images and videos are torch.Tensor of torch.uint8 format,
38
+ # ordered as (T, 1 or 3, H, W) where T=1 for image
39
+ type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
40
+
41
+ if config.inputs.video_input.random_aug:
42
+ aug_transform = transforms.RandAugment()
43
+ else:
44
+ aug_transform = transforms.Lambda(lambda x: x)
45
+
46
+ train_transform = transforms.Compose(
47
+ [
48
+ aug_transform,
49
+ transforms.RandomResizedCrop(
50
+ config.inputs.image_res,
51
+ scale=(0.5, 1.0),
52
+ interpolation=InterpolationMode.BICUBIC,
53
+ ),
54
+ transforms.RandomHorizontalFlip(),
55
+ type_transform,
56
+ normalize,
57
+ ]
58
+ )
59
+ test_transform = transforms.Compose(
60
+ [
61
+ transforms.Resize(
62
+ (config.inputs.image_res, config.inputs.image_res),
63
+ interpolation=InterpolationMode.BICUBIC,
64
+ ),
65
+ type_transform,
66
+ normalize,
67
+ ]
68
+ )
69
+
70
+ video_reader_type = config.inputs.video_input.get("video_reader_type", "decord")
71
+ video_only_dataset_kwargs_train = dict(
72
+ video_reader_type=video_reader_type,
73
+ sample_type=config.inputs.video_input.sample_type,
74
+ num_frames=config.inputs.video_input.num_frames,
75
+ num_tries=3, # false tolerance
76
+ )
77
+
78
+ if dataset_type == "pt_train":
79
+ raise ValueError("NOT PRETRAINING YET")
80
+ elif dataset_type in ["it_train"]:
81
+ # convert to list of lists
82
+ train_files = (
83
+ [config.train_file] if isinstance(config.train_file[0], str) else config.train_file
84
+ )
85
+ train_media_types = sorted(list({get_media_type(e) for e in train_files}))
86
+
87
+ train_datasets = []
88
+ for m in train_media_types:
89
+ dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset
90
+ # dataset of the same media_type will be mixed in a single Dataset object
91
+ _train_files = [e for e in train_files if get_media_type(e) == m]
92
+
93
+ datasets = []
94
+ for train_file in _train_files:
95
+ dataset_kwargs = dict(
96
+ ann_file=train_file,
97
+ transform=train_transform,
98
+ mm_alone=config.preprocess.get("mm_alone", True),
99
+ add_second_msg=config.preprocess.get("add_second_msg", True),
100
+ skip_short_sample=config.preprocess.get("skip_short_sample", False),
101
+ clip_transform=config.preprocess.get("clip_transform", False),
102
+ random_shuffle=config.preprocess.get("random_shuffle", True),
103
+ system=config.preprocess.get("system", ""),
104
+ role=config.preprocess.get('roles', ("Human", "Assistant")),
105
+ end_signal=config.preprocess.get('end_signal', "###"),
106
+ begin_signal=config.preprocess.get('begin_signal', ""),
107
+ )
108
+ if m == "video":
109
+ video_only_dataset_kwargs_train.update({
110
+ "start_token": config.model.get("start_token", "<Video>"),
111
+ "end_token": config.model.get("end_token", "</Video>"),
112
+ })
113
+ dataset_kwargs.update(video_only_dataset_kwargs_train)
114
+ if "tgif" in train_file[1]:
115
+ video_only_dataset_kwargs_train.update({
116
+ "video_reader_type": "gif"
117
+ })
118
+ dataset_kwargs.update(video_only_dataset_kwargs_train)
119
+ elif "webvid" in train_file[1]:
120
+ video_only_dataset_kwargs_train.update({
121
+ "video_reader_type": "hdfs"
122
+ })
123
+ else:
124
+ video_only_dataset_kwargs_train.update({
125
+ "video_reader_type": "decord"
126
+ })
127
+ dataset_kwargs.update(video_only_dataset_kwargs_train)
128
+ datasets.append(dataset_cls(**dataset_kwargs))
129
+ dataset = ConcatDataset(datasets)
130
+ train_datasets.append(dataset)
131
+ return train_datasets
132
+
133
+
134
+ def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
135
+ loaders = []
136
+ for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
137
+ datasets, samplers, batch_size, num_workers, is_trains, collate_fns
138
+ ):
139
+ if is_train:
140
+ shuffle = sampler is None
141
+ drop_last = True
142
+ else:
143
+ shuffle = False
144
+ drop_last = False
145
+ loader = DataLoader(
146
+ dataset,
147
+ batch_size=bs,
148
+ num_workers=n_worker,
149
+ pin_memory=False,
150
+ sampler=sampler,
151
+ shuffle=shuffle,
152
+ collate_fn=collate_fn,
153
+ drop_last=drop_last,
154
+ persistent_workers=True if n_worker > 0 else False,
155
+ )
156
+ loaders.append(loader)
157
+ return loaders
158
+
dataset/base_dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import json
4
+ import random
5
+ from torch.utils.data import Dataset
6
+ import time
7
+ from dataset.utils import load_image_from_path
8
+
9
+ try:
10
+ from petrel_client.client import Client
11
+ has_client = True
12
+ except ImportError:
13
+ has_client = False
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ImageVideoBaseDataset(Dataset):
19
+ """Base class that implements the image and video loading methods"""
20
+
21
+ media_type = "video"
22
+
23
+ def __init__(self):
24
+ assert self.media_type in ["image", "video", "only_video"]
25
+ self.data_root = None
26
+ self.anno_list = (
27
+ None # list(dict), each dict contains {"image": str, # image or video path}
28
+ )
29
+ self.transform = None
30
+ self.video_reader = None
31
+ self.num_tries = None
32
+
33
+ self.client = None
34
+ if has_client:
35
+ self.client = Client('~/petreloss.conf')
36
+
37
+ def __getitem__(self, index):
38
+ raise NotImplementedError
39
+
40
+ def __len__(self):
41
+ raise NotImplementedError
42
+
43
+ def get_anno(self, index):
44
+ """obtain the annotation for one media (video or image)
45
+
46
+ Args:
47
+ index (int): The media index.
48
+
49
+ Returns: dict.
50
+ - "image": the filename, video also use "image".
51
+ - "caption": The caption for this file.
52
+
53
+ """
54
+ anno = self.anno_list[index]
55
+ if self.data_root is not None:
56
+ anno["image"] = os.path.join(self.data_root, anno["image"])
57
+ return anno
58
+
59
+ def load_and_transform_media_data(self, index, data_path):
60
+ if self.media_type == "image":
61
+ return self.load_and_transform_media_data_image(index, data_path, clip_transform=self.clip_transform)
62
+ else:
63
+ return self.load_and_transform_media_data_video(index, data_path, clip_transform=self.clip_transform)
64
+
65
+ def load_and_transform_media_data_image(self, index, data_path, clip_transform=False):
66
+ image = load_image_from_path(data_path, client=self.client)
67
+ if not clip_transform:
68
+ image = self.transform(image)
69
+ return image, index
70
+
71
+ def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None, clip_transform=False):
72
+ for _ in range(self.num_tries):
73
+ try:
74
+ max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1
75
+ if "webvid" in data_path:
76
+ hdfs_dir="hdfs://harunava/home/byte_ailab_us_cvg/user/weimin.wang/videogen_data/webvid_data/10M_full_train"
77
+ video_name = os.path.basename(data_path)
78
+ video_id, extension = os.path.splitext(video_name)
79
+ ind_file = os.path.join(hdfs_dir, self.keys_indexfile[video_id])
80
+ frames, frame_indices, fps = self.video_reader(ind_file, video_id, self.num_frames, self.sample_type,
81
+ max_num_frames=max_num_frames, client=self.client, clip=clip)
82
+ else:
83
+ frames, frame_indices, fps = self.video_reader(
84
+ data_path, self.num_frames, self.sample_type,
85
+ max_num_frames=max_num_frames, client=self.client, clip=clip
86
+ )
87
+ except Exception as e:
88
+ logger.warning(
89
+ f"Caught exception {e} when loading video {data_path}, "
90
+ f"randomly sample a new video as replacement"
91
+ )
92
+ index = random.randint(0, len(self) - 1)
93
+ ann = self.get_anno(index)
94
+ data_path = ann["image"]
95
+ continue
96
+ # shared aug for video frames
97
+ if not clip_transform:
98
+ frames = self.transform(frames)
99
+ if return_fps:
100
+ sec = [str(round(f / fps, 1)) for f in frame_indices]
101
+ return frames, index, sec
102
+ else:
103
+ return frames, index
104
+ else:
105
+ raise RuntimeError(
106
+ f"Failed to fetch video after {self.num_tries} tries. "
107
+ f"This might indicate that you have many corrupted videos."
108
+ )
dataset/it_dataset.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import json
4
+ import sqlite3
5
+ import random
6
+ from os.path import basename
7
+
8
+ import numpy as np
9
+ import datetime
10
+
11
+ from dataset.base_dataset import ImageVideoBaseDataset
12
+ from dataset.video_utils import VIDEO_READER_FUNCS
13
+
14
+ logger = logging.getLogger(__name__)
15
+ IMAGE_TOKEN="<image>"
16
+
17
+ class ITImgTrainDataset(ImageVideoBaseDataset):
18
+ media_type = "image"
19
+
20
+ def __init__(
21
+ self, ann_file, transform,
22
+ system="", role=("Human", "Assistant"),
23
+ mm_alone=True,
24
+ add_second_msg=True,
25
+ start_token="<Image>", end_token="</Image>",
26
+ random_shuffle=True, # if True, shuffle the QA list ##xl:????? why need random shuffle
27
+ begin_signal=None,
28
+ end_signal=None,
29
+ clip_transform=False,
30
+ skip_short_sample=False,
31
+ ):
32
+ super().__init__()
33
+ self.mm_alone = mm_alone
34
+ self.clip_transform = clip_transform
35
+ if len(ann_file) == 3 and ann_file[2] == "video":
36
+ self.media_type = "video"
37
+ else:
38
+ self.media_type = "image"
39
+ self.label_file, self.data_root = ann_file[:2]
40
+
41
+ logger.info('Load json file')
42
+ with open(self.label_file, 'r') as f:
43
+ self.anno = json.load(f)
44
+ self.num_examples = len(self.anno)
45
+ self.transform = transform
46
+ annos = []
47
+ for ann in self.anno:
48
+ filename = ann['video'] if 'video' in ann else ann['image']
49
+ if self.media_type =='video' and "webvid" in self.data_root:
50
+ video_id, extension = os.path.splitext(os.path.basename(filename))
51
+ if video_id not in self.keys_indexfile:
52
+ pass
53
+ else:
54
+ annos.append(ann)
55
+ else:
56
+
57
+ if filename is None or filename=="None":
58
+ pass
59
+ else:
60
+ if os.path.exists(os.path.join(self.data_root, filename)):
61
+ annos.append(ann)
62
+ else:
63
+ ...
64
+ self.anno = annos
65
+ self.num_examples = len(self.anno)
66
+
67
+
68
+ # prompt parameters
69
+ if system:
70
+ assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token."
71
+ # currently not support add start_token and end_token in the system, since the msg should be added properly
72
+ self.begin_signal = [begin_signal for _ in role] if isinstance(begin_signal, str) else begin_signal
73
+ self.end_signal = [end_signal for _ in role] if isinstance(end_signal, str) else end_signal
74
+ self.start_token = start_token
75
+ self.end_token = end_token
76
+ self.system = system
77
+ self.role = role
78
+ self.random_shuffle = random_shuffle
79
+ # instruction location and number
80
+ logger.info(f"Random shuffle: {self.random_shuffle}")
81
+
82
+ def get_anno(self, index):
83
+ filename = self.anno[index][self.media_type]
84
+ qa = self.anno[index]["QA"]
85
+
86
+ if "start" in self.anno[index] and "end" in self.anno[index]:
87
+ anno = {
88
+ "image": os.path.join(self.data_root, filename), "qa": qa,
89
+ "start": self.anno[index]["start"], "end": self.anno[index]["end"],
90
+ }
91
+ else:
92
+ anno = {"image": os.path.join(self.data_root, filename), "qa": qa}
93
+ return anno
94
+
95
+ def __len__(self):
96
+ return self.num_examples
97
+
98
+ def process_qa(self, qa, msg=""):
99
+ cur_instruction = ""
100
+ # randomly shuffle qa for conversation
101
+ if self.random_shuffle and len(qa) > 1:
102
+ random.shuffle(qa)
103
+ if "i" in qa[0].keys() and qa[0]["i"] != "":
104
+ cur_instruction = qa[0]["i"] + self.end_signal[0]
105
+
106
+ conversation = self.system
107
+ # add instruction as system message
108
+ if cur_instruction:
109
+ conversation += cur_instruction
110
+
111
+ # rstrip() for the extra " " in msg
112
+ if self.mm_alone:
113
+ conversation += (
114
+ self.begin_signal[0] + self.role[0] +
115
+ self.start_token + self.end_token + msg.rstrip() + self.end_signal[0]
116
+ )
117
+
118
+ for i, sentence in enumerate(qa):
119
+ q = self.start_token + self.end_token+"\n"+ qa[0]["q"] if (not self.mm_alone) and (i == 0) else sentence["q"]
120
+ a = sentence["a"]
121
+ if q != "":
122
+ conversation += (self.begin_signal[0] + self.role[0] + q + self.end_signal[1])
123
+ else:
124
+ # no question, often in caption dataset
125
+ pass
126
+ conversation += (self.begin_signal[0] + self.role[1] + a + self.end_signal[1])
127
+
128
+
129
+ if cur_instruction:
130
+ cur_instruction += qa[0]["q"]
131
+ return conversation, cur_instruction.strip()
132
+
133
+ def __getitem__(self, index):
134
+ try:
135
+ ann = self.get_anno(index)
136
+ image, index = self.load_and_transform_media_data_image(index, ann["image"], clip_transform=self.clip_transform)
137
+ conversation, instruction = self.process_qa(ann["qa"])
138
+ return image, conversation, instruction, index
139
+ except Exception as e:
140
+ logger.warning(f"Caught exception {e} when loading image {ann['image']}")
141
+ index = np.random.randint(0, len(self))
142
+ return self.__getitem__(index)
143
+
144
+
145
+ class ITVidTrainDataset(ITImgTrainDataset):
146
+ media_type = "video"
147
+
148
+ def __init__(
149
+ self, ann_file, transform,
150
+ num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3,
151
+ mm_alone=True,
152
+ system="", role=("Human", "Assistant"),
153
+ start_token="<Video>", end_token="</Video>",
154
+ add_second_msg=True,
155
+ random_shuffle=True,
156
+ begin_signal=None,
157
+ end_signal=None,
158
+ clip_transform=False,
159
+ skip_short_sample=False,
160
+
161
+ ):
162
+ # "id index file for webvid"
163
+ if "webvid" in ann_file[1]:
164
+ with open("/mnt/bn/dq-storage-ckpt/xulin/datasets/videos/webvid_10m/keys_indexfile.json") as f:
165
+ self.keys_indexfile = json.load(f) # the correponding index file for each webvid id
166
+
167
+ super().__init__(
168
+ ann_file, transform,
169
+ system=system, role=role,
170
+ mm_alone=mm_alone,
171
+ start_token=start_token, end_token=end_token,
172
+ random_shuffle=random_shuffle,
173
+ begin_signal=begin_signal,
174
+ end_signal=end_signal,
175
+ clip_transform=clip_transform,
176
+ skip_short_sample=skip_short_sample,
177
+ )
178
+ self.num_frames = num_frames
179
+ self.video_reader_type = video_reader_type
180
+ self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
181
+ self.sample_type = sample_type
182
+ self.num_tries = num_tries
183
+ self.add_second_msg = add_second_msg
184
+
185
+ logger.info(f"Use {video_reader_type} for data in {ann_file}")
186
+ if add_second_msg:
187
+ logger.info(f"Add second message: The video contains X frames sampled at T seconds.")
188
+
189
+ def __getitem__(self, index):
190
+ try:
191
+ ann = self.get_anno(index)
192
+
193
+ msg = ""
194
+ clip = None
195
+ if "start" in ann and "end" in ann:
196
+ clip = [ann["start"], ann["end"]]
197
+ video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip, clip_transform=self.clip_transform)
198
+ if self.add_second_msg:
199
+ # " " should be added in the start and end
200
+ msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. "
201
+ conversation, instruction = self.process_qa(ann["qa"], msg)
202
+ return video, conversation, instruction, index
203
+ except Exception as e:
204
+ logger.warning(f"Caught exception {e} when loading video {ann['image']}")
205
+ index = np.random.randint(0, len(self))
206
+ return self.__getitem__(index)
dataset/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.distributed import is_main_process, get_rank, get_world_size
2
+ import io
3
+ import json
4
+ import re
5
+ import numpy as np
6
+ from os.path import join
7
+ from tqdm import trange
8
+ from PIL import Image
9
+ from PIL import ImageFile
10
+ from torchvision.transforms import PILToTensor
11
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
12
+ Image.MAX_IMAGE_PIXELS = None
13
+
14
+
15
+ def load_image_from_path(image_path, client):
16
+ if image_path.startswith('s3') or image_path.startswith('p2'):
17
+ value = client.Get(image_path)
18
+ img_bytes = np.frombuffer(value, dtype=np.uint8)
19
+ buff = io.BytesIO(img_bytes)
20
+ image = Image.open(buff).convert('RGB')
21
+ else:
22
+ image = Image.open(image_path).convert('RGB') # PIL Image
23
+ image = PILToTensor()(image).unsqueeze(0) # (1, C, H, W), torch.uint8
24
+ return image
25
+
26
+ def pre_text(text, max_l=None, pre_text=True):
27
+ if pre_text:
28
+ text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower())
29
+ text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
30
+
31
+ text = re.sub(r"\s{2,}", ' ', text)
32
+ text = text.rstrip('\n').strip(' ')
33
+
34
+ if max_l: # truncate
35
+ words = text.split(' ')
36
+ if len(words) > max_l:
37
+ text = ' '.join(words[:max_l])
38
+ else:
39
+ pass
40
+ return text
41
+
dataset/video_utils.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
3
+ """
4
+ import random
5
+ import io
6
+ import os
7
+ import av
8
+ import cv2
9
+ import decord
10
+ import imageio
11
+ from decord import VideoReader
12
+
13
+ # from dataloader import KVReader
14
+ import torch
15
+ import numpy as np
16
+ import math
17
+ # import tensorflow as tf
18
+ decord.bridge.set_bridge("torch")
19
+
20
+ import logging
21
+ logger = logging.getLogger(__name__)
22
+
23
+ def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
24
+ """
25
+ Converts a present time with the given time base and start_pts offset to seconds.
26
+
27
+ Returns:
28
+ time_in_seconds (float): The corresponding time in seconds.
29
+
30
+ https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
31
+ """
32
+ if pts == math.inf:
33
+ return math.inf
34
+
35
+ return int(pts - start_pts) * time_base
36
+
37
+
38
+ def get_pyav_video_duration(video_reader):
39
+ video_stream = video_reader.streams.video[0]
40
+ video_duration = pts_to_secs(
41
+ video_stream.duration,
42
+ video_stream.time_base,
43
+ video_stream.start_time
44
+ )
45
+ return float(video_duration)
46
+
47
+
48
+ def get_frame_indices_by_fps():
49
+ pass
50
+
51
+
52
+ def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
53
+ if sample in ["rand", "middle"]: # uniform sampling
54
+ acc_samples = min(num_frames, vlen)
55
+ # split the video into `acc_samples` intervals, and sample from each interval.
56
+ intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
57
+ ranges = []
58
+ for idx, interv in enumerate(intervals[:-1]):
59
+ ranges.append((interv, intervals[idx + 1] - 1))
60
+ if sample == 'rand':
61
+ try:
62
+ frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
63
+ except:
64
+ frame_indices = np.random.permutation(vlen)[:acc_samples]
65
+ frame_indices.sort()
66
+ frame_indices = list(frame_indices)
67
+ elif fix_start is not None:
68
+ frame_indices = [x[0] + fix_start for x in ranges]
69
+ elif sample == 'middle':
70
+ frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
71
+ else:
72
+ raise NotImplementedError
73
+
74
+ if len(frame_indices) < num_frames: # padded with last frame
75
+ padded_frame_indices = [frame_indices[-1]] * num_frames
76
+ padded_frame_indices[:len(frame_indices)] = frame_indices
77
+ frame_indices = padded_frame_indices
78
+ elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
79
+ output_fps = float(sample[3:])
80
+ duration = float(vlen) / input_fps
81
+ delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
82
+ frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
83
+ frame_indices = np.around(frame_seconds * input_fps).astype(int)
84
+ frame_indices = [e for e in frame_indices if e < vlen]
85
+ if max_num_frames > 0 and len(frame_indices) > max_num_frames:
86
+ frame_indices = frame_indices[:max_num_frames]
87
+ # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
88
+ else:
89
+ raise ValueError
90
+ return frame_indices
91
+
92
+
93
+ def read_frames_av(
94
+ video_path, num_frames, sample='rand', fix_start=None,
95
+ max_num_frames=-1, client=None, clip=None,
96
+ ):
97
+ reader = av.open(video_path)
98
+ frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]
99
+ vlen = len(frames)
100
+ duration = get_pyav_video_duration(reader)
101
+ fps = vlen / float(duration)
102
+ frame_indices = get_frame_indices(
103
+ num_frames, vlen, sample=sample, fix_start=fix_start,
104
+ input_fps=fps, max_num_frames=max_num_frames
105
+ )
106
+ frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8
107
+ frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
108
+ return frames, frame_indices, fps
109
+
110
+
111
+ def read_frames_gif(
112
+ video_path, num_frames, sample='rand', fix_start=None,
113
+ max_num_frames=-1, client=None, clip=None,
114
+ ):
115
+ if video_path.startswith('s3') or video_path.startswith('p2'):
116
+ video_bytes = client.get(video_path)
117
+ gif = imageio.get_reader(io.BytesIO(video_bytes))
118
+ else:
119
+ gif = imageio.get_reader(video_path)
120
+ vlen = len(gif)
121
+ frame_indices = get_frame_indices(
122
+ num_frames, vlen, sample=sample, fix_start=fix_start,
123
+ max_num_frames=max_num_frames
124
+ )
125
+ frames = []
126
+ for index, frame in enumerate(gif):
127
+ # for index in frame_idxs:
128
+ if index in frame_indices:
129
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
130
+ frame = torch.from_numpy(frame).byte()
131
+ # # (H x W x C) to (C x H x W)
132
+ frame = frame.permute(2, 0, 1)
133
+ frames.append(frame)
134
+ frames = torch.stack(frames) # .float() / 255
135
+
136
+ return frames, frame_indices, 25. # for tgif
137
+
138
+
139
+ def read_frames_hdfs(ind_file, vid, num_frames, sample='rand',fix_start=None,
140
+ max_num_frames=-1, client=None, clip=None):
141
+ _context_features = {'title': tf.io.FixedLenFeature([], dtype=tf.string)}
142
+ _sequence_features = {'data': tf.io.FixedLenSequenceFeature([], dtype=tf.string)}
143
+ num_parallel_reader = 1
144
+ filename, extension = os.path.splitext(ind_file)
145
+ reader = KVReader(filename, num_parallel_reader)
146
+ key = vid
147
+ values = reader.read_many([key])
148
+ item = values[0]
149
+ contexts, sequences = tf.io.parse_single_sequence_example(
150
+ serialized=item,
151
+ context_features=_context_features,
152
+ sequence_features=_sequence_features)
153
+
154
+ # text = contexts['title'].numpy().decode("utf-8")
155
+ rawframes = sequences['data']
156
+ vlen = len(rawframes)
157
+ sample="rand"
158
+
159
+ frame_indices = get_frame_indices(num_frames, vlen, sample=sample,
160
+ fix_start=fix_start,
161
+ max_num_frames=max_num_frames)
162
+ def read_image(raw_data):
163
+ return tf.image.decode_jpeg(raw_data, channels=3, dct_method='INTEGER_ACCURATE').numpy()
164
+
165
+ frames = []
166
+ for index, frame in enumerate(rawframes):
167
+ if index in frame_indices:
168
+ frame = read_image(frame)
169
+ frame = torch.as_tensor(frame)
170
+ frames.append(frame)
171
+
172
+ frames = torch.stack(frames)
173
+ # print("in hdfs========>",frames[0])
174
+ frames = frames.permute(0, 3, 1, 2)
175
+ return frames, frame_indices, 25 # don't know the fps for index
176
+
177
+
178
+ def read_frames_decord(
179
+ video_path, num_frames, sample='rand', fix_start=None,
180
+ max_num_frames=-1, client=None, clip=None
181
+ ):
182
+ if video_path.startswith('s3') or video_path.startswith('p2'):
183
+ video_bytes = client.get(video_path)
184
+ video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
185
+ else:
186
+ video_reader = VideoReader(video_path, num_threads=1)
187
+ vlen = len(video_reader)
188
+ fps = video_reader.get_avg_fps()
189
+ duration = vlen / float(fps)
190
+
191
+ if clip:
192
+ start, end = clip
193
+ duration = end - start
194
+ vlen = int(duration * fps)
195
+ start_index = int(start * fps)
196
+
197
+ frame_indices = get_frame_indices(
198
+ num_frames, vlen, sample=sample, fix_start=fix_start,
199
+ input_fps=fps, max_num_frames=max_num_frames
200
+ )
201
+ if clip:
202
+ frame_indices = [f + start_index for f in frame_indices]
203
+
204
+ frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
205
+ frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
206
+ return frames, frame_indices, float(fps)
207
+
208
+
209
+ VIDEO_READER_FUNCS = {
210
+ 'av': read_frames_av,
211
+ 'decord': read_frames_decord,
212
+ 'gif': read_frames_gif,
213
+ 'hdfs': read_frames_hdfs,
214
+ }
docs/PoolLLaVA_Report.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b9f175bd915cdc6f9791a95149992fde1f48ebfffa6c8bff9e6365b7186c57d
3
+ size 3850702
example/1917.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99f5f2a10985964ddc0555a8fa12b9d41f130b49ad62879a9e150d91834e93d5
3
+ size 1535936
example/bear.jpg ADDED

Git LFS Details

  • SHA256: 286b3a5693322edf01870a561e35016ed46a7cb4b9194c58e2f3526eab1f9efc
  • Pointer size: 131 Bytes
  • Size of remote file: 376 kB
example/cooking.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a1395530cc13c0441ae99ce66477f533f6009ebdb913064aec91e38eaf3b8e9
3
+ size 876622
example/dog.png ADDED

Git LFS Details

  • SHA256: 919b6e24d3cc7d7998181029fb76e94d8149e6a9d2c4930445fa217f6715716d
  • Pointer size: 131 Bytes
  • Size of remote file: 563 kB
example/jesse_dance.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1fc41c6ebae0692726ea56b33ba711f21186fd4203ac54cd43a5cd898be4350
3
+ size 1221420
example/working.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09372cdb6b0ea272868b4469d5067674670a948962f1236196e8f23e1f7ce764
3
+ size 4718899
example/yoga.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74b65d9bec7f83e487b7f923076c01d476dd2ef7ed83928a696ab6f88c7751b7
3
+ size 776184
models/__init__.py ADDED
File without changes
models/pllava/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
17
+
18
+
19
+ _import_structure = {"configuration_pllava": ["PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "PllavaConfig"]}
20
+
21
+ try:
22
+ if not is_torch_available():
23
+ raise OptionalDependencyNotAvailable()
24
+ except OptionalDependencyNotAvailable:
25
+ pass
26
+ else:
27
+ _import_structure["modeling_pllava"] = [
28
+ "PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
29
+ "PllavaForConditionalGeneration",
30
+ "PllavaPreTrainedModel",
31
+ ]
32
+ _import_structure["processing_pllava"] = ["PllavaProcessor"]
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ from .configuration_pllava import PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, PllavaConfig
37
+
38
+ try:
39
+ if not is_torch_available():
40
+ raise OptionalDependencyNotAvailable()
41
+ except OptionalDependencyNotAvailable:
42
+ pass
43
+ else:
44
+ from .modeling_pllava import (
45
+ PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
46
+ PllavaForConditionalGeneration,
47
+ PllavaPreTrainedModel,
48
+ )
49
+ from .processing_pllava import PllavaProcessor
50
+
51
+
52
+ else:
53
+ import sys
54
+
55
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
models/pllava/configuration_pllava.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ Llava model configuration"""
15
+
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+ from transformers.models.auto import CONFIG_MAPPING
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
+ "llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json",
25
+ }
26
+
27
+
28
+ class PllavaConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
31
+ Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
+ with the defaults will yield a similar configuration to that of the Llava-9B.
33
+
34
+ e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+ Args:
40
+ vision_config (`LlavaVisionConfig`, *optional*):
41
+ Custom vision config or dict
42
+ text_config (`Union[AutoConfig, dict]`, *optional*):
43
+ The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
44
+ ignore_index (`int`, *optional*, defaults to -100):
45
+ The ignore index for the loss function.
46
+ image_token_index (`int`, *optional*, defaults to 32000):
47
+ The image token index to encode the image prompt.
48
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
49
+ The activation function used by the multimodal projector.
50
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
51
+ The feature selection strategy used to select the vision feature from the CLIP backbone.
52
+ vision_feature_layer (`int`, *optional*, defaults to -2):
53
+ The index of the layer to select the vision feature.
54
+ vocab_size (`int`, *optional*, defaults to 32000):
55
+ Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the
56
+ `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]
57
+
58
+ Example:
59
+
60
+ ```python
61
+ >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
62
+
63
+ >>> # Initializing a CLIP-vision config
64
+ >>> vision_config = CLIPVisionConfig()
65
+
66
+ >>> # Initializing a Llama config
67
+ >>> text_config = LlamaConfig()
68
+
69
+ >>> # Initializing a Llava llava-1.5-7b style configuration
70
+ >>> configuration = LlavaConfig(vision_config, text_config)
71
+
72
+ >>> # Initializing a model from the llava-1.5-7b style configuration
73
+ >>> model = LlavaForConditionalGeneration(configuration)
74
+
75
+ >>> # Accessing the model configuration
76
+ >>> configuration = model.config
77
+ ```"""
78
+
79
+ model_type = "llava"
80
+ is_composition = False
81
+
82
+ def __init__(
83
+ self,
84
+ vision_config=None,
85
+ text_config=None,
86
+ ignore_index=-100,
87
+ image_token_index=32000,
88
+ projector_hidden_act="gelu",
89
+ vision_feature_select_strategy="default",
90
+ vision_feature_layer=-2,
91
+ vocab_size=32000,
92
+ pooling_method='avg',
93
+ pooling_shape=(8, 16, 16),
94
+ frame_shape=(24, 24), # llava 1.5 pretrained frame shape
95
+ num_frames=1, # llava 1.5 pretrained frame shape
96
+ use_pooling=True,
97
+ gradient_checkpointing=False,
98
+ **kwargs,
99
+ ):
100
+ self.ignore_index = ignore_index
101
+ self.image_token_index = image_token_index
102
+ self.projector_hidden_act = projector_hidden_act
103
+ self.vision_feature_select_strategy = vision_feature_select_strategy
104
+ self.vision_feature_layer = vision_feature_layer
105
+ self.vocab_size = vocab_size
106
+ self.use_pooling = use_pooling
107
+ self.gradient_checkpointing = gradient_checkpointing
108
+
109
+ self.vision_config = vision_config
110
+
111
+ self.pooling_method = pooling_method # should be in 'max', 'avg'
112
+ self.pooling_shape = pooling_shape #
113
+ self.frame_shape = frame_shape #
114
+ self.num_frames = num_frames
115
+ if isinstance(self.vision_config, dict):
116
+ vision_config["model_type"] = (
117
+ vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
118
+ )
119
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
120
+ elif vision_config is None:
121
+ self.vision_config = CONFIG_MAPPING["clip_vision_model"](
122
+ intermediate_size=4096,
123
+ hidden_size=1024,
124
+ patch_size=14,
125
+ image_size=336,
126
+ num_hidden_layers=24,
127
+ num_attention_heads=16,
128
+ vocab_size=32000,
129
+ projection_dim=768,
130
+ )
131
+ self.vocab_size = self.vocab_size
132
+
133
+ self.text_config = text_config
134
+
135
+ if isinstance(self.text_config, dict):
136
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
137
+ self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
138
+ self.vocab_size = self.text_config.vocab_size
139
+ self.text_config.gradient_checkpointing = self.gradient_checkpointing
140
+
141
+ elif text_config is None:
142
+ tmp_config = {"_attn_implementation":"flash_attention_2",
143
+ "gradient_checkpointing": self.gradient_checkpointing}
144
+ self.text_config = CONFIG_MAPPING["llama"](**tmp_config)
145
+ self.text_config.gradient_checkpointing = self.gradient_checkpointing
146
+ # self.text_config["_attn_implementation"]="flash_attention_2" # xl: temporal hard code
147
+
148
+
149
+ super().__init__(**kwargs)
models/pllava/convert_pllava_weights_to_hf.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Not yet
models/pllava/modeling_pllava.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Llava model."""
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+ import math
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ import os
24
+ from transformers import PreTrainedModel
25
+ from transformers.activations import ACT2FN
26
+ from transformers.cache_utils import Cache
27
+ from transformers.modeling_outputs import ModelOutput
28
+ from transformers.utils import (
29
+ add_start_docstrings,
30
+ add_start_docstrings_to_model_forward,
31
+ logging,
32
+ replace_return_docstrings,
33
+ )
34
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
35
+ import einops
36
+
37
+ from .configuration_pllava import PllavaConfig
38
+ import pickle
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CONFIG_FOR_DOC = "LlavaConfig"
43
+
44
+ PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
+ "",
46
+ "",
47
+ "",
48
+ # See all Llava models at https://huggingface.co/models?filter=llava
49
+ ]
50
+
51
+
52
+ @dataclass
53
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
54
+ class PllavaCausalLMOutputWithPast(ModelOutput):
55
+ """
56
+ Base class for Llava causal language model (or autoregressive) outputs.
57
+
58
+ Args:
59
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
60
+ Language modeling loss (for next-token prediction).
61
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
62
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
63
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
64
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
65
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
66
+
67
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
68
+ `past_key_values` input) to speed up sequential decoding.
69
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
70
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
71
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
72
+
73
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
74
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
75
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
76
+ sequence_length)`.
77
+
78
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
79
+ heads.
80
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
81
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
82
+ sequence_length, hidden_size)`.
83
+
84
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
85
+ """
86
+
87
+ loss: Optional[torch.FloatTensor] = None
88
+ logits: torch.FloatTensor = None
89
+ past_key_values: Optional[List[torch.FloatTensor]] = None
90
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
91
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
92
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
93
+
94
+ class PllavaMultiModalProjector(nn.Module):
95
+ supported_highres = ['pad_crop_four', 'slide', ]
96
+ def __init__(self, config: PllavaConfig):
97
+ super().__init__()
98
+ self.use_pooling = config.use_pooling
99
+ self.frame_shape=config.frame_shape
100
+ self.num_frames = config.num_frames
101
+ self.pooling_shape = config.pooling_shape
102
+
103
+ self.pooling = nn.AdaptiveAvgPool3d(config.pooling_shape)
104
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
105
+ self.act = ACT2FN[config.projector_hidden_act]
106
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
107
+
108
+ def convert_Fembeddings2video(self, input, num_videos, frame_shape):
109
+ input = einops.rearrange(input,
110
+ '(num_videos num_frames) (h w) embed_dims -> num_videos embed_dims num_frames h w',
111
+ num_videos=num_videos, h=frame_shape[0])
112
+ return input
113
+
114
+ def convert_video2Fembeddings(self, input):
115
+ input = einops.rearrange(input, 'num_videos embed_dims num_frames h w -> (num_videos num_frames) (h w) embed_dims ', )
116
+ return input
117
+
118
+ def convert_video2MMembeddings(self, input):
119
+ input = einops.rearrange(input, 'num_videos embed_dims num_frames h w -> num_videos (num_frames h w) embed_dims ', )
120
+ return input
121
+
122
+ def forward(self, image_features, media_type, batch_size=None, num_videos=None):
123
+ frame_shape = self.frame_shape
124
+ num_frames = self.num_frames
125
+ assert media_type in ( 'video', 'image'), f'only image or video, but got media_type {media_type}'
126
+ hidden_states = image_features
127
+
128
+ if media_type == 'image':
129
+ hidden_states = hidden_states.repeat(num_frames, 1, 1)
130
+
131
+ total_frames, spatial_seqlen, embed_dims = hidden_states.shape
132
+ #TODO: temporal code, should ensure num_frames == total frames in data loading later
133
+ if total_frames < num_frames and self.use_pooling: #
134
+ multiplier = int(num_frames/total_frames)+1
135
+ hidden_states= hidden_states.repeat_interleave(multiplier, dim=0)[:num_frames]
136
+ total_frames, spatial_seqlen, embed_dims = hidden_states.shape
137
+
138
+ assert total_frames % num_frames == 0
139
+ assert frame_shape[0] * frame_shape[1] == spatial_seqlen
140
+ hidden_states = self.linear_1(hidden_states)
141
+ hidden_states = self.act(hidden_states)
142
+ hidden_states = self.linear_2(hidden_states)
143
+ hidden_states_videos = self.convert_Fembeddings2video(hidden_states, num_videos * batch_size, frame_shape)
144
+ hidden_states_videos = self.pooling(hidden_states_videos)
145
+ hidden_states = einops.rearrange(hidden_states_videos, 'batch_size_num_videos embed_dims num_frames h w -> batch_size_num_videos num_frames (h w) embed_dims', )
146
+ hidden_states = einops.rearrange(hidden_states, 'batch_size_num_videos num_frames hw embed_dims -> batch_size_num_videos (num_frames hw) embed_dims ')
147
+ return hidden_states
148
+
149
+
150
+
151
+ PLLAVA_START_DOCSTRING = r"""
152
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
153
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
154
+ etc.)
155
+
156
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
157
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
158
+ and behavior.
159
+
160
+ Parameters:
161
+ config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
162
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
163
+ load the weights associated with the model, only the configuration. Check out the
164
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
165
+ """
166
+
167
+
168
+ @add_start_docstrings(
169
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
170
+ PLLAVA_START_DOCSTRING,
171
+ )
172
+ class PllavaPreTrainedModel(PreTrainedModel):
173
+ config_class = PllavaConfig
174
+ base_model_prefix = "model"
175
+ supports_gradient_checkpointing = True
176
+ _no_split_modules = ["LlavaVisionAttention"]
177
+ _skip_keys_device_placement = "past_key_values"
178
+ _supports_flash_attn_2 = True
179
+
180
+ def _init_weights(self, module):
181
+ # important: this ported version of Llava isn't meant for training from scratch - only
182
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
183
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
184
+ std = (
185
+ self.config.initializer_range
186
+ if hasattr(self.config, "initializer_range")
187
+ else self.config.text_config.initializer_range
188
+ )
189
+
190
+ if hasattr(module, "class_embedding"):
191
+ module.class_embedding.data.normal_(mean=0.0, std=std)
192
+
193
+ # if isinstance(module, (nn.Linear, nn.Conv2d)):
194
+ # module.weight.data.normal_(mean=0.0, std=std)
195
+ # if module.bias is not None:
196
+ # module.bias.data.zero_()
197
+
198
+ elif isinstance(module, nn.Embedding):
199
+ module.weight.data.normal_(mean=0.0, std=std)
200
+ if module.padding_idx is not None:
201
+ module.weight.data[module.padding_idx].zero_()
202
+
203
+ elif isinstance(module, PllavaMultiModalProjector):
204
+ # module.register_embed.data.normal_(mean=0.0, std=std)
205
+ if self.config.register:
206
+ module.register_embed.data.zero_()
207
+
208
+ @property
209
+ def _supports_sdpa(self):
210
+ """
211
+ Retrieve language_model's attribute to check whether the model supports
212
+ SDPA or not.
213
+ """
214
+ return self.language_model._supports_sdpa
215
+
216
+
217
+ PLLAVA_INPUTS_DOCSTRING = r"""
218
+ Args:
219
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
220
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
221
+ it.
222
+
223
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
224
+ [`PreTrainedTokenizer.__call__`] for details.
225
+
226
+ [What are input IDs?](../glossary#input-ids)
227
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
228
+ The tensors corresponding to the input images. Pixel values can be obtained using
229
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
230
+ [`CLIPImageProcessor`] for processing images).
231
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
232
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
233
+
234
+ - 1 for tokens that are **not masked**,
235
+ - 0 for tokens that are **masked**.
236
+
237
+ [What are attention masks?](../glossary#attention-mask)
238
+
239
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
240
+ [`PreTrainedTokenizer.__call__`] for details.
241
+
242
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
243
+ `past_key_values`).
244
+
245
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
246
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
247
+ information on the default strategy.
248
+
249
+ - 1 indicates the head is **not masked**,
250
+ - 0 indicates the head is **masked**.
251
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
252
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
253
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
254
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
255
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
256
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
257
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
258
+
259
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
260
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
261
+
262
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
263
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
264
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
265
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
266
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
267
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
268
+ model's internal embedding lookup matrix.
269
+ use_cache (`bool`, *optional*):
270
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
271
+ `past_key_values`).
272
+ output_attentions (`bool`, *optional*):
273
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
274
+ tensors for more detail.
275
+ output_hidden_states (`bool`, *optional*):
276
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
277
+ more detail.
278
+ return_dict (`bool`, *optional*):
279
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
280
+ """
281
+
282
+
283
+ @add_start_docstrings(
284
+ """The LLAVA model which consists of a vision backbone and a language model.""",
285
+ PLLAVA_START_DOCSTRING,
286
+ )
287
+ class PllavaForConditionalGeneration(PllavaPreTrainedModel):
288
+ def __init__(self, config: PllavaConfig):
289
+ super().__init__(config)
290
+ self.config = config
291
+ self.vision_tower = AutoModel.from_config(config.vision_config)
292
+ self.multi_modal_projector = PllavaMultiModalProjector(config)
293
+ self.vocab_size = config.vocab_size
294
+ # self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation="flash_attention_2")
295
+ self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation="eager")
296
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.text_config.pad_token_id
297
+ assert self.pad_token_id is not None, 'provide the model with pad_token_id, this would be used to arranging new embedings'
298
+ self.post_init()
299
+
300
+ def get_input_embeddings(self):
301
+ return self.language_model.get_input_embeddings()
302
+
303
+ def set_input_embeddings(self, value):
304
+ self.language_model.set_input_embeddings(value)
305
+
306
+ def get_output_embeddings(self):
307
+ return self.language_model.get_output_embeddings()
308
+
309
+ def set_output_embeddings(self, new_embeddings):
310
+ self.language_model.set_output_embeddings(new_embeddings)
311
+
312
+ def set_decoder(self, decoder):
313
+ self.language_model.set_decoder(decoder)
314
+
315
+ def get_decoder(self):
316
+ return self.language_model.get_decoder()
317
+
318
+ def tie_weights(self):
319
+ return self.language_model.tie_weights()
320
+
321
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
322
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
323
+ # update vocab size
324
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
325
+ self.config.vocab_size = model_embeds.num_embeddings
326
+ self.vocab_size = model_embeds.num_embeddings
327
+ return model_embeds
328
+
329
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
330
+ num_images, num_image_patches, embed_dim = image_features.shape
331
+ batch_size, sequence_length = input_ids.shape
332
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
333
+ # 1. Create a mask to know where special image tokens are
334
+ special_image_token_mask = input_ids == self.config.image_token_index
335
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
336
+ # Compute the maximum embed dimension
337
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
338
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
339
+
340
+ # 2. Compute the positions where text should be written
341
+ # Calculate new positions for text tokens in merged image-text sequence.
342
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
343
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
344
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
345
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
346
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
347
+ if left_padding:
348
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
349
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
350
+
351
+ # 3. Create the full embedding, already padded to the maximum position
352
+ final_embedding = torch.zeros(
353
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
354
+ )
355
+ final_attention_mask = torch.zeros(
356
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
357
+ )
358
+ if labels is not None:
359
+ final_labels = torch.full(
360
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
361
+ )
362
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
363
+ # set the corresponding tensors into their correct target device.
364
+ target_device = inputs_embeds.device
365
+ batch_indices, non_image_indices, text_to_overwrite = (
366
+ batch_indices.to(target_device),
367
+ non_image_indices.to(target_device),
368
+ text_to_overwrite.to(target_device),
369
+ )
370
+ attention_mask = attention_mask.to(target_device)
371
+
372
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
373
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
374
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
375
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
376
+ if labels is not None:
377
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
378
+
379
+ # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
380
+ image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
381
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) > nb_image_pad[:, None].to(target_device)
382
+
383
+ # # somthing really weird here.
384
+ # temp1 = (image_to_overwrite.cumsum(-1) > nb_image_pad[:, None].to(target_device)) & image_to_overwrite
385
+ # # this is for right padding
386
+ # temp2 = (image_to_overwrite.cumsum(-1) <= num_special_image_tokens.max() * num_image_patches - nb_image_pad[:, None]) & image_to_overwrite
387
+
388
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
389
+ raise ValueError(
390
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
391
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
392
+ )
393
+
394
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
395
+ final_attention_mask |= image_to_overwrite
396
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
397
+
398
+ if labels is None:
399
+ final_labels = None
400
+
401
+ return final_embedding, final_attention_mask, final_labels, position_ids
402
+
403
+ @add_start_docstrings_to_model_forward(PLLAVA_INPUTS_DOCSTRING)
404
+ @replace_return_docstrings(output_type=PllavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
405
+ def forward(
406
+ self,
407
+ input_ids: torch.LongTensor = None,
408
+ pixel_values: torch.FloatTensor = None,
409
+ attention_mask: Optional[torch.Tensor] = None,
410
+ media_type: str = None,
411
+ position_ids: Optional[torch.LongTensor] = None,
412
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
413
+ inputs_embeds: Optional[torch.FloatTensor] = None,
414
+ vision_feature_layer: Optional[int] = None,
415
+ vision_feature_select_strategy: Optional[str] = None,
416
+ labels: Optional[torch.LongTensor] = None,
417
+ use_cache: Optional[bool] = None,
418
+ output_attentions: Optional[bool] = None,
419
+ output_hidden_states: Optional[bool] = None,
420
+ return_dict: Optional[bool] = None,
421
+ ) -> Union[Tuple, PllavaCausalLMOutputWithPast]:
422
+ r"""
423
+ Args:
424
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
425
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
426
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
427
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
428
+
429
+ Returns:
430
+
431
+ Example:
432
+
433
+ ```python
434
+ >>> from PIL import Image
435
+ >>> import requests
436
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
437
+
438
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
439
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
440
+
441
+ >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
442
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
443
+ >>> image = Image.open(requests.get(url, stream=True).raw)
444
+
445
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
446
+
447
+ >>> # Generate
448
+ >>> generate_ids = model.generate(**inputs, max_length=30)
449
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
450
+ "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
451
+ ```"""
452
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
453
+ output_hidden_states = (
454
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
455
+ )
456
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
457
+ vision_feature_layer = (
458
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
459
+ )
460
+ vision_feature_select_strategy = (
461
+ vision_feature_select_strategy
462
+ if vision_feature_select_strategy is not None
463
+ else self.config.vision_feature_select_strategy
464
+ )
465
+
466
+ if inputs_embeds is None:
467
+ # 1. Extra the input embeddings
468
+ no_img_input_ids = torch.where(input_ids!=self.config.image_token_index, input_ids, self.pad_token_id) # some model used up all the embeddings
469
+ inputs_embeds = self.get_input_embeddings()(no_img_input_ids)
470
+ batch_size = inputs_embeds.shape[0]
471
+ # 2. Merge text and images
472
+ if pixel_values is not None and input_ids.shape[1] != 1:
473
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
474
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
475
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer] # ( b, img_seqlen, embed_dim)
476
+ if vision_feature_select_strategy == "default":
477
+ selected_image_feature = selected_image_feature[:, 1:]
478
+ elif vision_feature_select_strategy == "full":
479
+ raise ValueError("not implemented")
480
+ selected_image_feature = selected_image_feature
481
+ else:
482
+ raise ValueError(
483
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
484
+ )
485
+
486
+ image_features = self.multi_modal_projector(selected_image_feature,
487
+ media_type,
488
+ batch_size=batch_size,
489
+ num_videos=pixel_values.shape[0]//self.config.num_frames//batch_size,)
490
+
491
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
492
+ image_features, inputs_embeds, input_ids, attention_mask, labels
493
+ )
494
+ if labels is None:
495
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
496
+ else:
497
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
498
+ # generation with cache
499
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
500
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
501
+ # that are set to 0
502
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
503
+
504
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
505
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
506
+
507
+ # Get the target length
508
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
509
+
510
+ extended_attention_mask = torch.ones(
511
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
512
+ dtype=attention_mask.dtype,
513
+ device=attention_mask.device,
514
+ )
515
+
516
+ # Filter out only the tokens that can be un-attended, this can happen
517
+ # if one uses Llava + Fused modules where the cache on the
518
+ # first iteration is already big enough, or if one passes custom cache
519
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
520
+ new_batch_index = batch_index[valid_indices]
521
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
522
+
523
+ # Zero-out the places where we don't need to attend
524
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
525
+
526
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
527
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
528
+
529
+ outputs = self.language_model(
530
+ attention_mask=attention_mask,
531
+ position_ids=position_ids,
532
+ past_key_values=past_key_values,
533
+ inputs_embeds=inputs_embeds,
534
+ use_cache=use_cache,
535
+ output_attentions=output_attentions,
536
+ output_hidden_states=output_hidden_states,
537
+ return_dict=return_dict,
538
+ )
539
+
540
+ logits = outputs[0]
541
+
542
+ loss = None
543
+ if labels is not None:
544
+ # Shift so that tokens < n predict n
545
+ if attention_mask is not None:
546
+ shift_attention_mask = attention_mask[..., 1:]
547
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
548
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
549
+ else:
550
+ shift_logits = logits[..., :-1, :].contiguous()
551
+ shift_labels = labels[..., 1:].contiguous()
552
+ # Flatten the tokens
553
+ loss_fct = nn.CrossEntropyLoss()
554
+ loss = loss_fct(
555
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
556
+ )
557
+
558
+ if not return_dict:
559
+ output = (logits,) + outputs[1:]
560
+ return (loss,) + output if loss is not None else output
561
+
562
+ return PllavaCausalLMOutputWithPast(
563
+ loss=loss,
564
+ logits=logits,
565
+ past_key_values=outputs.past_key_values,
566
+ hidden_states=outputs.hidden_states,
567
+ attentions=outputs.attentions,
568
+ )
569
+
570
+ def prepare_inputs_for_generation(
571
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
572
+ ):
573
+ if past_key_values is not None:
574
+ if isinstance(past_key_values, Cache):
575
+ cache_length = past_key_values.get_seq_length()
576
+ past_length = past_key_values.seen_tokens
577
+ else:
578
+ cache_length = past_length = past_key_values[0][0].shape[2]
579
+
580
+ # Keep only the unprocessed tokens:
581
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
582
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
583
+ # input)
584
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
585
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
586
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
587
+ # input_ids based on the past_length.
588
+ elif past_length < input_ids.shape[1]:
589
+ input_ids = input_ids[:, past_length:]
590
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
591
+ elif self.config.image_token_index in input_ids:
592
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
593
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
594
+ # older attention values, as their corresponding values are not part of the input.
595
+ if cache_length < past_length and attention_mask is not None:
596
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
597
+
598
+ position_ids = kwargs.get("position_ids", None)
599
+ if attention_mask is not None and position_ids is None:
600
+ # create position_ids on the fly for batch generation
601
+ position_ids = attention_mask.long().cumsum(-1) - 1
602
+ position_ids.masked_fill_(attention_mask == 0, 1)
603
+ if past_key_values:
604
+ position_ids = position_ids[:, -input_ids.shape[1] :]
605
+
606
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
607
+ if inputs_embeds is not None and past_key_values is None:
608
+ model_inputs = {"inputs_embeds": inputs_embeds}
609
+ else:
610
+ model_inputs = {"input_ids": input_ids}
611
+ media_type = kwargs.get('media_type', None)
612
+
613
+ model_inputs.update(
614
+ {
615
+ "position_ids": position_ids,
616
+ "past_key_values": past_key_values,
617
+ "use_cache": kwargs.get("use_cache"),
618
+ "attention_mask": attention_mask,
619
+ "pixel_values": pixel_values,
620
+ "media_type": media_type,
621
+ }
622
+ )
623
+ return model_inputs
624
+
625
+ def _reorder_cache(self, *args, **kwargs):
626
+ return self.language_model._reorder_cache(*args, **kwargs)
models/pllava/processing_pllava.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Llava.
17
+ """
18
+
19
+
20
+ import itertools
21
+ from typing import List, Optional, Union
22
+ import PIL.Image
23
+ import numpy as np
24
+
25
+ from transformers import AutoTokenizer
26
+ from transformers.feature_extraction_utils import BatchFeature
27
+ from transformers.image_utils import (
28
+ ImageInput,
29
+ make_list_of_images,
30
+ valid_images,
31
+ infer_channel_dimension_format,
32
+ to_numpy_array,
33
+ get_image_size,
34
+ ChannelDimension,
35
+ )
36
+ from transformers.image_processing_utils import get_size_dict
37
+ from transformers.image_utils import PILImageResampling
38
+ from transformers.processing_utils import ProcessorMixin
39
+ from transformers.image_transforms import resize, pad, PaddingMode, to_channel_dimension_format, get_resize_output_image_size
40
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
41
+ from transformers.utils import TensorType
42
+
43
+
44
+ class PllavaProcessor(ProcessorMixin):
45
+ r"""
46
+ Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
47
+
48
+ [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
49
+ [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
50
+
51
+ Args:
52
+ image_processor ([`CLIPImageProcessor`], *optional*):
53
+ The image processor is a required input.
54
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
55
+ The tokenizer is a required input.
56
+ """
57
+
58
+ attributes = ["image_processor", "tokenizer"]
59
+ image_processor_class = "CLIPImageProcessor"
60
+ tokenizer_class = "AutoTokenizer"
61
+
62
+ def __init__(self, image_processor=None, tokenizer=None,
63
+ shortest_edge=336,
64
+ longest_edge=762,
65
+ center_pad=False):
66
+ self.shortest_edge = shortest_edge
67
+ self.longest_edge = longest_edge
68
+ self.center_pad = center_pad
69
+ super().__init__(image_processor, tokenizer)
70
+
71
+ def resize_crop_longshort(self, videos: list[list[np.ndarray]], input_data_format):
72
+ video_spatial_sizes = [get_image_size(images[0], input_data_format) for images in videos]
73
+ long_short_rates = [max(size) / min(size) for size in video_spatial_sizes]
74
+ min_long_short_rate = min(long_short_rates)
75
+ min_long_short_video_idx = long_short_rates.index(min_long_short_rate)
76
+
77
+ clip_resolution = self.image_processor.size['shortest_edge']
78
+ out_video_spatial_size = video_spatial_sizes[min_long_short_video_idx]
79
+ out_videos_short_edge = max(min(size) for size in video_spatial_sizes)
80
+ resize_longest_edge = max(max(size) for size in video_spatial_sizes)
81
+ resize_longest_edge = min(640, resize_longest_edge)
82
+ out_videos_short_edge = min(out_videos_short_edge, int(resize_longest_edge / min_long_short_rate))
83
+ out_videos_short_edge = max(out_videos_short_edge, clip_resolution)
84
+
85
+
86
+ if out_video_spatial_size[0] > out_video_spatial_size[1]: # h > w:
87
+ out_video_spatial_size = (int(out_videos_short_edge * min_long_short_rate), out_videos_short_edge )
88
+ else:
89
+ out_video_spatial_size = ( out_videos_short_edge, int(out_videos_short_edge * min_long_short_rate) )
90
+ videos = [
91
+ [self.resize(frame, input_data_format=input_data_format, shortest_edge=out_videos_short_edge, longest_edge=9999) for frame in frames]
92
+ for frames in videos
93
+ ]
94
+ out_videos = []
95
+ for frames in videos:
96
+ out_frames = []
97
+ video_spatial_size = get_image_size(frames[0], input_data_format)
98
+ assert min(video_spatial_size) == out_videos_short_edge
99
+ overhead = (max(video_spatial_size) - max(out_video_spatial_size)) // 2
100
+ slice_start, slice_end = overhead // 2, overhead // 2 + max(out_video_spatial_size)
101
+ hslice, wslice = (slice(slice_start, slice_end), slice(None, None)) if video_spatial_size[0] > video_spatial_size[1] \
102
+ else (slice(None, None), slice(slice_start, slice_end)) # h > w
103
+ for frame in frames:
104
+ if input_data_format == ChannelDimension.FIRST:
105
+ out_frames.append(frame[..., hslice, wslice])
106
+ elif input_data_format == ChannelDimension.LAST:
107
+ out_frames.append(frame[..., hslice, wslice, :])
108
+ out_videos.append(out_frames)
109
+
110
+ return out_videos
111
+
112
+ @staticmethod
113
+ def _compute_num_blocks_and_overlaps(input_shape, resolution):
114
+ input_shape = np.array(input_shape)
115
+ resolution = np.array(resolution)
116
+ assert input_shape.max() >= resolution
117
+ num_blocks = np.ceil(input_shape / resolution).astype(np.int32).tolist()
118
+ overlaps = [0 if size % resolution==0
119
+ else int(np.floor((resolution - size % resolution) / (num_block - 1))) for num_block, size in zip(num_blocks, input_shape)]
120
+ return num_blocks, overlaps
121
+
122
+ def resize(
123
+ self,
124
+ image: np.ndarray,
125
+ resample: PILImageResampling = PILImageResampling.BICUBIC, # type: ignore
126
+ data_format: Optional[Union[str, ChannelDimension]] = None,
127
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
128
+ shortest_edge: int = None,
129
+ longest_edge: int = None,
130
+ **kwargs,
131
+ ) -> np.ndarray:
132
+ """
133
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
134
+ resized to keep the input aspect ratio.
135
+
136
+ Args:
137
+ image (`np.ndarray`):
138
+ Image to resize.
139
+ size (`Dict[str, int]`):
140
+ Size of the output image.
141
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
142
+ Resampling filter to use when resiizing the image.
143
+ data_format (`str` or `ChannelDimension`, *optional*):
144
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
145
+ input_data_format (`ChannelDimension` or `str`, *optional*):
146
+ The channel dimension format of the input image. If not provided, it will be inferred.
147
+ """
148
+ shortest_edge = getattr(self, 'shortest_edge', None) if shortest_edge is None else shortest_edge
149
+ longest_edge = getattr(self, 'longest_edge', None) if longest_edge is None else longest_edge
150
+ default_to_square = False
151
+ output_size = get_resize_output_image_size(
152
+ image,
153
+ size=shortest_edge,
154
+ default_to_square=default_to_square,
155
+ max_size=longest_edge,
156
+ input_data_format=input_data_format,
157
+ )
158
+ clip_resolution = self.image_processor.size['shortest_edge']
159
+ if min(output_size) < clip_resolution:
160
+ output_size = get_resize_output_image_size(
161
+ image,
162
+ size=shortest_edge,
163
+ default_to_square=default_to_square,
164
+ input_data_format=input_data_format,
165
+ )
166
+ return resize(
167
+ image,
168
+ size=output_size,
169
+ resample=resample,
170
+ data_format=data_format,
171
+ input_data_format=input_data_format,
172
+ **kwargs,
173
+ )
174
+
175
+ def __call__(
176
+ self,
177
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
178
+ images: ImageInput = None,
179
+ center_pad = None,
180
+ padding: Union[bool, str, PaddingStrategy] = False,
181
+ truncation: Union[bool, str, TruncationStrategy] = None,
182
+ max_length=None,
183
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
184
+ ) -> BatchFeature:
185
+ """
186
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
187
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
188
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
189
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
190
+ of the above two methods for more information.
191
+
192
+ Args:
193
+ text (`str`, `List[str]`, `List[List[str]]`):
194
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
195
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
196
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
197
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
198
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
199
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
200
+ number of channels, H and W are image height and width.
201
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
202
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
203
+ index) among:
204
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
205
+ sequence if provided).
206
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
207
+ acceptable input length for the model if that argument is not provided.
208
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
209
+ lengths).
210
+ max_length (`int`, *optional*):
211
+ Maximum length of the returned list and optionally padding length (see above).
212
+ truncation (`bool`, *optional*):
213
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
214
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
215
+ If set, will return tensors of a particular framework. Acceptable values are:
216
+
217
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
218
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
219
+ - `'np'`: Return NumPy `np.ndarray` objects.
220
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
221
+
222
+ Returns:
223
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
224
+
225
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
226
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
227
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
228
+ `None`).
229
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
230
+ """
231
+ data=dict()
232
+ if images is not None:
233
+ if isinstance(images, list) and isinstance(images[0], PIL.Image.Image):
234
+ videos = [images] # one video
235
+ else:
236
+ videos = images
237
+
238
+ pixel_values_list = []
239
+ videos = [[to_numpy_array(image) for image in make_list_of_images(images)] for images in videos]
240
+ # images = [self.resize(image, ) if min(get_image_size(image, input_data_format)) < clip_resolution else image for image in images]
241
+ input_data_format = infer_channel_dimension_format(videos[0][0])
242
+ videos = self.resize_crop_longshort(videos, input_data_format)
243
+
244
+ for images in videos:
245
+ if not valid_images(images):
246
+ raise ValueError(
247
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
248
+ "torch.Tensor, tf.Tensor or jax.ndarray."
249
+ )
250
+
251
+ center_pad = center_pad if center_pad is not None else self.center_pad
252
+ if center_pad:
253
+ images = [self.pad_to_square(image, 0, input_data_format, input_data_format) for image in images]
254
+
255
+ pixel_values = self.image_processor(images, return_tensors='np')["pixel_values"]
256
+ pixel_values_list.append(pixel_values)
257
+
258
+ pixel_values = np.concatenate(pixel_values_list)
259
+ data.update(pixel_values=pixel_values)
260
+
261
+ else:
262
+ data.update(pixel_values = None)
263
+
264
+ if text is not None:
265
+ text_inputs = self.tokenizer(
266
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
267
+ )
268
+ data.update(**text_inputs)
269
+ return BatchFeature(data, tensor_type=return_tensors)
270
+
271
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
272
+ def batch_decode(self, *args, **kwargs):
273
+ """
274
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
275
+ refer to the docstring of this method for more information.
276
+ """
277
+ return self.tokenizer.batch_decode(*args, **kwargs)
278
+
279
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
280
+ def decode(self, *args, **kwargs):
281
+ """
282
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
283
+ the docstring of this method for more information.
284
+ """
285
+ return self.tokenizer.decode(*args, **kwargs)
286
+
287
+ @property
288
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
289
+ def model_input_names(self):
290
+ tokenizer_input_names = self.tokenizer.model_input_names
291
+ image_processor_input_names = self.image_processor.model_input_names
292
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
python_scripts/hf.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import os
3
+ import re
4
+ import multiprocessing
5
+ import functools
6
+ import huggingface_hub
7
+ from huggingface_hub import snapshot_download
8
+
9
+
10
+ def upload(repo_id, local_dir, path_in_repo, repo_type, token):
11
+ huggingface_hub.upload_folder(
12
+ repo_id=repo_id,
13
+ folder_path=local_dir,
14
+ path_in_repo=path_in_repo,
15
+ token=token,
16
+ repo_type=repo_type
17
+ )
18
+
19
+ def download(repo_id, local_dir, repo_type, token, filter_re=None):
20
+ files = huggingface_hub.list_repo_files(repo_id, repo_type=repo_type, token=token)
21
+ if filter_re is not None:
22
+ files = [file for file in files if re.search(filter_re, file) is not None]
23
+ pool = multiprocessing.Pool(8)
24
+ download_func = functools.partial(
25
+ huggingface_hub.hf_hub_download,
26
+ repo_id,
27
+ repo_type=repo_type,
28
+ local_dir=local_dir,
29
+ local_dir_use_symlinks=True,
30
+ token=token
31
+ )
32
+ pool.map(download_func, files)
33
+ print(f'downloaded files {files}')
34
+
35
+
36
+ def upload_file(repo_id, file_path, repo_type, token):
37
+ huggingface_hub.upload_file(
38
+ repo_id=repo_id,
39
+ path_or_fileobj=file_path,
40
+ path_in_repo=file_path,
41
+ token=token,
42
+ repo_type=repo_type,
43
+ )
44
+
45
+ if __name__ == '__main__':
46
+ read_token = '...'
47
+ write_token = '...'
48
+ repo_id = '...'
49
+ local_dir = '...'
50
+ repo_type = '...'
51
+
52
+
53
+ # #############
54
+ # # Examples on most simple hf usage
55
+ # # downlaod
56
+ # filters = []
57
+ # for filter_re in filters:
58
+ # download(repo_id,
59
+ # local_dir,
60
+ # repo_type,
61
+ # filter_re)
62
+
63
+ # # upload
64
+ # upload(repo_id, local_dir, local_dir, repo_type, write_token)
65
+ # #############
66
+
67
+ # download models
68
+ repo_ids = [
69
+ 'ermu2001/pllava-7b',
70
+ 'ermu2001/pllava-13b',
71
+ ]
72
+ for repo_id in repo_ids:
73
+ local_dir = repo_id.replace('ermu2001', 'MODELS')
74
+ snapshot_download(
75
+ repo_id,
76
+ local_dir=local_dir,
77
+ repo_type='model',
78
+ local_dir_use_symlinks=True,
79
+ token=read_token,
80
+ )
requirements.no_torch.txt ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.26.1
3
+ addict==2.4.0
4
+ aiofiles==23.2.1
5
+ aliyun-python-sdk-core==2.15.0
6
+ aliyun-python-sdk-kms==2.16.2
7
+ altair==5.2.0
8
+ annotated-types==0.6.0
9
+ antlr4-python3-runtime==4.9.3
10
+ anyio==4.3.0
11
+ anykeystore==0.2
12
+ apex==0.9.10.dev0
13
+ appdirs==1.4.4
14
+ argcomplete==3.2.3
15
+ attrs==23.2.0
16
+ av==10.0.0
17
+ beautifulsoup4==4.12.3
18
+ blessed==1.20.0
19
+ blessings==1.7
20
+ boto3==1.34.63
21
+ botocore==1.34.63
22
+ Brotli==1.1.0
23
+ cachetools==5.3.3
24
+ certifi==2024.2.2
25
+ cffi==1.16.0
26
+ charset-normalizer==3.3.2
27
+ click==8.1.7
28
+ colorama==0.4.6
29
+ contourpy==1.2.0
30
+ crcmod==1.7
31
+ cryptacular==1.6.2
32
+ cryptography==42.0.5
33
+ cycler==0.12.1
34
+ dacite==1.7.0
35
+ decorator==4.4.2
36
+ decord==0.6.0
37
+ deepspeed==0.14.0
38
+ defusedxml==0.7.1
39
+ Deprecated==1.2.14
40
+ dill==0.3.8
41
+ distro==1.9.0
42
+ dnspython==2.6.1
43
+ docker-pycreds==0.4.0
44
+ einops==0.6.1
45
+ exceptiongroup==1.2.0
46
+ fastapi==0.110.0
47
+ ffmpeg==1.4
48
+ ffmpy==0.3.2
49
+ fiftyone==0.23.6
50
+ fiftyone-brain==0.16.1
51
+ fiftyone_db==1.1.2
52
+ filelock==3.9.0
53
+ flash-attn==2.5.6
54
+ fonttools==4.49.0
55
+ fsspec==2024.2.0
56
+ ftfy==6.1.3
57
+ future==1.0.0
58
+ fvcore==0.1.5.post20221221
59
+ gdown==5.1.0
60
+ gitdb==4.0.11
61
+ GitPython==3.1.42
62
+ glob2==0.7
63
+ google-auth==2.28.2
64
+ google-auth-oauthlib==1.2.0
65
+ gpustat==1.1.1
66
+ gradio==4.21.0
67
+ gradio_client==0.12.0
68
+ graphql-core==3.2.3
69
+ greenlet==3.0.3
70
+ grpcio==1.62.1
71
+ h11==0.14.0
72
+ h2==4.1.0
73
+ hjson==3.1.0
74
+ hpack==4.0.0
75
+ httpcore==1.0.4
76
+ httpx==0.27.0
77
+ huggingface-hub==0.21.4
78
+ humanize==4.9.0
79
+ hupper==1.12.1
80
+ Hypercorn==0.16.0
81
+ hyperframe==6.0.1
82
+ idna==3.6
83
+ idscheck==2.3.0
84
+ imageio==2.27.0
85
+ imageio-ffmpeg==0.4.9
86
+ importlib_metadata==7.0.2
87
+ importlib_resources==6.3.0
88
+ inflate64==1.0.0
89
+ iopath==0.1.10
90
+ Jinja2==3.1.2
91
+ jmespath==0.10.0
92
+ joblib==1.3.2
93
+ jsonlines==4.0.0
94
+ jsonschema==4.21.1
95
+ jsonschema-specifications==2023.12.1
96
+ kaleido==0.2.1
97
+ kiwisolver==1.4.5
98
+ lazy_loader==0.3
99
+ Markdown==3.6
100
+ markdown-it-py==3.0.0
101
+ MarkupSafe==2.1.3
102
+ matplotlib==3.8.3
103
+ mdurl==0.1.2
104
+ mmcv-full==1.7.2
105
+ model-index==0.1.11
106
+ mongoengine==0.24.2
107
+ motor==3.3.2
108
+ moviepy==1.0.3
109
+ mpmath==1.3.0
110
+ multivolumefile==0.2.3
111
+ networkx==3.2.1
112
+ ninja==1.11.1.1
113
+ numpy
114
+ oauthlib==3.2.2
115
+ omegaconf==2.3.0
116
+ openai==1.14.0
117
+ opencv-python==4.9.0.80
118
+ opencv-python-headless==4.9.0.80
119
+ opendatalab==0.0.10
120
+ openmim==0.3.9
121
+ openxlab==0.0.36
122
+ ordered-set==4.1.0
123
+ orjson==3.9.15
124
+ oss2==2.17.0
125
+ packaging==24.0
126
+ pandas==1.5.3
127
+ PasteDeploy==3.1.0
128
+ pathtools==0.1.2
129
+ pbkdf2==1.3
130
+ peft==0.10.0
131
+ pillow==10.2.0
132
+ plaster==1.1.2
133
+ plaster-pastedeploy==1.0.1
134
+ platformdirs==4.2.0
135
+ plotly==5.20.0
136
+ portalocker==2.8.2
137
+ pprintpp==0.4.0
138
+ priority==2.0.0
139
+ proglog==0.1.10
140
+ protobuf==4.23.4
141
+ psutil==5.9.4
142
+ py-cpuinfo==9.0.0
143
+ py7zr==0.21.0
144
+ pyasn1==0.5.1
145
+ pyasn1-modules==0.3.0
146
+ pybcj==1.0.2
147
+ pycparser==2.21
148
+ pycryptodome==3.20.0
149
+ pycryptodomex==3.20.0
150
+ pydantic==2.6.4
151
+ pydantic_core==2.16.3
152
+ pydub==0.25.1
153
+ Pygments==2.17.2
154
+ pymongo==4.6.2
155
+ pynvml==11.5.0
156
+ pyparsing==3.1.2
157
+ pyppmd==1.1.0
158
+ pyramid==2.0.2
159
+ pyramid-mailer==0.15.1
160
+ PySocks==1.7.1
161
+ python-dateutil==2.9.0.post0
162
+ python-multipart==0.0.9
163
+ python3-openid==3.2.0
164
+ pytz==2023.4
165
+ PyYAML==6.0
166
+ pyzstd==0.15.9
167
+ rarfile==4.1
168
+ referencing==0.33.0
169
+ regex==2023.12.25
170
+ repoze.sendmail==4.4.1
171
+ requests==2.28.2
172
+ requests-oauthlib==1.4.0
173
+ retrying==1.3.4
174
+ rich==13.4.2
175
+ rpds-py==0.18.0
176
+ rsa==4.9
177
+ ruff==0.3.2
178
+ s3transfer==0.10.1
179
+ safetensors==0.4.2
180
+ scikit-image==0.22.0
181
+ scikit-learn==1.4.1.post1
182
+ scipy==1.10.1
183
+ semantic-version==2.10.0
184
+ sentencepiece==0.2.0
185
+ sentry-sdk==1.42.0
186
+ setproctitle==1.3.3
187
+ shellingham==1.5.4
188
+ six==1.16.0
189
+ smmap==5.0.1
190
+ sniffio==1.3.1
191
+ sortedcontainers==2.4.0
192
+ soupsieve==2.5
193
+ SQLAlchemy==2.0.28
194
+ sse-starlette==0.10.3
195
+ sseclient-py==1.8.0
196
+ starlette==0.36.3
197
+ strawberry-graphql==0.138.1
198
+ sympy==1.12
199
+ tabulate==0.9.0
200
+ taskgroup==0.0.0a4
201
+ tenacity==8.2.3
202
+ tensorboard==2.15.1
203
+ tensorboard-data-server==0.7.2
204
+ tensorboardX==2.6.2.2
205
+ termcolor==2.3.0
206
+ texttable==1.7.0
207
+ threadpoolctl==3.3.0
208
+ tifffile==2024.2.12
209
+ timm==0.6.12
210
+ tokenizers==0.15.2
211
+ tomli==2.0.1
212
+ tomlkit==0.12.0
213
+ toolz==0.12.1
214
+ tqdm==4.65.2
215
+ transaction==4.0
216
+ transformers==4.37.1
217
+ translationstring==1.4
218
+ triton==2.2.0
219
+ typer==0.9.0
220
+ typing_extensions==4.8.0
221
+ tzdata==2024.1
222
+ tzlocal==5.2
223
+ universal-analytics-python3==1.1.1
224
+ urllib3==1.26.18
225
+ uvicorn==0.28.0
226
+ velruse==1.1.1
227
+ venusian==3.1.0
228
+ voxel51-eta==0.12.6
229
+ wandb==0.14.0
230
+ wcwidth==0.2.13
231
+ WebOb==1.8.7
232
+ websockets==11.0.3
233
+ Werkzeug==3.0.1
234
+ wrapt==1.16.0
235
+ wsproto==1.2.0
236
+ WTForms==3.1.2
237
+ wtforms-recaptcha==0.3.2
238
+ xmltodict==0.13.0
239
+ yacs==0.1.8
240
+ yapf==0.40.2
241
+ zipp==3.18.1
242
+ zope.deprecation==5.0
243
+ zope.interface==6.2
244
+ zope.sqlalchemy==3.1
requirements.torch.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ --index-url https://download.pytorch.org/whl/cu118
2
+ torch==2.2.1
3
+ torchaudio==2.2.1
4
+ torchvision==0.17.1
requirements.txt ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.26.1
3
+ addict==2.4.0
4
+ aiofiles==23.2.1
5
+ aliyun-python-sdk-core==2.15.0
6
+ aliyun-python-sdk-kms==2.16.2
7
+ altair==5.2.0
8
+ annotated-types==0.6.0
9
+ antlr4-python3-runtime==4.9.3
10
+ anyio==4.3.0
11
+ anykeystore==0.2
12
+ apex==0.9.10.dev0
13
+ appdirs==1.4.4
14
+ argcomplete==3.2.3
15
+ attrs==23.2.0
16
+ av==10.0.0
17
+ beautifulsoup4==4.12.3
18
+ blessed==1.20.0
19
+ blessings==1.7
20
+ boto3==1.34.63
21
+ botocore==1.34.63
22
+ Brotli==1.1.0
23
+ cachetools==5.3.3
24
+ certifi==2024.2.2
25
+ cffi==1.16.0
26
+ charset-normalizer==3.3.2
27
+ click==8.1.7
28
+ colorama==0.4.6
29
+ contourpy==1.2.0
30
+ crcmod==1.7
31
+ cryptacular==1.6.2
32
+ cryptography==42.0.5
33
+ cycler==0.12.1
34
+ dacite==1.7.0
35
+ decorator==4.4.2
36
+ decord==0.6.0
37
+ deepspeed==0.14.0
38
+ defusedxml==0.7.1
39
+ Deprecated==1.2.14
40
+ dill==0.3.8
41
+ distro==1.9.0
42
+ dnspython==2.6.1
43
+ docker-pycreds==0.4.0
44
+ einops==0.6.1
45
+ exceptiongroup==1.2.0
46
+ fastapi==0.110.0
47
+ ffmpeg==1.4
48
+ ffmpy==0.3.2
49
+ fiftyone==0.23.6
50
+ fiftyone-brain==0.16.1
51
+ fiftyone_db==1.1.2
52
+ filelock==3.9.0
53
+ fonttools==4.49.0
54
+ fsspec==2024.2.0
55
+ ftfy==6.1.3
56
+ future==1.0.0
57
+ fvcore==0.1.5.post20221221
58
+ gdown==5.1.0
59
+ gitdb==4.0.11
60
+ GitPython==3.1.42
61
+ glob2==0.7
62
+ google-auth==2.28.2
63
+ google-auth-oauthlib==1.2.0
64
+ gpustat==1.1.1
65
+ gradio==4.21.0
66
+ gradio_client==0.12.0
67
+ graphql-core==3.2.3
68
+ greenlet==3.0.3
69
+ grpcio==1.62.1
70
+ h11==0.14.0
71
+ h2==4.1.0
72
+ hjson==3.1.0
73
+ hpack==4.0.0
74
+ httpcore==1.0.4
75
+ httpx==0.27.0
76
+ huggingface-hub==0.21.4
77
+ humanize==4.9.0
78
+ hupper==1.12.1
79
+ Hypercorn==0.16.0
80
+ hyperframe==6.0.1
81
+ idna==3.6
82
+ idscheck==2.3.0
83
+ imageio==2.27.0
84
+ imageio-ffmpeg==0.4.9
85
+ importlib_metadata==7.0.2
86
+ importlib_resources==6.3.0
87
+ inflate64==1.0.0
88
+ iopath==0.1.10
89
+ Jinja2==3.1.2
90
+ jmespath==0.10.0
91
+ joblib==1.3.2
92
+ jsonlines==4.0.0
93
+ jsonschema==4.21.1
94
+ jsonschema-specifications==2023.12.1
95
+ kaleido==0.2.1
96
+ kiwisolver==1.4.5
97
+ lazy_loader==0.3
98
+ Markdown==3.6
99
+ markdown-it-py==3.0.0
100
+ MarkupSafe==2.1.3
101
+ matplotlib==3.8.3
102
+ mdurl==0.1.2
103
+ mmcv-full==1.7.2
104
+ model-index==0.1.11
105
+ mongoengine==0.24.2
106
+ motor==3.3.2
107
+ moviepy==1.0.3
108
+ mpmath==1.3.0
109
+ multivolumefile==0.2.3
110
+ networkx==3.2.1
111
+ ninja==1.11.1.1
112
+ numpy==1.23.5
113
+ oauthlib==3.2.2
114
+ omegaconf==2.3.0
115
+ openai==1.14.0
116
+ opencv-python==4.9.0.80
117
+ opencv-python-headless==4.9.0.80
118
+ opendatalab==0.0.10
119
+ openmim==0.3.9
120
+ openxlab==0.0.36
121
+ ordered-set==4.1.0
122
+ orjson==3.9.15
123
+ oss2==2.17.0
124
+ packaging==24.0
125
+ pandas==1.5.3
126
+ PasteDeploy==3.1.0
127
+ pathtools==0.1.2
128
+ pbkdf2==1.3
129
+ peft==0.10.0
130
+ pillow==10.2.0
131
+ plaster==1.1.2
132
+ plaster-pastedeploy==1.0.1
133
+ platformdirs==4.2.0
134
+ plotly==5.20.0
135
+ portalocker==2.8.2
136
+ pprintpp==0.4.0
137
+ priority==2.0.0
138
+ proglog==0.1.10
139
+ protobuf==4.23.4
140
+ psutil==5.9.4
141
+ py-cpuinfo==9.0.0
142
+ py7zr==0.21.0
143
+ pyasn1==0.5.1
144
+ pyasn1-modules==0.3.0
145
+ pybcj==1.0.2
146
+ pycparser==2.21
147
+ pycryptodome==3.20.0
148
+ pycryptodomex==3.20.0
149
+ pydantic==2.6.4
150
+ pydantic_core==2.16.3
151
+ pydub==0.25.1
152
+ Pygments==2.17.2
153
+ pymongo==4.6.2
154
+ pynvml==11.5.0
155
+ pyparsing==3.1.2
156
+ pyppmd==1.1.0
157
+ pyramid==2.0.2
158
+ pyramid-mailer==0.15.1
159
+ PySocks==1.7.1
160
+ python-dateutil==2.9.0.post0
161
+ python-multipart==0.0.9
162
+ python3-openid==3.2.0
163
+ pytz==2023.4
164
+ PyYAML==6.0
165
+ pyzstd==0.15.9
166
+ rarfile==4.1
167
+ referencing==0.33.0
168
+ regex==2023.12.25
169
+ repoze.sendmail==4.4.1
170
+ requests==2.28.2
171
+ requests-oauthlib==1.4.0
172
+ retrying==1.3.4
173
+ rich==13.4.2
174
+ rpds-py==0.18.0
175
+ rsa==4.9
176
+ ruff==0.3.2
177
+ s3transfer==0.10.1
178
+ safetensors==0.4.2
179
+ scikit-image==0.22.0
180
+ scikit-learn==1.4.1.post1
181
+ scipy==1.10.1
182
+ semantic-version==2.10.0
183
+ sentencepiece==0.2.0
184
+ sentry-sdk==1.42.0
185
+ setproctitle==1.3.3
186
+ shellingham==1.5.4
187
+ six==1.16.0
188
+ smmap==5.0.1
189
+ sniffio==1.3.1
190
+ sortedcontainers==2.4.0
191
+ soupsieve==2.5
192
+ SQLAlchemy==2.0.28
193
+ sse-starlette==0.10.3
194
+ sseclient-py==1.8.0
195
+ starlette==0.36.3
196
+ strawberry-graphql==0.138.1
197
+ sympy==1.12
198
+ tabulate==0.9.0
199
+ taskgroup==0.0.0a4
200
+ tenacity==8.2.3
201
+ tensorboard==2.15.1
202
+ tensorboard-data-server==0.7.2
203
+ tensorboardX==2.6.2.2
204
+ termcolor==2.3.0
205
+ texttable==1.7.0
206
+ threadpoolctl==3.3.0
207
+ tifffile==2024.2.12
208
+ timm==0.6.12
209
+ tokenizers==0.15.2
210
+ tomli==2.0.1
211
+ tomlkit==0.12.0
212
+ toolz==0.12.1
213
+ torch==2.2.1
214
+ torchaudio==2.2.1
215
+ torchvision==0.17.1
216
+ tqdm==4.65.2
217
+ transaction==4.0
218
+ transformers
219
+ translationstring==1.4
220
+ triton==2.2.0
221
+ typer==0.9.0
222
+ typing_extensions==4.8.0
223
+ tzdata==2024.1
224
+ tzlocal==5.2
225
+ universal-analytics-python3==1.1.1
226
+ urllib3==1.26.18
227
+ uvicorn==0.28.0
228
+ velruse==1.1.1
229
+ venusian==3.1.0
230
+ voxel51-eta==0.12.6
231
+ wandb==0.14.0
232
+ wcwidth==0.2.13
233
+ WebOb==1.8.7
234
+ websockets==11.0.3
235
+ Werkzeug==3.0.1
236
+ wrapt==1.16.0
237
+ wsproto==1.2.0
238
+ WTForms==3.1.2
239
+ wtforms-recaptcha==0.3.2
240
+ xmltodict==0.13.0
241
+ yacs==0.1.8
242
+ yapf==0.40.2
243
+ zipp==3.18.1
244
+ zope.deprecation==5.0
245
+ zope.interface==6.2
246
+ zope.sqlalchemy==3.1
scripts/accel_config_deepspeed_zero2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 8
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: false
8
+ zero_stage: 2
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ machine_rank: 0
12
+ main_training_function: main
13
+ mixed_precision: bf16
14
+ num_machines: 1
15
+ num_processes: 4
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
scripts/accel_config_deepspeed_zero3_offload.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 2
5
+ offload_optimizer_device: cpu
6
+ offload_param_device: cpu
7
+ zero3_init_flag: true
8
+ zero3_save_16bit_model: true
9
+ zero_stage: 3
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 1
16
+ num_processes: 8
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
scripts/accel_config_deepspeed_zero3_offload_multinode.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ gradient_accumulation_steps: 2
6
+ offload_optimizer_device: cpu
7
+ offload_param_device: cpu
8
+ zero3_init_flag: true
9
+ zero3_save_16bit_model: true
10
+ zero_stage: 3
11
+ distributed_type: DEEPSPEED
12
+ downcast_bf16: 'no'
13
+ machine_rank: 0
14
+ main_process_ip: fdbd:dc61:18:8::20
15
+ main_process_port: 6876
16
+ main_training_function: main
17
+ mixed_precision: bf16
18
+ num_machines: 2
19
+ num_processes: 16
20
+ rdzv_backend: static
21
+ same_network: true
22
+ tpu_env: []
23
+ tpu_use_cluster: false
24
+ tpu_use_sudo: false
25
+ use_cpu: false
scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ gradient_accumulation_steps: 2
6
+ offload_optimizer_device: cpu
7
+ offload_param_device: cpu
8
+ zero3_init_flag: true
9
+ zero3_save_16bit_model: true
10
+ zero_stage: 3
11
+ distributed_type: DEEPSPEED
12
+ downcast_bf16: 'no'
13
+ machine_rank: 0
14
+ main_process_ip: fdbd:dc61:18:8::20
15
+ main_process_port: 6876
16
+ main_training_function: main
17
+ mixed_precision: bf16
18
+ num_machines: 2
19
+ num_processes: 16
20
+ rdzv_backend: static
21
+ same_network: true
22
+ tpu_env: []
23
+ tpu_use_cluster: false
24
+ tpu_use_sudo: false
25
+ use_cpu: false
scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ gradient_accumulation_steps: 2
6
+ offload_optimizer_device: cpu
7
+ offload_param_device: cpu
8
+ zero3_init_flag: true
9
+ zero3_save_16bit_model: true
10
+ zero_stage: 3
11
+ distributed_type: DEEPSPEED
12
+ downcast_bf16: 'no'
13
+ machine_rank: 1
14
+ main_process_ip: fdbd:dc61:18:8::20
15
+ main_process_port: 6876
16
+ main_training_function: main
17
+ mixed_precision: bf16
18
+ num_machines: 2
19
+ num_processes: 16
20
+ rdzv_backend: static
21
+ same_network: true
22
+ tpu_env: []
23
+ tpu_use_cluster: false
24
+ tpu_use_sudo: false
25
+ use_cpu: false
scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 16
5
+ gradient_clipping: 1.0
6
+ offload_optimizer_device: cpu
7
+ offload_param_device: cpu
8
+ zero3_init_flag: true
9
+ zero3_save_16bit_model: true
10
+ zero_stage: 3
11
+ distributed_type: DEEPSPEED
12
+ downcast_bf16: 'no'
13
+ machine_rank: 0
14
+ main_training_function: main
15
+ mixed_precision: bf16
16
+ num_machines: 1
17
+ num_processes: 1
18
+ rdzv_backend: static
19
+ same_network: true
20
+ tpu_env: []
21
+ tpu_use_cluster: false
22
+ tpu_use_sudo: false
23
+ use_cpu: false
scripts/accel_config_multigpu.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ gpu_ids: 2,3,4,5
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: bf16
9
+ num_machines: 1
10
+ num_processes: 4
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
scripts/accel_config_multinode.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ gpu_ids: all
6
+ machine_rank: 1
7
+ main_process_ip: 10.193.16.150
8
+ main_process_port: 6784
9
+ main_training_function: main
10
+ mixed_precision: bf16
11
+ num_machines: 2
12
+ num_processes: 16
13
+ rdzv_backend: static
14
+ same_network: true
15
+ tpu_env: []
16
+ tpu_use_cluster: false
17
+ tpu_use_sudo: false
18
+ use_cpu: false
scripts/accel_config_singlegpu.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: 'NO'
4
+ downcast_bf16: 'no'
5
+ gpu_ids: '0'
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: bf16
9
+ num_machines: 1
10
+ num_processes: 1
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
scripts/demo.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_dir=${1:-"MODELS/pllava-7b"}
2
+ weight_dir=${2:-"${model_dir}"}
3
+ num_frames=16
4
+ lora_alpha=4
5
+
6
+ echo Running DEMO from model_dir: ${model_dir}
7
+ echo Running DEMO from weights_dir: ${weight_dir}
8
+ echo Running DEMO On Devices: ${CUDA_VISIBLE_DEVICES}
9
+
10
+
11
+ # # 34B Need to Use dispatch for this large.
12
+ # CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} python -m tasks.eval.demo.pllava_demo \
13
+ # --pretrained_model_name_or_path ${model_dir} \
14
+ # --num_frames ${num_frames} \
15
+ # --use_lora \
16
+ # --weight_dir ${weight_dir} \
17
+ # --lora_alpha ${lora_alpha} \
18
+ # --conv_mode eval_vcg_llava_next \
19
+ # --use_multi_gpus \
20
+
21
+
22
+ # 7B and 13B, There are problem if Model was split around A100 40G... Probably because some unkown bug in accelerate dispatch
23
+ CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1"} python -m tasks.eval.demo.pllava_demo \
24
+ --pretrained_model_name_or_path ${model_dir} \
25
+ --num_frames ${num_frames} \
26
+ --use_lora \
27
+ --weight_dir ${weight_dir} \
28
+ --lora_alpha ${lora_alpha} \
29
+ --conv_mode plain \
30
+ --use_multi_gpus
31
+
32
+
scripts/eval.sh ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # export CUDA_VISIBLE_DEVICES=2,6,7
2
+ export OPENAI_API_KEY=...
3
+ num_frames=16
4
+ test_ratio=1
5
+
6
+ # 13b, uses offload thus saving the full model
7
+ model_dir=MODELS/pllava-13b
8
+ weight_dir=MODELS/pllava-13b
9
+ SAVE_DIR=test_results/test_pllava_13b
10
+ lora_alpha=4
11
+ conv_mode=eval_vcgbench
12
+ python -m tasks.eval.vcgbench.pllava_eval_vcgbench \
13
+ --pretrained_model_name_or_path ${model_dir} \
14
+ --save_path ${SAVE_DIR}/vcgbench \
15
+ --num_frames ${num_frames} \
16
+ --use_lora \
17
+ --lora_alpha ${lora_alpha} \
18
+ --weight_dir ${weight_dir} \
19
+ --pooling_shape 16-12-12 \
20
+ --test_ratio ${test_ratio} \
21
+ --conv_mode ${conv_mode}
22
+
23
+ conv_mode=eval_mvbench
24
+ python -m tasks.eval.mvbench.pllava_eval_mvbench \
25
+ --pretrained_model_name_or_path ${model_dir} \
26
+ --save_path ${SAVE_DIR}/mvbench \
27
+ --use_lora \
28
+ --lora_alpha ${lora_alpha} \
29
+ --num_frames ${num_frames} \
30
+ --weight_dir ${weight_dir} \
31
+ --pooling_shape 16-12-12 \
32
+ --conv_mode ${conv_mode}
33
+
34
+ onv_mode=eval_videoqabench
35
+ python -m tasks.eval.videoqabench.pllava_eval_videoqabench \
36
+ --pretrained_model_name_or_path ${model_dir} \
37
+ --save_path ${SAVE_DIR}/videoqabench \
38
+ --num_frames ${num_frames} \
39
+ --use_lora \
40
+ --lora_alpha ${lora_alpha} \
41
+ --weight_dir ${weight_dir} \
42
+ --test_ratio ${test_ratio} \
43
+ --conv_mode ${conv_mode}
44
+
45
+
46
+ conv_mode=eval_recaption
47
+ python -m tasks.eval.recaption.pllava_recaption \
48
+ --pretrained_model_name_or_path ${model_dir} \
49
+ --save_path ${SAVE_DIR}/recaption \
50
+ --num_frames ${num_frames} \
51
+ --use_lora \
52
+ --weight_dir ${weight_dir} \
53
+ --lora_alpha ${lora_alpha} \
54
+ --test_ratio ${test_ratio} \
55
+ --conv_mode ${conv_mode}
56
+
57
+
58
+ model_dir=MODELS/pllava-7b
59
+ weight_dir=MODELS/pllava-7b
60
+ SAVE_DIR=test_results/test_pllava_7b
61
+ lora_alpha=4
62
+
63
+ conv_mode=eval_vcgbench
64
+ python -m tasks.eval.vcgbench.pllava_eval_vcgbench \
65
+ --pretrained_model_name_or_path ${model_dir} \
66
+ --save_path ${SAVE_DIR}/vcgbench \
67
+ --num_frames ${num_frames} \
68
+ --use_lora \
69
+ --lora_alpha ${lora_alpha} \
70
+ --weight_dir ${weight_dir} \
71
+ --pooling_shape 16-12-12 \
72
+ --test_ratio ${test_ratio}
73
+
74
+
75
+ conv_mode=eval_mvbench
76
+ python -m tasks.eval.mvbench.pllava_eval_mvbench \
77
+ --pretrained_model_name_or_path ${model_dir} \
78
+ --save_path ${SAVE_DIR}/mvbench \
79
+ --use_lora \
80
+ --lora_alpha ${lora_alpha} \
81
+ --num_frames ${num_frames} \
82
+ --weight_dir ${weight_dir} \
83
+ --pooling_shape 16-12-12
84
+
85
+
86
+ onv_mode=eval_videoqabench
87
+ python -m tasks.eval.videoqabench.pllava_eval_videoqabench \
88
+ --pretrained_model_name_or_path ${model_dir} \
89
+ --save_path ${SAVE_DIR}/videoqabench \
90
+ --num_frames ${num_frames} \
91
+ --use_lora \
92
+ --lora_alpha ${lora_alpha} \
93
+ --weight_dir ${weight_dir} \
94
+ --test_ratio ${test_ratio}
95
+
96
+ conv_mode=eval_recaption
97
+ python -m tasks.eval.recaption.pllava_recaption \
98
+ --pretrained_model_name_or_path ${model_dir} \
99
+ --save_path ${SAVE_DIR}/recaption \
100
+ --num_frames ${num_frames} \
101
+ --use_lora \
102
+ --lora_alpha ${lora_alpha} \
103
+ --weight_dir ${weight_dir} \
104
+ --test_ratio ${test_ratio}
scripts/eval_yiprompt.sh ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # export CUDA_VISIBLE_DEVICES=0,3,4,5,6,7
2
+ export OPENAI_API_KEY=...
3
+ num_frames=16
4
+ test_ratio=200
5
+
6
+ model_dir=MODELS/pllava-34b
7
+ weight_dir=MODELS/pllava-34b
8
+ SAVE_DIR=test_results/test_pllava_34b
9
+ lora_alpha=4
10
+ conv_mode=eval_vcg_llavanext
11
+ python -m tasks.eval.vcgbench.pllava_eval_vcgbench \
12
+ --pretrained_model_name_or_path ${model_dir} \
13
+ --save_path ${SAVE_DIR}/vcgbench \
14
+ --num_frames ${num_frames} \
15
+ --use_lora \
16
+ --lora_alpha ${lora_alpha} \
17
+ --weight_dir ${weight_dir} \
18
+ --pooling_shape 16-12-12 \
19
+ --test_ratio ${test_ratio} \
20
+ --conv_mode $conv_mode
21
+
22
+ conv_mode=eval_mvbench_llavanext
23
+ python -m tasks.eval.mvbench.pllava_eval_mvbench \
24
+ --pretrained_model_name_or_path ${model_dir} \
25
+ --save_path ${SAVE_DIR}/mvbench \
26
+ --use_lora \
27
+ --lora_alpha ${lora_alpha} \
28
+ --num_frames ${num_frames} \
29
+ --weight_dir ${weight_dir} \
30
+ --pooling_shape 16-12-12 \
31
+ --conv_mode $conv_mode
32
+
33
+ conv_mode=eval_videoqa_llavanext
34
+ python -m tasks.eval.videoqabench.pllava_eval_videoqabench \
35
+ --pretrained_model_name_or_path ${model_dir} \
36
+ --save_path ${SAVE_DIR}/videoqabench \
37
+ --num_frames ${num_frames} \
38
+ --use_lora \
39
+ --lora_alpha ${lora_alpha} \
40
+ --weight_dir ${weight_dir} \
41
+ --test_ratio ${test_ratio} \
42
+ --conv_mode ${conv_mode}
43
+
44
+ conv_mode=eval_recaption_llavanext
45
+ python -m tasks.eval.recaption.pllava_recaption \
46
+ --pretrained_model_name_or_path ${model_dir} \
47
+ --save_path ${SAVE_DIR}/recaption \
48
+ --num_frames ${num_frames} \
49
+ --use_lora \
50
+ --weight_dir ${weight_dir} \
51
+ --lora_alpha ${lora_alpha} \
52
+ --test_ratio ${test_ratio} \
53
+ --conv_mode $conv_mode
scripts/gallery.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export OPENAI_API_KEY=...
2
+ SAVE_DIR=${1:-"test_results"}
3
+
4
+ # # gallery view
5
+ # python -m tasks.eval.show_gallery \
6
+ # --root_dir ${SAVE_DIR}
7
+
8
+ # # compare view
9
+ python -m tasks.eval.demo.show_compare \
10
+ --root_dir ${SAVE_DIR}
11
+
scripts/train_pllava.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ echo "PYTHONPATH: ${PYTHONPATH}"
2
+ which_python=$(which python)
3
+ echo "which python: ${which_python}"
4
+ export PYTHONPATH=${PYTHONPATH}:${which_python}
5
+ export PYTHONPATH=${PYTHONPATH}:.
6
+ echo "PYTHONPATH: ${PYTHONPATH}"
7
+
8
+ OUTPUT_DIR=./pllava_video_outputs/test_train_7b_reconstruct
9
+
10
+ # # Naive Env
11
+ # rm -rf ${OUTPUT_DIR}
12
+ pooling_shape=(16,12,12)
13
+ accelerate launch --main_process_port 6876 --config_file scripts/accel_config_multigpu.yaml tasks/train/train_pllava_nframe_accel.py \
14
+ tasks/train/config_pllava_nframe.py \
15
+ output_dir ${OUTPUT_DIR} \
16
+ train_corpus videochat2_video \
17
+ save_steps 10000 \
18
+ num_workers 8 \
19
+ num_frames 16 \
20
+ model.pooling_method avg \
21
+ model.repo_id llava-hf/llava-v1.6-vicuna-7b-hf \
22
+ model.use_lora True \
23
+ model.pooling_shape $pooling_shape \
24
+ optimizer.lr 2e-5 \
25
+ scheduler.epochs 3 \
26
+ scheduler.warmup_ratio 0.2 \
27
+ scheduler.min_lr_multi 0.25 \
28
+ scheduler.is_videochat2_custom True \
29
+ preprocess.mm_alone False \
30
+ preprocess.random_shuffle False \
31
+ preprocess.add_second_msg False \
32
+ train_corpus videochat2_instruction_debug
33
+
34
+
scripts/train_pllava_13b.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ echo "PYTHONPATH: ${PYTHONPATH}"
2
+ which_python=$(which python)
3
+ echo "which python: ${which_python}"
4
+ export PYTHONPATH=${PYTHONPATH}:${which_python}
5
+ export PYTHONPATH=${PYTHONPATH}:.
6
+ echo "PYTHONPATH: ${PYTHONPATH}"
7
+
8
+ OUTPUT_DIR=./pllava_video_outputs/pllava_13b
9
+
10
+
11
+ pooling_shape=(16,12,12)
12
+ num_save_samples=80000
13
+ num_gpus=8
14
+ full_batch_size=128
15
+ batch_size=8
16
+ save_steps=$[$num_save_samples/($batch_size*$num_gpus)]
17
+ ckpt_steps=$[$save_steps/10]
18
+ gradient_accumulation_steps=$[$full_batch_size/($batch_size*$num_gpus)]
19
+ echo $batch_size
20
+ echo $gradient_accumulation_steps
21
+ repo_id=llava-hf/llava-v1.6-vicuna-13b-hf
22
+ accelerate launch --main_process_port 6876 --config_file scripts/accel_config_deepspeed_zero3_offload.yaml tasks/train/train_pllava_nframe_accel.py \
23
+ tasks/train/config_pllava_nframe.py \
24
+ output_dir ${OUTPUT_DIR} \
25
+ train_corpus videochat2_instruction_debug \
26
+ save_steps $save_steps \
27
+ ckpt_steps $ckpt_steps \
28
+ num_workers 8 \
29
+ num_frames 16 \
30
+ gradient_accumulation_steps $gradient_accumulation_steps \
31
+ batch_size $batch_size \
32
+ deepspeed True \
33
+ model.pooling_method avg \
34
+ model.use_lora True \
35
+ model.use_pooling True \
36
+ model.repo_id $repo_id \
37
+ gradient_checkpointing True \
38
+ preprocess.center_pad False \
39
+ preprocess.clip_transform False \
40
+ optimizer.lr 2e-5 \
41
+ scheduler.epochs 3 \
42
+ scheduler.warmup_ratio 0.2 \
43
+ scheduler.min_lr_multi 0.25 \
44
+ model.pooling_shape $pooling_shape \
45
+ scheduler.is_videochat2_custom True \
46
+ preprocess.mm_alone False \
47
+ preprocess.random_shuffle False \
48
+ preprocess.add_second_msg False
49
+
50
+
scripts/train_pllava_34b.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ echo "PYTHONPATH: ${PYTHONPATH}"
2
+ which_python=$(which python)
3
+ echo "which python: ${which_python}"
4
+ export PYTHONPATH=${PYTHONPATH}:${which_python}
5
+ export PYTHONPATH=${PYTHONPATH}:.
6
+ echo "PYTHONPATH: ${PYTHONPATH}"
7
+
8
+ machine_rank=${1:-"0"} # machine rank
9
+
10
+ OUTPUT_DIR=./pllava_video_outputs/pllava_34b_videchat2-video
11
+
12
+ pooling_shape=(16,12,12)
13
+ num_save_samples=80000
14
+ num_gpus=8
15
+ full_batch_size=128
16
+ batch_size=4
17
+ save_steps=$[$num_save_samples/($batch_size*$num_gpus)]
18
+ ckpt_steps=$[$save_steps/10]
19
+ gradient_accumulation_steps=$[$full_batch_size/($batch_size*$num_gpus)]
20
+ echo $batch_size
21
+ echo $gradient_accumulation_steps
22
+ repo_id=llava-hf/llava-v1.6-34b-hf
23
+ accelerate launch --main_process_port 6876 --config_file scripts/accel_config_deepspeed_zero3_offload.yaml tasks/train/train_pllava_nframe_accel.py \
24
+ tasks/train/config_pllava_nframe_yiprompt.py \
25
+ output_dir ${OUTPUT_DIR} \
26
+ train_corpus videochat2_instruction_debug \
27
+ save_steps $save_steps \
28
+ ckpt_steps $ckpt_steps \
29
+ num_workers 8 \
30
+ num_frames 16 \
31
+ deepspeed True \
32
+ gradient_accumulation_steps $gradient_accumulation_steps \
33
+ batch_size $batch_size \
34
+ model.pooling_method avg \
35
+ model.use_lora True \
36
+ model.use_pooling True \
37
+ model.repo_id $repo_id \
38
+ gradient_checkpointing True \
39
+ preprocess.center_pad False \
40
+ preprocess.clip_transform True \
41
+ optimizer.lr 2e-5 \
42
+ scheduler.epochs 3 \
43
+ scheduler.warmup_ratio 0.2 \
44
+ scheduler.min_lr_multi 0.25 \
45
+ model.pooling_shape $pooling_shape \
46
+ scheduler.is_videochat2_custom True \
47
+ preprocess.image_token_index 64002 \
48
+ preprocess.mm_alone False \
49
+ preprocess.random_shuffle False \
50
+ preprocess.add_second_msg False