omni-research commited on
Commit
97a05c0
1 Parent(s): 3e84302
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ tmp*
app.py CHANGED
@@ -1,64 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
41
 
 
 
 
 
 
 
 
42
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # copy and modify from: https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/demo/demo.py
16
+ import spaces
17
+ from copy import deepcopy
18
  import gradio as gr
19
+ from gradio.themes.utils import colors, fonts, sizes
20
+ from tools.conversation import Chat, conv_templates
21
+ from tools.utils import load_model_and_processor, file_to_base64
22
+ from dataset.processor import Processor
23
+ import os
24
+ import torch
25
 
26
+ device = 'cuda'
27
+ model_path = os.getenv("MODEL_PATH", "/home/user/checkpoints/Tarsier-7b")
28
+ max_n_frames = int(os.getenv("MAX_N_FRAMES", 8))
29
+ debug = True
30
+
31
+ # ========================================
32
+ # Model Initialization
33
+ # ========================================
34
+ def init_model():
35
+ print("Start Initialization...")
36
+ # if torch.cuda.is_available():
37
+ if not debug:
38
+ model, processor = load_model_and_processor(model_path, max_n_frames)
39
+ else:
40
+ print(f"No Valid GPU! Lauch in debug mode!")
41
+ processor = Processor(model_path, max_n_frames)
42
+ model = None
43
+ chat = Chat(model, processor, device, debug)
44
+ print('Initialization Finished')
45
+ return chat
46
+
47
+
48
+ # ========================================
49
+ # Gradio Setting
50
+ # ========================================
51
+ def gradio_reset(chat_state, img_file, img_list):
52
+ if chat_state is not None:
53
+ chat_state.messages = []
54
+ img_file = None
55
+ if img_list is not None:
56
+ img_list = []
57
+ return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_file, img_list
58
+
59
+
60
+ def upload_img(gr_img, gr_video, gr_gif, chat_state, num_frames):
61
+ print(gr_img, gr_video)
62
+ conv_type = ''
63
+ if 'tarsier2-7b' in model_path.lower():
64
+ conv_type = 'tarsier2-7b'
65
+ elif '7b' in model_path.lower():
66
+ conv_type = 'tarsier-7b'
67
+ elif '13b' in model_path.lower():
68
+ conv_type = 'tarsier-13b'
69
+ elif '34b' in model_path.lower():
70
+ conv_type = 'tarsier-34b'
71
+ else:
72
+ raise ValueError(f"Unknow model: {model_path}")
73
+ chat_state = deepcopy(conv_templates[conv_type])
74
+
75
+ img_list = []
76
+ if gr_img is None and gr_video is None and gr_gif is None:
77
+ return None, None, None, gr.update(interactive=True), gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None, None
78
+ if gr_video or gr_img or gr_gif:
79
+ for img_file in [gr_video, gr_video, gr_gif]:
80
+ if img_file is not None:
81
+ break
82
+ return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_file, img_list
83
+
84
+
85
+ def gradio_ask(user_message, chatbot, chat_state):
86
+ if len(user_message) == 0:
87
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
88
+ chat_state = chat.ask(user_message, chat_state)
89
+ chatbot = chatbot + [[user_message, None]]
90
+ return '', chatbot, chat_state
91
+
92
+ @spaces.GPU(duration=120)
93
+ def gradio_answer(chatbot, chat_state, img_file, img_list, top_p, temperature, n_frames=None):
94
+ llm_message, chat_state, img_list = chat.answer(conv=chat_state, visual_data_file=img_file, images=img_list, n_frames=n_frames, max_new_tokens=512, num_beams=1, temperature=temperature, top_p=top_p)
95
+ chatbot[-1][1] = llm_message
96
+ print(chat_state)
97
+ print(f"Answer: {llm_message}")
98
+ return chatbot, chat_state, img_list
99
+
100
+
101
+ class OpenGVLab(gr.themes.base.Base):
102
+ def __init__(
103
+ self,
104
+ *,
105
+ primary_hue=colors.blue,
106
+ secondary_hue=colors.sky,
107
+ neutral_hue=colors.gray,
108
+ spacing_size=sizes.spacing_md,
109
+ radius_size=sizes.radius_sm,
110
+ text_size=sizes.text_md,
111
+ font=(
112
+ fonts.GoogleFont("Noto Sans"),
113
+ "ui-sans-serif",
114
+ "sans-serif",
115
+ ),
116
+ font_mono=(
117
+ fonts.GoogleFont("IBM Plex Mono"),
118
+ "ui-monospace",
119
+ "monospace",
120
+ ),
121
  ):
122
+ super().__init__(
123
+ primary_hue=primary_hue,
124
+ secondary_hue=secondary_hue,
125
+ neutral_hue=neutral_hue,
126
+ spacing_size=spacing_size,
127
+ radius_size=radius_size,
128
+ text_size=text_size,
129
+ font=font,
130
+ font_mono=font_mono,
131
+ )
132
+ super().set(
133
+ body_background_fill="*neutral_50",
134
+ )
135
 
 
 
136
 
137
+ gvlabtheme = OpenGVLab(primary_hue=colors.blue,
138
+ secondary_hue=colors.sky,
139
+ neutral_hue=colors.gray,
140
+ spacing_size=sizes.spacing_md,
141
+ radius_size=sizes.radius_sm,
142
+ text_size=sizes.text_md,
143
+ )
144
 
145
+ logo_b64 = file_to_base64("assets/figures/tarsier_logo.jpg")
146
+ title = f"""<center><a href="https://github.com/bytedance/tarsier"><img src="data:image/jpeg;base64,{logo_b64}" alt="Tarsier" border="0" style="margin: 0 auto; height: 140px;" /></a></center>"""
147
+ description ="""<center><p><a href='https://github.com/bytedance/tarsier'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p></center>
148
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
+ with gr.Blocks(title="Tarsier",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
152
+ gr.Markdown(title)
153
+ gr.Markdown(description)
154
+ with gr.Row():
155
+ with gr.Column(scale=0.5, visible=True) as video_upload:
156
+ with gr.Column(elem_id="image", scale=0.5) as img_part:
157
+ with gr.Tab("Video", elem_id='video_tab'):
158
+ up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload", height=360)
159
+ with gr.Tab("Image", elem_id='image_tab'):
160
+ up_image = gr.Image(type="filepath", interactive=True, elem_id="image_upload", height=360)
161
+ with gr.Tab("GIF", elem_id='gif_tab'):
162
+ up_gif = gr.File(type="filepath", file_count="single", file_types=["gif"], interactive=True, elem_id="gif_upload", height=360)
163
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
164
+ clear = gr.Button("Restart")
165
+
166
+ # num_beams = gr.Slider(
167
+ # minimum=1,
168
+ # maximum=10,
169
+ # value=1,
170
+ # step=1,
171
+ # interactive=True,
172
+ # label="beam search numbers)",
173
+ # )
174
+
175
+ temperature = gr.Slider(
176
+ minimum=0.0,
177
+ maximum=1.0,
178
+ value=0.0,
179
+ step=0.1,
180
+ interactive=True,
181
+ label="Temperature",
182
+ )
183
+
184
+ top_p = gr.Slider(
185
+ minimum=0.1,
186
+ maximum=1.0,
187
+ value=1.0,
188
+ step=0.1,
189
+ interactive=True,
190
+ label="Top_p",
191
+ )
192
+
193
+ num_frames = gr.Slider(
194
+ minimum=4,
195
+ maximum=16,
196
+ value=8,
197
+ step=2,
198
+ interactive=True,
199
+ label="#Frames",
200
+ )
201
+
202
+ with gr.Column(visible=True) as input_raws:
203
+ chat_state = gr.State()
204
+ img_list = gr.State()
205
+ img_file = gr.State()
206
+ chatbot = gr.Chatbot(elem_id="chatbot",label='VideoChat')
207
+ with gr.Row():
208
+ with gr.Column(scale=0.7):
209
+ text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False, container=False)
210
+ with gr.Column(scale=0.15, min_width=0):
211
+ run = gr.Button("💭Send")
212
+ with gr.Column(scale=0.15, min_width=0):
213
+ clear = gr.Button("🔄Clear️")
214
+
215
+ chat = init_model()
216
+ upload_button.click(upload_img, [up_image, up_video, up_gif, chat_state, num_frames], [up_image, up_video, up_gif, text_input, upload_button, chat_state, img_file, img_list])
217
+
218
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
219
+ gradio_answer, [chatbot, chat_state, img_file, img_list, top_p, temperature, num_frames], [chatbot, chat_state, img_list]
220
+ )
221
+ run.click(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
222
+ gradio_answer, [chatbot, chat_state, img_file, img_list, top_p, temperature, num_frames], [chatbot, chat_state, img_list]
223
+ )
224
+ run.click(lambda: "", None, text_input)
225
+ clear.click(gradio_reset, [chat_state, img_file, img_list], [chatbot, up_image, up_video, up_gif, text_input, upload_button, chat_state, img_file, img_list], queue=False)
226
+
227
+
228
+ demo.launch()
229
+ # demo.launch(server_name="0.0.0.0", server_port=11451)
assets/figures/tarsier_logo.jpg ADDED
dataset/mm_dataset.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataset.utils import get_visual_type, sample_frame_indices
15
+ from .processor import Processor
16
+ from tools.rw_utils import read_jsonlines
17
+
18
+ class MMDataset(object):
19
+ def __init__(self, ann_path="", anns=None, processor:Processor=None):
20
+ self.processor = processor
21
+ if anns is None:
22
+ self.anns = []
23
+ if isinstance(ann_path, str):
24
+ ann_path = [ann_path]
25
+ for path in ann_path:
26
+ self.anns.extend(read_jsonlines(path))
27
+ else:
28
+ self.anns = anns
29
+
30
+ def __len__(self):
31
+ return len(self.anns)
32
+
33
+ def __getitem__(self, index):
34
+ try:
35
+ ann = self.anns[index]
36
+
37
+ prompt = ann['text']['prompt']
38
+
39
+ video_file = ann['video_file']
40
+ visual_files = []
41
+ start_time = ann.get("start_time", 0)
42
+ end_time = ann.get("end_time", -1)
43
+ if isinstance(video_file, list):
44
+ # This is for MVBench/Episodic Reasoning
45
+ # The video_file are a list of sorted frames extract from the target video
46
+ for img_file in video_file:
47
+ if get_visual_type(img_file) == 'image':
48
+ visual_files.append(img_file)
49
+ frame_indices = sample_frame_indices(start_frame=0, total_frames=len(visual_files), n_frames=min(len(visual_files), self.processor.max_n_frames))
50
+ visual_files = [v for i,v in enumerate(visual_files) if i in frame_indices]
51
+ else:
52
+ if get_visual_type(video_file) in ['image', 'video', 'gif']:
53
+ visual_files.append(video_file)
54
+ assert len(visual_files) >= 0, f"Failed to load valid visual file from anns[{index}]!"
55
+ images = []
56
+ for v_f in visual_files:
57
+ images.extend(self.processor.load_images(v_f, start_time=start_time, end_time=end_time))
58
+ model_inputs = self.processor(prompt, images=images, edit_prompt=True, return_prompt=True)
59
+ except Exception as e:
60
+ print(f"Load data error: {e}")
61
+ return ann, None
62
+ return ann, model_inputs
dataset/processor.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from PIL import Image
15
+ from typing import List
16
+ import torch
17
+ from transformers import DataCollatorForSeq2Seq
18
+ from transformers.models.llava import LlavaProcessor
19
+ import re
20
+
21
+ from .utils import sample_image, sample_video, sample_gif, get_visual_type
22
+
23
+ ext2sampler = {
24
+ 'image': sample_image,
25
+ 'gif': sample_gif,
26
+ 'video': sample_video
27
+ }
28
+
29
+ class CustomImageProcessor:
30
+ def __init__(self, processor) -> None:
31
+ self.processor = processor
32
+
33
+ def __call__(self, images: List[Image.Image], do_padding=False) -> torch.Tensor:
34
+ if do_padding:
35
+ images = [self.expand2square(
36
+ img,
37
+ tuple(int(x * 255) for x in self.processor.image_processor.image_mean)
38
+ ) for img in images]
39
+ else:
40
+ images = [self.resize2square(img) for img in images]
41
+ images_pixel = self.processor(text="", images=images, return_tensors="pt")['pixel_values']
42
+ return images_pixel # [num_images, 3, 336, 336]
43
+
44
+ def expand2square(self, pil_img, background_color):
45
+ width, height = pil_img.size
46
+ if width == height:
47
+ return pil_img
48
+ elif width > height:
49
+ result = Image.new(pil_img.mode, (width, width), background_color)
50
+ result.paste(pil_img, (0, (width - height) // 2))
51
+ return result
52
+ else:
53
+ result = Image.new(pil_img.mode, (height, height), background_color)
54
+ result.paste(pil_img, ((height - width) // 2, 0))
55
+ return result
56
+
57
+ def resize2square(self, pil_img: Image.Image):
58
+ width, height = pil_img.size
59
+ pil_img = pil_img.resize((max(width, height), max(width, height)))
60
+ return pil_img
61
+
62
+ class Processor(object):
63
+ def __init__(
64
+ self,
65
+ model_name_or_path,
66
+ max_n_frames=8,
67
+ max_seq_len=None,
68
+ add_sep=False,
69
+ do_image_padding=False,
70
+ ):
71
+ self.max_n_frames = max_n_frames
72
+ self.max_seq_len = max_seq_len,
73
+ self.add_sep = add_sep
74
+ self.do_image_padding = do_image_padding
75
+ if not self.do_image_padding:
76
+ print(f"### do_image_padding is set as False, images will be resized directly!")
77
+
78
+ self.setup(model_name_or_path)
79
+
80
+
81
+ def setup(self, model_name_or_path):
82
+ sub_processor = LlavaProcessor.from_pretrained(
83
+ model_name_or_path,
84
+ padding_side='left',
85
+ trust_remote_code=True,
86
+ )
87
+ self.processor = CustomImageProcessor(sub_processor)
88
+ self.tokenizer = sub_processor.tokenizer
89
+ # self.pad_collator = DataCollatorForSeq2Seq(self.tokenizer, padding='longest')
90
+ self.sep_id = self.tokenizer.sep_token_id
91
+ self.pad_id = self.tokenizer.pad_token_id
92
+ self.eos_id = self.tokenizer.eos_token_id
93
+
94
+ if self.sep_id is None:
95
+ self.add_sep = False
96
+ if not self.max_seq_len:
97
+ self.max_seq_len = self.tokenizer.model_max_length
98
+
99
+ def process_prompt(self, prompt, images: List[Image.Image]=None):
100
+ if not images:
101
+ prompt = prompt.replace("<image>", "").replace("<video>", "")
102
+ elif images is not None:
103
+ prompt = prompt.replace("<video>", "<image>"*len(images))
104
+ image_token_num = len(re.findall('<image>', prompt, re.S))
105
+ if image_token_num == 0:
106
+ prompt_parts = re.findall(r'USER:(.*)ASSISTANT:(.*)', prompt, re.S)
107
+ if prompt_parts and len(prompt_parts) == 2:
108
+ p1, p2 = prompt_parts
109
+ else:
110
+ p1 = prompt
111
+ p2 = ''
112
+ prompt = f"USER: {'<image>'*len(images) + ' ' + p1.strip()} ASSISTANT: {p2.strip()}"
113
+ assert image_token_num == len(images)
114
+
115
+ if not re.findall(r'USER:(.*)ASSISTANT:(.*)', prompt, re.S):
116
+ prompt = f'USER: {prompt} ASSISTANT: '
117
+ return prompt
118
+
119
+ def select_frames_sampler(self, visual_data_path):
120
+ visual_type = get_visual_type(visual_data_path)
121
+ if visual_type in ext2sampler:
122
+ return ext2sampler[visual_type]
123
+ else:
124
+ raise ValueError(f"Unsupported data format: {visual_data_path}")
125
+
126
+ def load_images(self, visual_data_path, n_frames=None, start_time=0, end_time=-1):
127
+ sampler = self.select_frames_sampler(visual_data_path)
128
+ return sampler(visual_data_path, n_frames=min(n_frames, self.max_n_frames) if n_frames else self.max_n_frames, start_time=start_time, end_time=end_time)
129
+
130
+ def get_pixel_values(self, images):
131
+ if images is not None and len(images) > 0:
132
+ pixel_values = self.processor(images=images, do_padding=self.do_image_padding)
133
+ else:
134
+ pixel_values = None
135
+ return pixel_values
136
+
137
+ def get_text_inputs(self, text):
138
+ prompt_ids = self.tokenizer.encode(text, add_special_tokens=True) # will add <s>
139
+ if self.add_sep:
140
+ prompt_ids = prompt_ids + [self.sep_id]
141
+ prompt_ids = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(dim=0)
142
+ return prompt_ids
143
+
144
+ def get_inputs(self, prompt, visual_data_file=None, images=None, n_frames=None, edit_prompt=False, return_prompt=False):
145
+ if images is None:
146
+ images = self.load_images(visual_data_file, n_frames) if visual_data_file else None
147
+ if edit_prompt:
148
+ prompt = self.process_prompt(prompt, images)
149
+ text_inputs = self.get_text_inputs(prompt)
150
+ pixel_values = self.get_pixel_values(images)
151
+ inputs = {
152
+ "input_ids": text_inputs,
153
+ "pixel_values": pixel_values
154
+ }
155
+ if return_prompt:
156
+ inputs['prompt'] = prompt
157
+ return inputs
158
+
159
+ def __call__(self, prompt, visual_data_file=None, images=None, n_frames=None, edit_prompt=False, return_prompt=False):
160
+ return self.get_inputs(prompt, visual_data_file, images, n_frames, edit_prompt, return_prompt)
dataset/utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+ import os
16
+ from PIL import Image, ImageSequence
17
+ import decord
18
+
19
+ VALID_DATA_FORMAT_STRING = "Input data must be {'.jpg', '.jpeg', '.png', '.tif'} for image; or {'.mp4', '.avi', '.webm', '.mov', '.mkv', '.wmv', '.gif'} for videos!"
20
+
21
+ # 均匀抽帧,必采样首尾帧。
22
+ def sample_frame_indices(start_frame, total_frames: int, n_frames: int):
23
+ if n_frames == 1:
24
+ return [0] # sample first frame in default
25
+ sample_ids = [round(i * (total_frames - 1) / (n_frames - 1)) for i in range(n_frames)]
26
+ sample_ids = [i + start_frame for i in sample_ids]
27
+ return sample_ids
28
+
29
+ def sample_video(
30
+ video_path: str,
31
+ n_frames: int = None,
32
+ start_time: int = 0,
33
+ end_time: int = -1
34
+ ) -> List[Image.Image]:
35
+
36
+ assert os.path.exists(video_path), f"File not found: {video_path}"
37
+ vr = decord.VideoReader(video_path, num_threads=1, ctx=decord.cpu(0))
38
+ vr.seek(0)
39
+ total_frames = len(vr)
40
+ fps = vr.get_avg_fps()
41
+
42
+ start_frame = 0
43
+ end_frame = total_frames - 1
44
+ if start_time > 0:
45
+ start_frame = min((total_frames-1), int(fps*start_time))
46
+ if end_time > 0:
47
+ end_frame = max(start_frame, int(fps*end_time))
48
+ end_frame = min(end_frame, (total_frames-1))
49
+ frame_indices = sample_frame_indices(
50
+ start_frame=start_frame,
51
+ total_frames=end_frame - start_frame + 1,
52
+ n_frames=n_frames,
53
+ )
54
+
55
+ frames = vr.get_batch(frame_indices).asnumpy()
56
+ frames = [Image.fromarray(f).convert('RGB') for f in frames]
57
+ return frames
58
+
59
+ def sample_gif(
60
+ gif_path: str,
61
+ n_frames:int = None,
62
+ start_time: int = 0,
63
+ end_time: int = -1
64
+ ) -> List[Image.Image]:
65
+
66
+ assert os.path.exists(gif_path), f"File not found: {gif_path}"
67
+
68
+ gif_frames = Image.open(gif_path)
69
+
70
+ start_frame = 0
71
+ end_frame = gif_frames.n_frames - 1
72
+ frame_indices = sample_frame_indices(
73
+ start_frame=start_frame,
74
+ total_frames=end_frame - start_frame + 1,
75
+ n_frames=n_frames,
76
+ )
77
+
78
+ frames = []
79
+ i = 0
80
+ for frame in ImageSequence.Iterator(gif_frames):
81
+ if i in frame_indices:
82
+ frames.append(frame.convert('RGB'))
83
+ i += 1
84
+ return frames
85
+
86
+ def sample_image(
87
+ image_path: str,
88
+ n_frames: int = None,
89
+ start_time: int = 0,
90
+ end_time: int = -1
91
+ ):
92
+ assert os.path.exists(image_path), f"File not found: {image_path}"
93
+ image = Image.open(image_path).convert('RGB')
94
+ return [image]
95
+
96
+ def get_visual_type(input_file):
97
+ ext = os.path.splitext(input_file)[-1]
98
+ if ext in {'.gif'}:
99
+ return 'gif'
100
+ elif ext in {'.mp4', '.avi', '.webm', '.mov', '.mkv', '.wmv'}:
101
+ return 'video'
102
+ elif ext in {'.jpg', '.jpeg', '.png', '.tif'}:
103
+ return 'image'
104
+ else:
105
+ print(f"{VALID_DATA_FORMAT_STRING} But found {ext}!")
106
+ return 'unk'
107
+
108
+ def get_benchmarks(benchmarks):
109
+ final_benchmarks = []
110
+ type2bm = {
111
+ 'dream': ['dream'],
112
+ 'caption': ['msvd-caption', 'msr-vtt-caption', 'vatex-caption'],
113
+ 'mc_qa': ['next-qa', 'egoschema', 'mvbench', 'video-mme'],
114
+ 'oe_qa': ['msvd-qa', 'msr-vtt-qa', 'tgif-qa', 'anet-qa'],
115
+ }
116
+ for bm in benchmarks:
117
+ bm = bm.lower()
118
+ if bm in final_benchmarks:
119
+ continue
120
+ if bm == 'all':
121
+ for v in type2bm.values():
122
+ final_benchmarks.extend(v)
123
+ return final_benchmarks
124
+ if bm in type2bm:
125
+ final_benchmarks.extend(type2bm[bm])
126
+ else:
127
+ final_benchmarks.append(bm)
128
+ return final_benchmarks
models/modeling_tarsier.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # copy and modify from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
16
+ """ PyTorch Llava model."""
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+ import math
20
+ import numpy as np
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+
27
+ from transformers import PreTrainedModel
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache
30
+ from transformers.modeling_outputs import ModelOutput
31
+ from transformers.utils import (
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ logging,
35
+ replace_return_docstrings,
36
+ )
37
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM, CONFIG_MAPPING
38
+ from transformers import LlamaForCausalLM
39
+ from transformers.configuration_utils import PretrainedConfig
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
45
+ "llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json",
46
+ }
47
+
48
+ class LlavaConfig(PretrainedConfig):
49
+ r"""
50
+ This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
51
+ Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
52
+ with the defaults will yield a similar configuration to that of the Llava-9B.
53
+
54
+ e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
55
+
56
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
57
+ documentation from [`PretrainedConfig`] for more information.
58
+
59
+ Args:
60
+ vision_config (`LlavaVisionConfig`, *optional*):
61
+ Custom vision config or dict
62
+ text_config (`Union[AutoConfig, dict]`, *optional*):
63
+ The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
64
+ ignore_index (`int`, *optional*, defaults to -100):
65
+ The ignore index for the loss function.
66
+ image_token_index (`int`, *optional*, defaults to 32000):
67
+ The image token index to encode the image prompt.
68
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
69
+ The activation function used by the multimodal projector.
70
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
71
+ The feature selection strategy used to select the vision feature from the CLIP backbone.
72
+ vision_feature_layer (`int`, *optional*, defaults to -2):
73
+ The index of the layer to select the vision feature.
74
+ vocab_size (`int`, *optional*, defaults to 32000):
75
+ Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the
76
+ `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]
77
+
78
+ Example:
79
+
80
+ ```python
81
+ >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
82
+
83
+ >>> # Initializing a CLIP-vision config
84
+ >>> vision_config = CLIPVisionConfig()
85
+
86
+ >>> # Initializing a Llama config
87
+ >>> text_config = LlamaConfig()
88
+
89
+ >>> # Initializing a Llava llava-1.5-7b style configuration
90
+ >>> configuration = LlavaConfig(vision_config, text_config)
91
+
92
+ >>> # Initializing a model from the llava-1.5-7b style configuration
93
+ >>> model = LlavaForConditionalGeneration(configuration)
94
+
95
+ >>> # Accessing the model configuration
96
+ >>> configuration = model.config
97
+ ```"""
98
+
99
+ model_type = "llava"
100
+ is_composition = False
101
+
102
+ def __init__(
103
+ self,
104
+ vision_config=None,
105
+ text_config=None,
106
+ ignore_index=-100,
107
+ image_token_index=32000,
108
+ projector_hidden_act="gelu",
109
+ vision_feature_select_strategy="default",
110
+ vision_feature_layer=-2,
111
+ vocab_size=32000,
112
+ image_newline_idx=32002,
113
+ image_new_idx=32003,
114
+ **kwargs,
115
+ ):
116
+ self.ignore_index = ignore_index
117
+ self.image_token_index = image_token_index
118
+ self.projector_hidden_act = projector_hidden_act
119
+ self.vision_feature_select_strategy = vision_feature_select_strategy
120
+ self.vision_feature_layer = vision_feature_layer
121
+ self.vocab_size = vocab_size
122
+ self.image_newline_idx = image_newline_idx
123
+ self.image_new_idx = image_new_idx
124
+
125
+ self.vision_config = vision_config
126
+
127
+ if isinstance(self.vision_config, dict):
128
+ vision_config["model_type"] = (
129
+ vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
130
+ )
131
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
132
+ elif vision_config is None:
133
+ self.vision_config = CONFIG_MAPPING["clip_vision_model"](
134
+ intermediate_size=4096,
135
+ hidden_size=1024,
136
+ patch_size=14,
137
+ image_size=336,
138
+ num_hidden_layers=24,
139
+ num_attention_heads=16,
140
+ vocab_size=32000,
141
+ projection_dim=768,
142
+ )
143
+ self.vocab_size = self.vocab_size
144
+
145
+ self.text_config = text_config
146
+
147
+ if isinstance(self.text_config, dict):
148
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
149
+ self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
150
+ self.vocab_size = self.text_config.vocab_size
151
+ elif text_config is None:
152
+ self.text_config = CONFIG_MAPPING["llama"]()
153
+
154
+ super().__init__(**kwargs)
155
+
156
+
157
+ logger = logging.get_logger(__name__)
158
+
159
+ _CONFIG_FOR_DOC = "LlavaConfig"
160
+
161
+ LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
162
+ "llava-hf/llava-1.5-7b-hf",
163
+ "llava-hf/llava-1.5-13b-hf",
164
+ "llava-hf/bakLlava-v1-hf",
165
+ # See all Llava models at https://huggingface.co/models?filter=llava
166
+ ]
167
+
168
+
169
+ class Llava3DPositionalEncoding(nn.Module):
170
+ def __init__(self, num_pos, dim) -> None:
171
+ super().__init__()
172
+ dim1, dim2, dim3 = self.split_dim(dim)
173
+ frame_position_encodings = self.create_sinusoidal_positions(num_pos, dim1)
174
+ height_position_encodings = self.create_sinusoidal_positions(num_pos, dim2)
175
+ width_position_encodings = self.create_sinusoidal_positions(num_pos, dim3)
176
+
177
+ self.register_buffer('frame_position_encodings', frame_position_encodings, persistent=False)
178
+ self.register_buffer('height_position_encodings', height_position_encodings, persistent=False)
179
+ self.register_buffer('width_position_encodings', width_position_encodings, persistent=False)
180
+
181
+ def split_dim(self, dim):
182
+ dim1 = dim // 3
183
+ if dim1 % 2 != 0:
184
+ dim1 -= 1
185
+
186
+ dim2 = dim // 3
187
+ if dim2 % 2 != 0:
188
+ dim2 -= 1
189
+
190
+ dim3 = dim - dim1 - dim2
191
+ return dim1, dim2, dim3
192
+
193
+ def create_sinusoidal_positions(self, num_pos: int, dim: int) -> torch.Tensor:
194
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
195
+ sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
196
+ return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
197
+
198
+ def forward(self, frame_position_ids, height_position_ids, width_position_ids):
199
+ frame_position_embeds = F.embedding(frame_position_ids, self.frame_position_encodings)
200
+ height_position_embeds = F.embedding(height_position_ids, self.height_position_encodings)
201
+ width_position_embeds = F.embedding(width_position_ids, self.width_position_encodings)
202
+
203
+ return torch.cat([frame_position_embeds, height_position_embeds, width_position_embeds], dim = -1)
204
+
205
+
206
+ @dataclass
207
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
208
+ class LlavaCausalLMOutputWithPast(ModelOutput):
209
+ """
210
+ Base class for Llava causal language model (or autoregressive) outputs.
211
+
212
+ Args:
213
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
214
+ Language modeling loss (for next-token prediction).
215
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
216
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
217
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
218
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
219
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
220
+
221
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
222
+ `past_key_values` input) to speed up sequential decoding.
223
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
224
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
225
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
226
+
227
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
228
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
229
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
230
+ sequence_length)`.
231
+
232
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
233
+ heads.
234
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
235
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
236
+ sequence_length, hidden_size)`.
237
+
238
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
239
+ """
240
+
241
+ loss: Optional[torch.FloatTensor] = None
242
+ logits: torch.FloatTensor = None
243
+ past_key_values: Optional[List[torch.FloatTensor]] = None
244
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
245
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
246
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
247
+ vision_outputs: Optional[torch.FloatTensor] = None
248
+ llm_attn_mask: Optional[Tuple[torch.FloatTensor]] = None
249
+
250
+
251
+ class LlavaMultiModalProjector(nn.Module):
252
+ def __init__(self, config: LlavaConfig):
253
+ super().__init__()
254
+
255
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
256
+ self.act = ACT2FN[config.projector_hidden_act]
257
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
258
+
259
+ def forward(self, image_features):
260
+ hidden_states = self.linear_1(image_features)
261
+ hidden_states = self.act(hidden_states)
262
+ hidden_states = self.linear_2(hidden_states)
263
+ return hidden_states
264
+
265
+
266
+ TARSIER_START_DOCSTRING = r"""
267
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
268
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
269
+ etc.)
270
+
271
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
272
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
273
+ and behavior.
274
+
275
+ Parameters:
276
+ config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
277
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
278
+ load the weights associated with the model, only the configuration. Check out the
279
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
280
+ """
281
+
282
+
283
+ @add_start_docstrings(
284
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
285
+ TARSIER_START_DOCSTRING,
286
+ )
287
+ class TarsierPreTrainedModel(PreTrainedModel):
288
+ config_class = LlavaConfig
289
+ base_model_prefix = "model"
290
+ supports_gradient_checkpointing = True
291
+ _no_split_modules = ["LlavaVisionAttention"]
292
+ _skip_keys_device_placement = "past_key_values"
293
+ _supports_flash_attn_2 = True
294
+
295
+ def _init_weights(self, module):
296
+ # important: this ported version of Llava isn't meant for training from scratch - only
297
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
298
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
299
+ std = (
300
+ self.config.initializer_range
301
+ if hasattr(self.config, "initializer_range")
302
+ else self.config.text_config.initializer_range
303
+ )
304
+
305
+ if hasattr(module, "class_embedding"):
306
+ module.class_embedding.data.normal_(mean=0.0, std=std)
307
+
308
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
309
+ module.weight.data.normal_(mean=0.0, std=std)
310
+ if module.bias is not None:
311
+ module.bias.data.zero_()
312
+ elif isinstance(module, nn.Embedding):
313
+ module.weight.data.normal_(mean=0.0, std=std)
314
+ if module.padding_idx is not None:
315
+ module.weight.data[module.padding_idx].zero_()
316
+
317
+ @property
318
+ def _supports_sdpa(self):
319
+ """
320
+ Retrieve language_model's attribute to check whether the model supports
321
+ SDPA or not.
322
+ """
323
+ return self.language_model._supports_sdpa
324
+
325
+
326
+ TARSIER_INPUTS_DOCSTRING = r"""
327
+ Args:
328
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
329
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
330
+ it.
331
+
332
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
333
+ [`PreTrainedTokenizer.__call__`] for details.
334
+
335
+ [What are input IDs?](../glossary#input-ids)
336
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
337
+ The tensors corresponding to the input images. Pixel values can be obtained using
338
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
339
+ [`CLIPImageProcessor`] for processing images).
340
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
341
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
342
+
343
+ - 1 for tokens that are **not masked**,
344
+ - 0 for tokens that are **masked**.
345
+
346
+ [What are attention masks?](../glossary#attention-mask)
347
+
348
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
349
+ [`PreTrainedTokenizer.__call__`] for details.
350
+
351
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
352
+ `past_key_values`).
353
+
354
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
355
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
356
+ information on the default strategy.
357
+
358
+ - 1 indicates the head is **not masked**,
359
+ - 0 indicates the head is **masked**.
360
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
361
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
362
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
363
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
364
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
365
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
366
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
367
+
368
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
369
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
370
+
371
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
372
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
373
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
374
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
375
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
376
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
377
+ model's internal embedding lookup matrix.
378
+ use_cache (`bool`, *optional*):
379
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
380
+ `past_key_values`).
381
+ output_attentions (`bool`, *optional*):
382
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
383
+ tensors for more detail.
384
+ output_hidden_states (`bool`, *optional*):
385
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
386
+ more detail.
387
+ return_dict (`bool`, *optional*):
388
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
389
+ """
390
+
391
+
392
+ @add_start_docstrings(
393
+ """The LLAVA model which consists of a vision backbone and a language model.""",
394
+ TARSIER_INPUTS_DOCSTRING,
395
+ )
396
+ class TarsierForConditionalGeneration(TarsierPreTrainedModel):
397
+ def __init__(self, config: LlavaConfig):
398
+ super().__init__(config)
399
+ self.vision_tower = AutoModel.from_config(config.vision_config, trust_remote_code=True)
400
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
401
+ self.vocab_size = config.vocab_size
402
+ self.language_model = AutoModelForCausalLM.from_config(config.text_config, attn_implementation="flash_attention_2")
403
+ image_newline_idx = torch.tensor([config.image_newline_idx], dtype=torch.long)
404
+ image_new_idx = torch.tensor([config.image_new_idx], dtype=torch.long)
405
+ self.register_buffer('image_newline_idx', image_newline_idx, persistent=False)
406
+ self.register_buffer('image_new_idx', image_new_idx, persistent=False)
407
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
408
+ self.post_init()
409
+
410
+ def get_input_embeddings(self):
411
+ return self.language_model.get_input_embeddings()
412
+
413
+ def set_input_embeddings(self, value):
414
+ self.language_model.set_input_embeddings(value)
415
+
416
+ def get_output_embeddings(self):
417
+ return self.language_model.get_output_embeddings()
418
+
419
+ def set_output_embeddings(self, new_embeddings):
420
+ self.language_model.set_output_embeddings(new_embeddings)
421
+
422
+ def set_decoder(self, decoder):
423
+ self.language_model.set_decoder(decoder)
424
+
425
+ def get_decoder(self):
426
+ return self.language_model.get_decoder()
427
+
428
+ def tie_weights(self):
429
+ return self.language_model.tie_weights()
430
+
431
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
432
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
433
+ # update vocab size
434
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
435
+ self.config.vocab_size = model_embeds.num_embeddings
436
+ self.vocab_size = model_embeds.num_embeddings
437
+ return model_embeds
438
+
439
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
440
+ num_images, num_image_patches, embed_dim = image_features.shape
441
+
442
+ batch_size, sequence_length = input_ids.shape
443
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
444
+ # 1. Create a mask to know where special image tokens are
445
+ special_image_token_mask = input_ids == self.config.image_token_index
446
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
447
+ # Compute the maximum embed dimension
448
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
449
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
450
+
451
+ # 2. Compute the positions where text should be written
452
+ # Calculate new positions for text tokens in merged image-text sequence.
453
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
454
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
455
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
456
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
457
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
458
+ if left_padding:
459
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
460
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
461
+
462
+ # 3. Create the full embedding, already padded to the maximum position
463
+ final_embedding = torch.zeros(
464
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
465
+ )
466
+ final_attention_mask = torch.zeros(
467
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
468
+ )
469
+ if labels is not None:
470
+ final_labels = torch.full(
471
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
472
+ )
473
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
474
+ # set the corresponding tensors into their correct target device.
475
+ target_device = inputs_embeds.device
476
+ batch_indices, non_image_indices, text_to_overwrite = (
477
+ batch_indices.to(target_device),
478
+ non_image_indices.to(target_device),
479
+ text_to_overwrite.to(target_device),
480
+ )
481
+ attention_mask = attention_mask.to(target_device)
482
+
483
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
484
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
485
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
486
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
487
+ if labels is not None:
488
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
489
+
490
+ # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
491
+ image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
492
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
493
+
494
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
495
+ raise ValueError(
496
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
497
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
498
+ )
499
+
500
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
501
+ final_attention_mask |= image_to_overwrite
502
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
503
+
504
+ if labels is None:
505
+ final_labels = None
506
+
507
+ return final_embedding, final_attention_mask, final_labels, position_ids
508
+
509
+ def add_split_tokens(self, image_features):
510
+ num_images, num_image_patches, embed_dim = image_features.shape
511
+ num_height_patches, num_width_patches = int(math.sqrt(num_image_patches)), int(math.sqrt(num_image_patches))
512
+
513
+ # add image_newline
514
+ image_newline = self.get_input_embeddings()(self.image_newline_idx).squeeze()
515
+ image_features = image_features.view(num_images, num_height_patches, num_width_patches, embed_dim)
516
+ image_features = torch.cat([
517
+ image_features,
518
+ image_newline.expand((num_images, num_height_patches, 1, embed_dim)).to(device=image_features.device)
519
+ ], dim=2)
520
+ num_image_patches += num_height_patches
521
+ image_features = image_features.view(num_images, num_image_patches, embed_dim)
522
+
523
+ # add image_new
524
+ image_new = self.get_input_embeddings()(self.image_new_idx).squeeze()
525
+ image_features = torch.cat([
526
+ image_features,
527
+ image_new.expand((num_images, 1, embed_dim)).to(device=image_features.device)
528
+ ], dim = 1)
529
+
530
+ return image_features
531
+
532
+ @add_start_docstrings_to_model_forward(TARSIER_INPUTS_DOCSTRING)
533
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
534
+ def forward(
535
+ self,
536
+ input_ids: torch.LongTensor = None,
537
+ pixel_values: torch.FloatTensor = None,
538
+ attention_mask: Optional[torch.Tensor] = None,
539
+ position_ids: Optional[torch.LongTensor] = None,
540
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
541
+ inputs_embeds: Optional[torch.FloatTensor] = None,
542
+ vision_feature_layer: Optional[int] = None,
543
+ vision_feature_select_strategy: Optional[str] = None,
544
+ labels: Optional[torch.LongTensor] = None,
545
+ use_cache: Optional[bool] = None,
546
+ output_attentions: Optional[bool] = None,
547
+ output_hidden_states: Optional[bool] = None,
548
+ return_dict: Optional[bool] = None,
549
+ **kwargs,
550
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
551
+ r"""
552
+ Args:
553
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
554
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
555
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
556
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
557
+
558
+ Returns:
559
+
560
+ Example:
561
+
562
+ ```python
563
+ >>> from PIL import Image
564
+ >>> import requests
565
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
566
+
567
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
568
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
569
+
570
+ >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
571
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
572
+ >>> image = Image.open(requests.get(url, stream=True).raw)
573
+
574
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
575
+
576
+ >>> # Generate
577
+ >>> generate_ids = model.generate(**inputs, max_length=30)
578
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
579
+ "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
580
+ ```"""
581
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
582
+ output_hidden_states = (
583
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
584
+ )
585
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
586
+ vision_feature_layer = (
587
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
588
+ )
589
+ vision_feature_select_strategy = (
590
+ vision_feature_select_strategy
591
+ if vision_feature_select_strategy is not None
592
+ else self.config.vision_feature_select_strategy
593
+ )
594
+
595
+ image_features = None
596
+ if inputs_embeds is None:
597
+ # 1. Extra the input embeddings
598
+ inputs_embeds = self.get_input_embeddings()(input_ids)
599
+
600
+ # 2. Merge text and images
601
+ if pixel_values is not None and input_ids.shape[1] != 1:
602
+ pixel_values = pixel_values.to(dtype=self.vision_tower.dtype)
603
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
604
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
605
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
606
+
607
+ if vision_feature_select_strategy == "default":
608
+ selected_image_feature = selected_image_feature[:, 1:]
609
+ elif vision_feature_select_strategy == "full":
610
+ selected_image_feature = selected_image_feature
611
+ else:
612
+ raise ValueError(
613
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
614
+ )
615
+
616
+ image_features = self.multi_modal_projector(selected_image_feature)
617
+
618
+ special_image_token_mask = input_ids == self.config.image_token_index
619
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim = -1)
620
+
621
+ image_features = self.add_split_tokens(image_features)
622
+
623
+ if sum(num_special_image_tokens) > 0:
624
+ # print(f'num_special_image_tokens: {num_special_image_tokens}')
625
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
626
+ image_features, inputs_embeds, input_ids, attention_mask, labels
627
+ )
628
+ else:
629
+ inputs_embeds = image_features.sum(dim=(0,1))[None, None, :] * 0. + inputs_embeds
630
+
631
+ if labels is None:
632
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
633
+ else:
634
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
635
+ # generation with cache
636
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
637
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
638
+ # that are set to 0
639
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
640
+
641
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
642
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
643
+
644
+ # Get the target length
645
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
646
+ extended_attention_mask = torch.ones(
647
+ (attention_mask.shape[0], target_seqlen),
648
+ dtype=attention_mask.dtype,
649
+ device=attention_mask.device,
650
+ )
651
+
652
+ extended_attention_mask[batch_index, non_attended_tokens] = 0
653
+
654
+ valid_indices = torch.ones_like(attention_mask)
655
+ valid_indices[:, 0] = target_seqlen - extended_attention_mask.sum(dim=-1)
656
+ valid_indices = torch.cumsum(valid_indices, dim=-1)
657
+ extended_attention_mask = extended_attention_mask.scatter(1, valid_indices, attention_mask)
658
+ attention_mask = extended_attention_mask
659
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
660
+ outputs = self.language_model(
661
+ attention_mask=attention_mask,
662
+ position_ids=position_ids,
663
+ past_key_values=past_key_values,
664
+ inputs_embeds=inputs_embeds,
665
+ use_cache=use_cache,
666
+ output_attentions=output_attentions,
667
+ output_hidden_states=output_hidden_states,
668
+ # use_rmpad=kwargs.get("use_rmpad", False),
669
+ return_dict=return_dict,
670
+ )
671
+
672
+ logits = outputs[0]
673
+
674
+ loss = None
675
+ if labels is not None:
676
+ # Shift so that tokens < n predict n
677
+ if attention_mask is not None:
678
+ shift_attention_mask = attention_mask[..., 1:]
679
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
680
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
681
+ else:
682
+ shift_logits = logits[..., :-1, :].contiguous()
683
+ shift_labels = labels[..., 1:].contiguous()
684
+ # Flatten the tokens
685
+ loss_fct = nn.CrossEntropyLoss()
686
+ loss = loss_fct(
687
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
688
+ )
689
+
690
+ if not return_dict:
691
+ output = (logits,) + outputs[1:]
692
+ return (loss,) + output if loss is not None else output
693
+
694
+ return LlavaCausalLMOutputWithPast(
695
+ loss=loss,
696
+ logits=logits,
697
+ past_key_values=outputs.past_key_values,
698
+ hidden_states=outputs.hidden_states,
699
+ attentions=outputs.attentions,
700
+ llm_attn_mask=attention_mask
701
+ )
702
+
703
+ def prepare_inputs_for_generation(
704
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
705
+ ):
706
+ if past_key_values is not None:
707
+ if isinstance(past_key_values, Cache):
708
+ cache_length = past_key_values.get_seq_length()
709
+ past_length = past_key_values.seen_tokens
710
+ else:
711
+ cache_length = past_length = past_key_values[0][0].shape[2]
712
+
713
+ # Keep only the unprocessed tokens:
714
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
715
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
716
+ # input)
717
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
718
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
719
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
720
+ # input_ids based on the past_length.
721
+ elif past_length < input_ids.shape[1]:
722
+ input_ids = input_ids[:, past_length:]
723
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
724
+ elif self.config.image_token_index in input_ids:
725
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
726
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
727
+ # older attention values, as their corresponding values are not part of the input.
728
+ if cache_length < past_length and attention_mask is not None:
729
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
730
+
731
+ position_ids = kwargs.get("position_ids", None)
732
+ if attention_mask is not None and position_ids is None:
733
+ # create position_ids on the fly for batch generation
734
+ position_ids = attention_mask.long().cumsum(-1) - 1
735
+ position_ids.masked_fill_(attention_mask == 0, 1)
736
+ if past_key_values:
737
+ position_ids = position_ids[:, -input_ids.shape[1] :]
738
+
739
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
740
+ if inputs_embeds is not None and past_key_values is None:
741
+ model_inputs = {"inputs_embeds": inputs_embeds}
742
+ else:
743
+ model_inputs = {"input_ids": input_ids}
744
+
745
+ model_inputs.update(
746
+ {
747
+ "position_ids": position_ids,
748
+ "past_key_values": past_key_values,
749
+ "use_cache": kwargs.get("use_cache"),
750
+ "attention_mask": attention_mask,
751
+ "pixel_values": pixel_values,
752
+ }
753
+ )
754
+ return model_inputs
755
+
756
+ def _reorder_cache(self, *args, **kwargs):
757
+ return self.language_model._reorder_cache(*args, **kwargs)
requirements.txt CHANGED
@@ -1 +1,23 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.0.1
2
+ Pillow==9.3.0
3
+ decord==0.6.0
4
+ gradio==4.31.5
5
+ ninja==1.11.1.1
6
+ omegaconf==2.3.0
7
+ openai==1.14.2
8
+ pathos==0.3.2
9
+ prettytable==3.10.0
10
+ protobuf==3.20.3
11
+ pycocoevalcap==1.2
12
+ pycocotools==2.0.8
13
+ requests==2.31.0
14
+ safetensors==0.4.2
15
+ scikit-learn==1.4.1.post1
16
+ scipy==1.13.0
17
+ tiktoken==0.6.0
18
+ torch==2.1.0
19
+ torchvision==0.16.0
20
+ torchaudio==2.1.0
21
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.7/flash_attn-2.5.7+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
22
+ transformers==4.44.2
23
+ triton==2.1.0
tools/conversation.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # copy and modify from: https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/conversation.py
16
+ from PIL import Image
17
+ import torch
18
+ from transformers import StoppingCriteria, StoppingCriteriaList
19
+
20
+ from enum import auto, Enum
21
+ import os
22
+ from dataset.processor import Processor
23
+ import re
24
+
25
+
26
+ IMAGE_TOKEN = "<image>"
27
+ VIDEO_TOKEN = "<video>"
28
+
29
+ class SeparatorStyle(Enum):
30
+ """Different separator style."""
31
+ SINGLE = auto()
32
+ TWO = auto()
33
+
34
+ def get_prompt(conv):
35
+ ret = ""
36
+ if conv.system:
37
+ ret = conv.system + conv.sep1
38
+ for i, (role, message) in enumerate(conv.messages):
39
+ if message:
40
+ # In current version, the image should be add at the first conversation round.
41
+ # So we need to remove the special image tokens in following user input.
42
+ if i > 0:
43
+ message = re.sub(f"({IMAGE_TOKEN}|{VIDEO_TOKEN})\n*", "", message)
44
+ ret += role + ": " + message
45
+ if i % 2:
46
+ ret += conv.sep2
47
+ else:
48
+ ret += conv.sep1
49
+ else:
50
+ ret += role + ":"
51
+ return ret
52
+
53
+
54
+ class StoppingCriteriaSub(StoppingCriteria):
55
+ def __init__(self, stops=[], encounters=1):
56
+ super().__init__()
57
+ self.stops = stops
58
+
59
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
60
+ for stop in self.stops:
61
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
62
+ return True
63
+ return False
64
+
65
+
66
+ class Chat:
67
+ def __init__(self, model, processor: Processor, device='cuda', debug=False):
68
+ self.model = model
69
+ self.processor = processor
70
+ self.device = device
71
+ self.debug = debug
72
+ stop_words_ids = [torch.tensor([self.processor.tokenizer.eos_token_id]).to(device)]
73
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
74
+
75
+ def ask(self,text,conv):
76
+ conv.messages.append([conv.roles[0], text])
77
+ return conv
78
+
79
+ def prepare_model_inputs(self, conv, visual_data_file=None, images=None, n_frames=None):
80
+ conv.messages.append([conv.roles[1], None])
81
+ conv.messages[0][1] = re.sub(f"({IMAGE_TOKEN}|{VIDEO_TOKEN})\n*", "", conv.messages[0][1])
82
+
83
+ if images is None or isinstance(images, list) and len(images) == 0:
84
+ if isinstance(visual_data_file, str) and os.path.exists(visual_data_file):
85
+ images = self.processor.load_images(visual_data_file, n_frames)
86
+ elif isinstance(visual_data_file, Image.Image):
87
+ images = [visual_data_file]
88
+ elif visual_data_file is None or visual_data_file == "":
89
+ images = None
90
+ else:
91
+ raise NotImplementedError
92
+
93
+ if isinstance(images, list) and len(images) > 0:
94
+ conv.messages[0][1] = IMAGE_TOKEN*len(images) + '\n' + conv.messages[0][1]
95
+
96
+ prompt = get_prompt(conv)
97
+ if self.debug:
98
+ print(f"visual_data_file: {visual_data_file}")
99
+ print(f"Prompt: {prompt}", flush=True)
100
+
101
+ inputs = self.processor(prompt, images=images, edit_prompt=False, return_prompt=False)
102
+ inputs = {k:v.to(self.device) for k,v in inputs.items() if v is not None}
103
+ return inputs, conv, images
104
+
105
+ def answer(self, conv, visual_data_file=None, images=None, n_frames=None, max_new_tokens=512, num_beams=1, min_length=1, top_p=1.0,
106
+ repetition_penalty=1.0, length_penalty=1, temperature=0):
107
+ inputs, conv, images = self.prepare_model_inputs(conv, visual_data_file, images, n_frames)
108
+ if self.model is not None:
109
+ outputs = self.model.generate(
110
+ **inputs,
111
+ max_new_tokens=max_new_tokens,
112
+ stopping_criteria=self.stopping_criteria,
113
+ num_beams=num_beams,
114
+ do_sample=True if temperature > 0 else False,
115
+ min_length=min_length,
116
+ top_p=top_p,
117
+ repetition_penalty=repetition_penalty,
118
+ length_penalty=length_penalty,
119
+ temperature=temperature,
120
+ )
121
+ output_text = self.processor.tokenizer.decode(outputs[0][inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
122
+ else:
123
+ output_text = "Fake respone as launched in debug mode!"
124
+ conv.messages[-1][1] = output_text
125
+ return output_text, conv, images
126
+
127
+ class EasyDict(dict):
128
+ """
129
+ Get attributes
130
+
131
+ >>> d = EasyDict({'foo':3})
132
+ >>> d['foo']
133
+ 3
134
+ >>> d.foo
135
+ 3
136
+ >>> d.bar
137
+ Traceback (most recent call last):
138
+ ...
139
+ AttributeError: 'EasyDict' object has no attribute 'bar'
140
+
141
+ Works recursively
142
+
143
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
144
+ >>> isinstance(d.bar, dict)
145
+ True
146
+ >>> d.bar.x
147
+ 1
148
+ """
149
+
150
+ def __init__(self, d=None, **kwargs):
151
+ if d is None:
152
+ d = {}
153
+ if kwargs:
154
+ d.update(**kwargs)
155
+ for k, v in d.items():
156
+ setattr(self, k, v)
157
+ # Class attributes
158
+ for k in self.__class__.__dict__.keys():
159
+ if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
160
+ setattr(self, k, getattr(self, k))
161
+
162
+ def __setattr__(self, name, value):
163
+ if isinstance(value, (list, tuple)):
164
+ value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
165
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
166
+ value = self.__class__(value)
167
+ super(EasyDict, self).__setattr__(name, value)
168
+ super(EasyDict, self).__setitem__(name, value)
169
+
170
+ __setitem__ = __setattr__
171
+
172
+ def update(self, e=None, **f):
173
+ d = e or dict()
174
+ d.update(f)
175
+ for k in d:
176
+ setattr(self, k, d[k])
177
+
178
+ def pop(self, k, d=None):
179
+ if hasattr(self, k):
180
+ delattr(self, k)
181
+ return super(EasyDict, self).pop(k, d)
182
+
183
+ conv_tarsier = EasyDict({
184
+ "system": "",
185
+ "roles": ("USER", "ASSISTANT"),
186
+ "messages": [],
187
+ "sep1": " ",
188
+ "sep2": "</s>",
189
+ }
190
+ )
191
+
192
+ conv_tarsier_yi = EasyDict({
193
+ "system": "",
194
+ "roles": ("USER", "ASSISTANT"),
195
+ "messages": [],
196
+ "sep1": " ",
197
+ "sep2": "<|endoftext|>",
198
+ }
199
+ )
200
+
201
+ conv_tarsier_qwen2 = EasyDict({
202
+ "system": "",
203
+ "roles": ("USER", "ASSISTANT"),
204
+ "messages": [],
205
+ "sep1": " ",
206
+ "sep2": "<|endoftext|>",
207
+ }
208
+ )
209
+
210
+ conv_templates = {
211
+ "tarsier-7b": conv_tarsier,
212
+ "tarsier-13b": conv_tarsier,
213
+ "tarsier-34b": conv_tarsier_yi,
214
+ "tarsier2-7b": conv_tarsier_qwen2
215
+ }
216
+
tools/utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from models.modeling_tarsier import TarsierForConditionalGeneration, LlavaConfig
15
+ from dataset.processor import Processor
16
+ import torch
17
+ import base64
18
+
19
+ class Color:
20
+
21
+ @staticmethod
22
+ def red(x):
23
+ return '\33[31m' +x + '\033[0m'
24
+
25
+ @staticmethod
26
+ def green(x):
27
+ return '\33[32m' +x + '\033[0m'
28
+
29
+ @staticmethod
30
+ def yellow(x):
31
+ return '\33[33m' +x + '\033[0m'
32
+
33
+ @staticmethod
34
+ def blue(x):
35
+ return '\33[34m' +x + '\033[0m'
36
+
37
+ @staticmethod
38
+ def violet(x):
39
+ return '\33[35m' +x + '\033[0m'
40
+
41
+ def file_to_base64(img_path):
42
+ with open(img_path, 'rb') as video_file:
43
+ video_b64_str = base64.b64encode(video_file.read()).decode()
44
+ return video_b64_str
45
+
46
+ def load_model_and_processor(model_name_or_path, max_n_frames=8):
47
+ print(Color.red(f"Load model and processor from: {model_name_or_path}; with max_n_frames={max_n_frames}"), flush=True)
48
+ processor = Processor(
49
+ model_name_or_path,
50
+ max_n_frames=max_n_frames,
51
+ )
52
+ model_config = LlavaConfig.from_pretrained(
53
+ model_name_or_path,
54
+ trust_remote_code=True,
55
+ )
56
+ model = TarsierForConditionalGeneration.from_pretrained(
57
+ model_name_or_path,
58
+ config=model_config,
59
+ device_map='auto',
60
+ torch_dtype=torch.float16,
61
+ trust_remote_code=True
62
+ )
63
+ model.eval()
64
+ return model, processor
65
+