THUdyh commited on
Commit
d9c19b7
·
verified ·
1 Parent(s): 745d2d6

update space

Browse files
oryx/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import OryxLlamaForCausalLM
oryx/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
oryx/conversation.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Any, Dict, Union, Tuple
4
+ import re
5
+ import base64
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ from transformers import AutoTokenizer
9
+
10
+ import os
11
+ if 'EVALUATION' in os.environ:
12
+ # highresxpatch
13
+ EVALUATION = True
14
+ print(f"EVALUATION is set")
15
+ else:
16
+ EVALUATION = False
17
+
18
+ class SeparatorStyle(Enum):
19
+ """Different separator style."""
20
+
21
+ SINGLE = auto()
22
+ TWO = auto()
23
+ MPT = auto()
24
+ PLAIN = auto()
25
+ CHATML = auto()
26
+ LLAMA_2 = auto()
27
+ LLAMA_3 = auto()
28
+ QWEN2 = auto()
29
+ QWEN = auto()
30
+
31
+
32
+ @dataclasses.dataclass
33
+ class Conversation:
34
+ """A class that keeps all conversation history."""
35
+
36
+ system: str
37
+ roles: List[str]
38
+ messages: List[List[str]]
39
+ offset: int
40
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
41
+ sep: str = "###"
42
+ sep2: str = None
43
+ version: str = "Unknown"
44
+
45
+ tokenizer_id: str = ""
46
+ tokenizer: Any = None
47
+ # Stop criteria (the default one is EOS token)
48
+ stop_str: Union[str, List[str]] = None
49
+ # Stops generation if meeting any token in this list
50
+ stop_token_ids: List[int] = None
51
+
52
+ skip_next: bool = False
53
+
54
+ def get_prompt(self):
55
+ messages = self.messages
56
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
57
+ messages = self.messages.copy()
58
+ init_role, init_msg = messages[0].copy()
59
+ init_msg = init_msg[0]
60
+ if "mmtag" in self.version:
61
+ init_msg = init_msg.replace("<image>", "").strip()
62
+ messages[0] = (init_role, init_msg)
63
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
64
+ messages.insert(1, (self.roles[1], "Received."))
65
+ elif not init_msg.startswith("<image>"):
66
+ init_msg = init_msg.replace("<image>", "").strip()
67
+ messages[0] = (init_role, "<image>\n" + init_msg)
68
+ else:
69
+ messages[0] = (init_role, init_msg)
70
+
71
+ if self.sep_style == SeparatorStyle.SINGLE:
72
+ ret = self.system + self.sep
73
+ for role, message in messages:
74
+ if message:
75
+ if type(message) is tuple:
76
+ message, _, _ = message
77
+ ret += role + ": " + message + self.sep
78
+ else:
79
+ ret += role + ":"
80
+
81
+ elif self.sep_style == SeparatorStyle.TWO:
82
+ seps = [self.sep, self.sep2]
83
+ ret = self.system + seps[0]
84
+ for i, (role, message) in enumerate(messages):
85
+ if message:
86
+ if type(message) is tuple:
87
+ message, _, _ = message
88
+ ret += role + ": " + message + seps[i % 2]
89
+ else:
90
+ ret += role + ":"
91
+
92
+ elif self.sep_style == SeparatorStyle.CHATML:
93
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
94
+ for role, message in messages:
95
+ if message:
96
+ if type(message) is tuple:
97
+ message, images = message
98
+ message = "<image>" * len(images) + message
99
+ ret += role + "\n" + message + self.sep + "\n"
100
+ else:
101
+ ret += role + "\n"
102
+ return ret
103
+
104
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
105
+ chat_template_messages = [{"role": "system", "content": self.system}]
106
+ for role, message in messages:
107
+ if message:
108
+ if type(message) is tuple:
109
+ message, images = message
110
+ message = "<image>" * len(images) + message
111
+ chat_template_messages.append({"role": role, "content": message})
112
+ if EVALUATION:
113
+ return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
114
+ else:
115
+ return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=False)
116
+
117
+ elif self.sep_style == SeparatorStyle.MPT:
118
+ ret = self.system + self.sep
119
+ for role, message in messages:
120
+ if message:
121
+ if type(message) is tuple:
122
+ message, _, _ = message
123
+ ret += role + message + self.sep
124
+ else:
125
+ ret += role
126
+
127
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
128
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
129
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
130
+ ret = ""
131
+
132
+ for i, (role, message) in enumerate(messages):
133
+ if i == 0:
134
+ assert message, "first message should not be none"
135
+ assert role == self.roles[0], "first message should come from user"
136
+ if message:
137
+ if type(message) is tuple:
138
+ message, _, _ = message
139
+ if i == 0:
140
+ message = wrap_sys(self.system) + message
141
+ if i % 2 == 0:
142
+ message = wrap_inst(message)
143
+ ret += self.sep + message
144
+ else:
145
+ ret += " " + message + " " + self.sep2
146
+ else:
147
+ ret += ""
148
+ ret = ret.lstrip(self.sep)
149
+
150
+ elif self.sep_style == SeparatorStyle.PLAIN:
151
+ seps = [self.sep, self.sep2]
152
+ ret = self.system
153
+ for i, (role, message) in enumerate(messages):
154
+ if message:
155
+ if type(message) is tuple:
156
+ message, _, _ = message
157
+ ret += message + seps[i % 2]
158
+ else:
159
+ ret += ""
160
+ elif self.sep_style == SeparatorStyle.QWEN2:
161
+ start = '<|im_start|>'
162
+ end = '<|im_end|>\n'
163
+ ret = start + 'system\n' + self.system + end
164
+ for i, (role, message) in enumerate(messages):
165
+ if message:
166
+ if type(message) is tuple:
167
+ message, _, _ = message
168
+
169
+ if message.endswith('<|endoftext|>'):
170
+ message = message.replace('<|endoftext|>', '')
171
+ ret += start + role + "\n" + message + end + '<|endoftext|>'
172
+ else:
173
+ assert not '<|endoftext|>' in message, f"Invalid message: {message}"
174
+ ret += start + role + "\n" + message + end
175
+ else:
176
+ ret += start + role + "\n"
177
+ else:
178
+ raise ValueError(f"Invalid style: {self.sep_style}")
179
+
180
+ return ret
181
+
182
+ def append_message(self, role, message):
183
+ self.messages.append([role, message])
184
+
185
+ def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
186
+ if image_process_mode == "Pad":
187
+
188
+ def expand2square(pil_img, background_color=(122, 116, 104)):
189
+ width, height = pil_img.size
190
+ if width == height:
191
+ return pil_img
192
+ elif width > height:
193
+ result = Image.new(pil_img.mode, (width, width), background_color)
194
+ result.paste(pil_img, (0, (width - height) // 2))
195
+ return result
196
+ else:
197
+ result = Image.new(pil_img.mode, (height, height), background_color)
198
+ result.paste(pil_img, ((height - width) // 2, 0))
199
+ return result
200
+
201
+ image = expand2square(image)
202
+ elif image_process_mode in ["Default", "Crop"]:
203
+ pass
204
+ elif image_process_mode == "Resize":
205
+ image = image.resize((336, 336))
206
+ else:
207
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
208
+
209
+ if type(image) is not Image.Image:
210
+ image = Image.open(image).convert("RGB")
211
+
212
+ max_hw, min_hw = max(image.size), min(image.size)
213
+ aspect_ratio = max_hw / min_hw
214
+ max_len, min_len = 672, 448
215
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
216
+ longest_edge = int(shortest_edge * aspect_ratio)
217
+ W, H = image.size
218
+ if H > W:
219
+ H, W = longest_edge, shortest_edge
220
+ else:
221
+ H, W = shortest_edge, longest_edge
222
+ image = image.resize((W, H))
223
+ if return_pil:
224
+ return image
225
+ else:
226
+ buffered = BytesIO()
227
+ image.save(buffered, format=image_format)
228
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
229
+ return img_b64_str
230
+
231
+ def get_images(self, return_pil=False, return_path=False):
232
+ images = []
233
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
234
+ if i % 2 == 0:
235
+ if type(msg) is tuple:
236
+ msg, image, image_process_mode = msg
237
+ if type(image) != list:
238
+ image = [image]
239
+ for img in image:
240
+ if not return_path:
241
+ img = self.process_image(img, image_process_mode, return_pil=return_pil)
242
+ else:
243
+ images.append(img)
244
+ return images
245
+
246
+ def to_gradio_chatbot(self):
247
+ ret = []
248
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
249
+ if i % 2 == 0:
250
+ if type(msg) is tuple:
251
+ msg, image, image_process_mode = msg
252
+ if type(image) != list:
253
+ image = [image]
254
+ if len(image) == 1:
255
+ msg = "<image>\n" + msg.replace("<image>", "").strip()
256
+ else:
257
+ msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
258
+ for img in image:
259
+ img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
260
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}"/>'
261
+ msg = msg.replace("<image>", img_str, 1).strip()
262
+ if len(msg) > 0:
263
+ ret.append([msg, None])
264
+ else:
265
+ ret.append([msg, None])
266
+ else:
267
+ ret[-1][-1] = msg
268
+ return ret
269
+
270
+ def copy(self):
271
+ return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
272
+
273
+ def dict(self):
274
+ if len(self.get_images()) > 0:
275
+ return {
276
+ "system": self.system,
277
+ "roles": self.roles,
278
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
279
+ "offset": self.offset,
280
+ "sep": self.sep,
281
+ "sep2": self.sep2,
282
+ }
283
+ return {
284
+ "system": self.system,
285
+ "roles": self.roles,
286
+ "messages": self.messages,
287
+ "offset": self.offset,
288
+ "sep": self.sep,
289
+ "sep2": self.sep2,
290
+ }
291
+
292
+
293
+ conv_vicuna_v0 = Conversation(
294
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
295
+ roles=("Human", "Assistant"),
296
+ messages=[
297
+ ["Human", "What are the key differences between renewable and non-renewable energy sources?"],
298
+ [
299
+ "Assistant",
300
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
301
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
302
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
303
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
304
+ "renewable and non-renewable energy sources:\n"
305
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
306
+ "energy sources are finite and will eventually run out.\n"
307
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
308
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
309
+ "and other negative effects.\n"
310
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
311
+ "have lower operational costs than non-renewable sources.\n"
312
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
313
+ "locations than non-renewable sources.\n"
314
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
315
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
316
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
317
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
318
+ ],
319
+ ],
320
+ offset=2,
321
+ sep_style=SeparatorStyle.SINGLE,
322
+ sep="###",
323
+ )
324
+
325
+ conv_vicuna_v1 = Conversation(
326
+ system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
327
+ roles=("USER", "ASSISTANT"),
328
+ version="v1",
329
+ messages=[],
330
+ offset=0,
331
+ sep_style=SeparatorStyle.TWO,
332
+ sep=" ",
333
+ sep2="</s>",
334
+ )
335
+
336
+ conv_qwen_v1 = Conversation(
337
+ system="You are a helpful assistant.",
338
+ roles=("user", "assistant"),
339
+ version="v1",
340
+ messages=(),
341
+ offset=0,
342
+ sep_style=SeparatorStyle.QWEN2,
343
+ )
344
+
345
+ conv_llama_2 = Conversation(
346
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
347
+
348
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
349
+ roles=("USER", "ASSISTANT"),
350
+ version="llama_v2",
351
+ messages=[],
352
+ offset=0,
353
+ sep_style=SeparatorStyle.LLAMA_2,
354
+ sep="<s>",
355
+ sep2="</s>",
356
+ )
357
+
358
+ conv_llava_llama_2 = Conversation(
359
+ system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
360
+ roles=("USER", "ASSISTANT"),
361
+ version="llama_v2",
362
+ messages=[],
363
+ offset=0,
364
+ sep_style=SeparatorStyle.LLAMA_2,
365
+ sep="<s>",
366
+ sep2="</s>",
367
+ )
368
+
369
+ conv_llava_llama_3 = Conversation(
370
+ system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
371
+ roles=("user", "assistant"),
372
+ version="llama_v3",
373
+ messages=[],
374
+ offset=0,
375
+ sep_style=SeparatorStyle.LLAMA_3,
376
+ tokenizer=AutoTokenizer.from_pretrained("/apdcephfs_jn/share_302244400/peterrao/nj3/models/Llama-3-8B-Instruct"),
377
+ stop_token_ids=[128009],
378
+ )
379
+
380
+ conv_mistral_instruct = Conversation(
381
+ system="",
382
+ roles=("USER", "ASSISTANT"),
383
+ version="llama_v2",
384
+ messages=[],
385
+ offset=0,
386
+ sep_style=SeparatorStyle.LLAMA_2,
387
+ sep="",
388
+ sep2="</s>",
389
+ )
390
+
391
+ conv_llava_llama_2_simple = Conversation(
392
+ system="Answer the questions about the visual content that the user provides.",
393
+ roles=("USER", "ASSISTANT"),
394
+ version="llama_v2",
395
+ messages=[],
396
+ offset=0,
397
+ sep_style=SeparatorStyle.LLAMA_2,
398
+ sep="<s>",
399
+ sep2="</s>",
400
+ )
401
+
402
+ conv_llava_llama_2_mmtag = Conversation(
403
+ system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: <Image>visual content</Image>.",
404
+ roles=("USER", "ASSISTANT"),
405
+ version="llama_v2_mmtag",
406
+ messages=[],
407
+ offset=0,
408
+ sep_style=SeparatorStyle.LLAMA_2,
409
+ sep="<s>",
410
+ sep2="</s>",
411
+ )
412
+
413
+ conv_mpt = Conversation(
414
+ system="""<|im_start|>system
415
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
416
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
417
+ version="mpt",
418
+ messages=[],
419
+ offset=0,
420
+ sep_style=SeparatorStyle.MPT,
421
+ sep="<|im_end|>",
422
+ )
423
+
424
+ conv_qwen = Conversation(
425
+ system="""<|im_start|>system
426
+ You are a helpful assistant.""",
427
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
428
+ version="qwen",
429
+ messages=[],
430
+ offset=0,
431
+ sep_style=SeparatorStyle.CHATML,
432
+ sep="<|im_end|>",
433
+ )
434
+
435
+ conv_llava_plain = Conversation(
436
+ system="",
437
+ roles=("", ""),
438
+ messages=[],
439
+ offset=0,
440
+ sep_style=SeparatorStyle.PLAIN,
441
+ sep="\n",
442
+ )
443
+
444
+ conv_llava_v0 = Conversation(
445
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
446
+ roles=("Human", "Assistant"),
447
+ messages=[],
448
+ offset=0,
449
+ sep_style=SeparatorStyle.SINGLE,
450
+ sep="###",
451
+ )
452
+
453
+ conv_llava_v0_mmtag = Conversation(
454
+ system="A chat between a curious user and an artificial intelligence assistant. "
455
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
456
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
457
+ roles=("Human", "Assistant"),
458
+ messages=[],
459
+ offset=0,
460
+ sep_style=SeparatorStyle.SINGLE,
461
+ sep="###",
462
+ version="v0_mmtag",
463
+ )
464
+
465
+ conv_llava_v1 = Conversation(
466
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
467
+ roles=("USER", "ASSISTANT"),
468
+ version="v1",
469
+ messages=[],
470
+ offset=0,
471
+ sep_style=SeparatorStyle.TWO,
472
+ sep=" ",
473
+ sep2="</s>",
474
+ )
475
+
476
+ conv_llava_v1_mmtag = Conversation(
477
+ system="A chat between a curious user and an artificial intelligence assistant. "
478
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
479
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
480
+ roles=("USER", "ASSISTANT"),
481
+ messages=[],
482
+ offset=0,
483
+ sep_style=SeparatorStyle.TWO,
484
+ sep=" ",
485
+ sep2="</s>",
486
+ version="v1_mmtag",
487
+ )
488
+
489
+ conv_mistral_orca = Conversation(
490
+ system="""<|im_start|>system
491
+ You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
492
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
493
+ version="mpt",
494
+ messages=[],
495
+ offset=0,
496
+ sep_style=SeparatorStyle.MPT,
497
+ sep="<|im_end|>",
498
+ )
499
+
500
+ conv_mistral_zephyr = Conversation(
501
+ system="""<|system|>
502
+ You are a helpful AI assistant.""",
503
+ roles=("<|user|>\n", "<|assistant|>\n"),
504
+ version="mpt",
505
+ messages=[],
506
+ offset=0,
507
+ sep_style=SeparatorStyle.MPT,
508
+ sep="</s>",
509
+ )
510
+
511
+ conv_mistral_direct = Conversation(
512
+ system="""<|im_start|>system
513
+ Answer the questions.""",
514
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
515
+ version="mpt",
516
+ messages=[],
517
+ offset=0,
518
+ sep_style=SeparatorStyle.MPT,
519
+ sep="<|im_end|>",
520
+ )
521
+
522
+ conv_chatml_direct = Conversation(
523
+ system="""<|im_start|>system
524
+ Answer the questions.""",
525
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
526
+ version="mpt",
527
+ messages=[],
528
+ offset=0,
529
+ sep_style=SeparatorStyle.MPT,
530
+ sep="<|im_end|>",
531
+ )
532
+
533
+ default_conversation = conv_vicuna_v0
534
+ conv_templates = {
535
+ "default": conv_vicuna_v0,
536
+ "v0": conv_vicuna_v0,
537
+ "v1": conv_vicuna_v1,
538
+ "vicuna_v1": conv_vicuna_v1,
539
+ 'v1_qwen2': conv_qwen_v1,
540
+ "llama_2": conv_llama_2,
541
+ "mistral_instruct": conv_mistral_instruct,
542
+ "mistral_orca": conv_mistral_orca,
543
+ "mistral_zephyr": conv_mistral_zephyr,
544
+ "mistral_direct": conv_mistral_direct,
545
+ "plain": conv_llava_plain,
546
+ "v0_plain": conv_llava_plain,
547
+ "chatml_direct": conv_chatml_direct,
548
+ "llava_v0": conv_llava_v0,
549
+ "llava_v0_mmtag": conv_llava_v0_mmtag,
550
+ "llava_v1": conv_llava_v1,
551
+ "llava_v1_mmtag": conv_llava_v1_mmtag,
552
+ "llava_llama_2": conv_llava_llama_2,
553
+ "llava_llama_3": conv_llava_llama_3,
554
+ "llava_llama_2_simple": conv_llava_llama_2_simple,
555
+ "llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
556
+ "llava_mistral_instruct": conv_mistral_instruct,
557
+ "mpt": conv_mpt,
558
+ "qwen_1_5": conv_qwen,
559
+ }
oryx/mm_utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import math
5
+ import ast
6
+
7
+ import torch
8
+ from transformers import StoppingCriteria
9
+ from oryx.constants import IMAGE_TOKEN_INDEX
10
+ import os
11
+
12
+ video_base = 0
13
+ video_ps = 64
14
+ highres_base = 0
15
+ highres_ps = 32
16
+ MAXRES = 1536
17
+ MINRES = 0
18
+ VIDEO_MAXRES = 480
19
+ VIDEO_MINRES = 288
20
+ LOWRES_RESIZE = (384,32)
21
+ PAD2STRIDE=False
22
+
23
+ def pad_image(image, target_resolution, value=0):
24
+ """
25
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
26
+
27
+ Args:
28
+ image (PIL.Image.Image): The input image.
29
+ target_resolution (tuple): The target resolution (width, height) of the image.
30
+
31
+ Returns:
32
+ PIL.Image.Image: The resized and padded image.
33
+ """
34
+ original_width, original_height = image.size
35
+ target_width, target_height = target_resolution
36
+ # Create a new image with the target size and paste the resized image onto it
37
+ new_image = Image.new('RGB', (target_width, target_height), (value, value, value))
38
+ paste_x = (target_width - original_width) // 2
39
+ paste_y = (target_height - original_height) // 2
40
+ new_image.paste(image, (paste_x, paste_y))
41
+ return new_image
42
+
43
+ def resize_images(image, patch_size=14, base_size=896):
44
+ h, w = image.size
45
+ if base_size == 0:
46
+ if h * w > MAXRES * MAXRES:
47
+ # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
48
+ scale = MAXRES * MAXRES / (h * w)
49
+ scale = math.sqrt(scale)
50
+ elif h * w < MINRES * MINRES:
51
+ # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
52
+ scale = MINRES * MINRES / (h * w)
53
+ scale = math.sqrt(scale)
54
+ else:
55
+ scale = None
56
+ else:
57
+ scale = base_size * base_size / (h * w)
58
+ scale = math.sqrt(scale)
59
+
60
+
61
+ if scale is not None:
62
+ new_h = int(h * scale / patch_size) * patch_size
63
+ new_w = int(w * scale / patch_size) * patch_size
64
+ image = image.resize((new_h, new_w))
65
+ elif PAD2STRIDE:
66
+ if h % patch_size == 0:
67
+ new_h = h
68
+ else:
69
+ new_h = (h // patch_size + 1) * patch_size
70
+
71
+ if w % patch_size == 0:
72
+ new_w = w
73
+ else:
74
+ new_w = (w // patch_size + 1) * patch_size
75
+ image = pad_image(image, (new_h, new_w), value=127)
76
+ else:
77
+ scale = 1.0
78
+ new_h = int(h * scale / patch_size) * patch_size
79
+ new_w = int(w * scale / patch_size) * patch_size
80
+ image = image.resize((new_h, new_w))
81
+
82
+ return image
83
+
84
+ def resize_video(image, patch_size=14, base_size=896):
85
+ h, w = image.size
86
+ if base_size == 0:
87
+ if h * w > VIDEO_MAXRES * VIDEO_MAXRES:
88
+ # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
89
+ scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w)
90
+ scale = math.sqrt(scale)
91
+ elif h * w < VIDEO_MINRES * VIDEO_MINRES:
92
+ # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
93
+ scale = VIDEO_MINRES * VIDEO_MINRES / (h * w)
94
+ scale = math.sqrt(scale)
95
+ else:
96
+ scale = None
97
+ else:
98
+ scale = base_size * base_size / (h * w)
99
+ scale = math.sqrt(scale)
100
+
101
+ if scale is not None:
102
+ new_h = int(h * scale / patch_size) * patch_size
103
+ new_w = int(w * scale / patch_size) * patch_size
104
+ image = image.resize((new_h, new_w))
105
+ elif PAD2STRIDE:
106
+ if h % patch_size == 0:
107
+ new_h = h
108
+ else:
109
+ new_h = (h // patch_size + 1) * patch_size
110
+
111
+ if w % patch_size == 0:
112
+ new_w = w
113
+ else:
114
+ new_w = (w // patch_size + 1) * patch_size
115
+ image = pad_image(image, (new_h, new_w), value=127)
116
+ else:
117
+ scale = 1.0
118
+ new_h = int(h * scale / patch_size) * patch_size
119
+ new_w = int(w * scale / patch_size) * patch_size
120
+ image = image.resize((new_h, new_w))
121
+
122
+ return image
123
+
124
+ def process_anyres_video_genli(image, processor):
125
+ image = resize_video(image, patch_size=video_ps, base_size=video_base)
126
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
127
+ return image.unsqueeze(0)
128
+
129
+ def process_anyres_video_genli_long(image, processor):
130
+ image = resize_video(image, patch_size=video_ps * 2, base_size=video_base)
131
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
132
+ return image.unsqueeze(0)
133
+
134
+ def load_image_from_base64(image):
135
+ return Image.open(BytesIO(base64.b64decode(image)))
136
+
137
+ def process_anyres_highres_image_genli(image, processor):
138
+ h, w = image.size
139
+ if h < 32 and w < 32:
140
+ min_size = min(h, w)
141
+ ratio = 64 / min_size
142
+ image = image.resize((int(h * ratio), int(w * ratio)))
143
+ elif h < 32:
144
+ ratio = 64 / h
145
+ image = image.resize((int(h * ratio), int(w * ratio)))
146
+ elif w < 32:
147
+ ratio = 64 / w
148
+ image = image.resize((int(h * ratio), int(w * ratio)))
149
+
150
+ image = resize_images(image, patch_size=highres_ps, base_size=highres_base)
151
+
152
+ image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0])
153
+
154
+ # image_patches = [image_original_resize] + [image_original_resize]
155
+ # image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
156
+ # for image_patch in image_patches]
157
+ image_patches = processor.preprocess(image_original_resize, return_tensors='pt')['pixel_values'][0]
158
+ image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
159
+ # return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0)
160
+ return image_patches.unsqueeze(0), image_padded.unsqueeze(0)
161
+
162
+
163
+ def read_image_patch(patch_info):
164
+ if 'img_path' in patch_info.keys():
165
+ image = Image.open(patch_info['img_path']).convert('RGB')
166
+ else:
167
+ if 'image_encoing' in patch_info.keys():
168
+ patch_info['image_encoding'] = patch_info['image_encoing']
169
+ image_file_name = patch_info['patch']
170
+ start_bytes = int(patch_info['start_num'])
171
+ file_size = int(patch_info['size'])
172
+
173
+ with open(image_file_name, 'rb') as f:
174
+ f.seek(start_bytes)
175
+ if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
176
+ image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB")
177
+ else:
178
+ image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB")
179
+ return image
180
+
181
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
182
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
183
+
184
+ def insert_separator(X, sep):
185
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
186
+
187
+ input_ids = []
188
+ offset = 0
189
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
190
+ offset = 1
191
+ input_ids.append(prompt_chunks[0][0])
192
+
193
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
194
+ input_ids.extend(x[offset:])
195
+
196
+ if return_tensors is not None:
197
+ if return_tensors == 'pt':
198
+ return torch.tensor(input_ids, dtype=torch.long)
199
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
200
+ return input_ids
201
+
202
+
203
+ def get_model_name_from_path(model_path):
204
+ model_path = model_path.strip("/")
205
+ model_paths = model_path.split("/")
206
+ if model_paths[-1].startswith('checkpoint-'):
207
+ return model_paths[-2] + "_" + model_paths[-1]
208
+ else:
209
+ return model_paths[-1]
210
+
211
+
212
+ class KeywordsStoppingCriteria(StoppingCriteria):
213
+ def __init__(self, keywords, tokenizer, input_ids):
214
+ self.keywords = keywords
215
+ self.keyword_ids = []
216
+ for keyword in keywords:
217
+ cur_keyword_ids = tokenizer(keyword).input_ids
218
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
219
+ cur_keyword_ids = cur_keyword_ids[1:]
220
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
221
+ self.tokenizer = tokenizer
222
+ self.start_len = input_ids.shape[1]
223
+
224
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
225
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
226
+ offset = min(output_ids.shape[1] - self.start_len, 3)
227
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
228
+ for keyword_id in self.keyword_ids:
229
+ if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
230
+ return True
231
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
232
+ for keyword in self.keywords:
233
+ if keyword in outputs:
234
+ return True
235
+ return False
oryx/model/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ AVAILABLE_MODELS = {
4
+ "oryx_llama": "OryxLlamaForCausalLM, OryxConfig",
5
+ "oryx_qwen": "OryxQwenForCausalLM, OryxQwenConfig",
6
+ # Add other models as needed
7
+ }
8
+
9
+ for model_name, model_classes in AVAILABLE_MODELS.items():
10
+ try:
11
+ exec(f"from .language_model.{model_name} import {model_classes}")
12
+ except Exception as e:
13
+ raise e
14
+ print(f"Failed to import {model_name} from llava.language_model.{model_name}")
oryx/model/builder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import shutil
4
+
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
6
+ import torch
7
+ from oryx.model import *
8
+ from oryx.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+
10
+
11
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", overwrite_config=None):
12
+ kwargs = {"device_map": device_map}
13
+
14
+ if load_8bit:
15
+ kwargs["load_in_8bit"] = True
16
+ elif load_4bit:
17
+ kwargs["load_in_4bit"] = True
18
+ kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
19
+ else:
20
+ kwargs["torch_dtype"] = torch.bfloat16
21
+
22
+ if "oryx" in model_name.lower():
23
+ # Load Oryx model
24
+ if "7b" in model_name.lower():
25
+ from oryx.model.language_model.oryx_qwen import OryxQwenConfig
26
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
27
+ if overwrite_config is not None:
28
+ cfg_pretrained = OryxQwenConfig.from_pretrained(model_path)
29
+ print(f"Overwriting config with {overwrite_config}")
30
+ for k, v in overwrite_config.items():
31
+ setattr(cfg_pretrained, k, v)
32
+ model = OryxQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
33
+ else:
34
+ model = OryxQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
35
+ else:
36
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
37
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
38
+ if overwrite_config is not None:
39
+ print(f"Overwriting config with {overwrite_config}")
40
+ for k, v in overwrite_config.items():
41
+ setattr(cfg_pretrained, k, v)
42
+ model = OryxLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
43
+
44
+ else:
45
+ # Load language model
46
+ if model_base is not None:
47
+ # PEFT model
48
+ from peft import PeftModel
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
51
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
52
+ print(f"Loading LoRA weights from {model_path}")
53
+ model = PeftModel.from_pretrained(model, model_path)
54
+ print(f"Merging weights")
55
+ model = model.merge_and_unload()
56
+ print("Convert to FP16...")
57
+ model.to(torch.bfloat16)
58
+ else:
59
+ use_fast = False
60
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
61
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
62
+
63
+ image_processor = None
64
+
65
+ assert "oryx" in model_name.lower(), "Only Oryx models are supported for video chatbot."
66
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
67
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
68
+ if mm_use_im_patch_token:
69
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
70
+ if mm_use_im_start_end:
71
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
72
+ model.resize_token_embeddings(len(tokenizer))
73
+
74
+ vision_tower = model.get_vision_tower()
75
+ print("Loading vision tower...")
76
+ if not vision_tower.is_loaded:
77
+ vision_tower.load_model(device_map=device_map)
78
+ if device_map != "auto":
79
+ vision_tower.to(device="cuda", dtype=torch.bfloat16)
80
+ else:
81
+ vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
82
+ image_processor = vision_tower.image_processor
83
+ print("Loading vision tower succeeded.")
84
+ if hasattr(model.config, "max_sequence_length"):
85
+ context_len = model.config.max_sequence_length
86
+ else:
87
+ context_len = 2048
88
+
89
+ return tokenizer, model, image_processor, context_len
oryx/model/language_model/oryx_llama.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import AutoConfig, AutoModelForCausalLM, \
7
+ LlamaConfig, LlamaModel, LlamaForCausalLM
8
+
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from transformers.generation.utils import GenerateOutput
11
+
12
+ from oryx.model.oryx_arch import OryxMetaModel, OryxMetaForCausalLM
13
+
14
+
15
+ class OryxConfig(LlamaConfig):
16
+ model_type = "oryx_llama"
17
+
18
+
19
+ class OryxLlamaModel(OryxMetaModel, LlamaModel):
20
+ config_class = OryxConfig
21
+
22
+ def __init__(self, config: LlamaConfig):
23
+ super(OryxLlamaModel, self).__init__(config)
24
+
25
+
26
+ class OryxLlamaForCausalLM(LlamaForCausalLM, OryxMetaForCausalLM):
27
+ config_class = OryxConfig
28
+
29
+ def __init__(self, config):
30
+ LlamaForCausalLM.__init__(self, config)
31
+ self.model = OryxLlamaModel(config)
32
+
33
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
34
+
35
+ # Initialize weights and apply final processing
36
+ self.post_init()
37
+
38
+ def get_model(self):
39
+ return self.model
40
+
41
+ def forward(
42
+ self,
43
+ input_ids: torch.LongTensor = None,
44
+ attention_mask: Optional[torch.Tensor] = None,
45
+ position_ids: Optional[torch.LongTensor] = None,
46
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
47
+ inputs_embeds: Optional[torch.FloatTensor] = None,
48
+ labels: Optional[torch.LongTensor] = None,
49
+ use_cache: Optional[bool] = None,
50
+ output_attentions: Optional[bool] = None,
51
+ output_hidden_states: Optional[bool] = None,
52
+ images: Optional[torch.FloatTensor] = None,
53
+ images_highres: Optional[List[torch.FloatTensor]] = None,
54
+ image_sizes: Optional[List[List[int]]] = None,
55
+ return_dict: Optional[bool] = None,
56
+ modalities: Optional[List[str]] = ["image"],
57
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
58
+
59
+
60
+ if inputs_embeds is None:
61
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images,
62
+ modalities, image_sizes, images_highres)
63
+
64
+ if labels is None:
65
+ return super().forward(
66
+ input_ids=input_ids,
67
+ attention_mask=attention_mask,
68
+ position_ids=position_ids,
69
+ past_key_values=past_key_values,
70
+ inputs_embeds=inputs_embeds,
71
+ use_cache=use_cache,
72
+ output_attentions=output_attentions,
73
+ output_hidden_states=output_hidden_states,
74
+ return_dict=return_dict
75
+ )
76
+ else:
77
+ return self.forward_llm_efficient(
78
+ input_ids=input_ids,
79
+ attention_mask=attention_mask,
80
+ position_ids=position_ids,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ labels=labels,
84
+ use_cache=use_cache,
85
+ output_attentions=output_attentions,
86
+ output_hidden_states=output_hidden_states,
87
+ return_dict=return_dict
88
+ )
89
+
90
+ def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict):
91
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
92
+ output_hidden_states = (
93
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
94
+ )
95
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
96
+
97
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
98
+ outputs = self.model(
99
+ input_ids=input_ids,
100
+ attention_mask=attention_mask,
101
+ position_ids=position_ids,
102
+ past_key_values=past_key_values,
103
+ inputs_embeds=inputs_embeds,
104
+ use_cache=use_cache,
105
+ output_attentions=output_attentions,
106
+ output_hidden_states=output_hidden_states,
107
+ return_dict=return_dict,
108
+ )
109
+
110
+ hidden_states = outputs[0]
111
+ hidden_dim = hidden_states.size(-1)
112
+ shift_labels = labels[..., 1:].contiguous().reshape(-1)
113
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim)
114
+ assert shift_labels.size(0) == shift_hidden_states.size(0)
115
+ mask = shift_labels > -1
116
+ seen_tokens = mask.float().sum().item()
117
+ if not seen_tokens > 0:
118
+ logits = self.lm_head(shift_hidden_states[0:2])
119
+ loss = logits.sum() * 0
120
+ print("No tokens seen")
121
+ print(shift_labels)
122
+ else:
123
+ shift_labels = shift_labels[mask]
124
+ shift_hidden_states = shift_hidden_states[mask, :]
125
+ logits = self.lm_head(shift_hidden_states)
126
+ logits = logits.float()
127
+ loss_fct = nn.CrossEntropyLoss()
128
+ loss = loss_fct(logits, shift_labels)
129
+
130
+
131
+ if not return_dict:
132
+ output = (logits,) + outputs[1:]
133
+ return (loss,) + output if loss is not None else output
134
+
135
+ return CausalLMOutputWithPast(
136
+ loss=loss,
137
+ logits=logits,
138
+ past_key_values=outputs.past_key_values,
139
+ hidden_states=outputs.hidden_states,
140
+ attentions=outputs.attentions,
141
+ )
142
+
143
+ @torch.no_grad()
144
+ def generate(
145
+ self,
146
+ inputs: Optional[torch.Tensor] = None,
147
+ images: Optional[torch.Tensor] = None,
148
+ image_sizes: Optional[torch.Tensor] = None,
149
+ **kwargs,
150
+ ) -> Union[GenerateOutput, torch.LongTensor]:
151
+ modalities = kwargs.pop("modalities", None)
152
+ position_ids = kwargs.pop("position_ids", None)
153
+ attention_mask = kwargs.pop("attention_mask", None)
154
+ if "inputs_embeds" in kwargs:
155
+ raise NotImplementedError("`inputs_embeds` is not supported")
156
+
157
+ if images is not None:
158
+ (
159
+ inputs,
160
+ position_ids,
161
+ attention_mask,
162
+ _,
163
+ inputs_embeds,
164
+ _
165
+ ) = self.prepare_inputs_labels_for_multimodal(
166
+ inputs,
167
+ position_ids,
168
+ attention_mask,
169
+ None,
170
+ None,
171
+ images,
172
+ modalities,
173
+ image_sizes=image_sizes
174
+ )
175
+ else:
176
+ inputs_embeds = self.get_model().embed_tokens(inputs)
177
+
178
+ return super().generate(
179
+ position_ids=position_ids,
180
+ attention_mask=attention_mask,
181
+ inputs_embeds=inputs_embeds,
182
+ **kwargs
183
+ )
184
+
185
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
186
+ inputs_embeds=None, **kwargs):
187
+ images = kwargs.pop("images", None)
188
+ image_sizes = kwargs.pop("image_sizes", None)
189
+ inputs = super().prepare_inputs_for_generation(
190
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
191
+ )
192
+ if images is not None:
193
+ inputs['images'] = images
194
+ if image_sizes is not None:
195
+ inputs['image_sizes'] = image_sizes
196
+ return inputs
197
+
198
+ if OryxConfig.model_type == "oryx":
199
+ OryxConfig.model_type = "oryx_llama" # directly set to Oryx_dev to avoid conflict with HF's Oryx
200
+
201
+ AutoConfig.register("oryx_llama", OryxConfig)
202
+ AutoModelForCausalLM.register(OryxConfig, OryxLlamaForCausalLM)
oryx/model/language_model/oryx_qwen.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Optional, Tuple, Union, Dict
3
+ import torch
4
+ import os
5
+ import torch.nn as nn
6
+
7
+ import transformers
8
+ from transformers import AutoConfig, AutoModelForCausalLM
9
+
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+ from transformers.generation.utils import GenerateOutput
12
+
13
+ from oryx.model.oryx_arch import OryxMetaModel, OryxMetaForCausalLM
14
+ from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
15
+
16
+ class OryxQwenConfig(Qwen2Config):
17
+ model_type = "oryx_qwen"
18
+
19
+
20
+ class OryxQwenModel(OryxMetaModel, Qwen2Model):
21
+ config_class = OryxQwenConfig
22
+
23
+ def __init__(self, config: Qwen2Config):
24
+ super(OryxQwenModel, self).__init__(config)
25
+
26
+
27
+ class OryxQwenForCausalLM(Qwen2ForCausalLM, OryxMetaForCausalLM):
28
+ config_class = OryxQwenConfig
29
+
30
+ def __init__(self, config):
31
+ # super(Qwen2ForCausalLM, self).__init__(config)
32
+ Qwen2ForCausalLM.__init__(self, config)
33
+ config.model_type = "oryx_qwen"
34
+ config.rope_scaling = None
35
+
36
+ self.model = OryxQwenModel(config)
37
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
38
+ # Initialize weights and apply final processing
39
+ self.post_init()
40
+
41
+ def get_model(self):
42
+ return self.model
43
+
44
+ def forward(
45
+ self,
46
+ input_ids: torch.LongTensor = None,
47
+ attention_mask: Optional[torch.Tensor] = None,
48
+ position_ids: Optional[torch.LongTensor] = None,
49
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
50
+ inputs_embeds: Optional[torch.FloatTensor] = None,
51
+ labels: Optional[torch.LongTensor] = None,
52
+ use_cache: Optional[bool] = None,
53
+ output_attentions: Optional[bool] = None,
54
+ output_hidden_states: Optional[bool] = None,
55
+ images: Optional[torch.FloatTensor] = None,
56
+ images_highres: Optional[List[torch.FloatTensor]] = None,
57
+ image_sizes: Optional[List[List[int]]] = None,
58
+ return_dict: Optional[bool] = None,
59
+ modalities: Optional[List[str]] = ["image"],
60
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
61
+
62
+ if inputs_embeds is None:
63
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images,
64
+ modalities, image_sizes, images_highres)
65
+ if labels is None:
66
+ return super().forward(
67
+ input_ids=input_ids,
68
+ attention_mask=attention_mask,
69
+ position_ids=position_ids,
70
+ past_key_values=past_key_values,
71
+ inputs_embeds=inputs_embeds,
72
+ use_cache=use_cache,
73
+ output_attentions=output_attentions,
74
+ output_hidden_states=output_hidden_states,
75
+ return_dict=return_dict
76
+ )
77
+ else:
78
+ return self.forward_llm_efficient(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ position_ids=position_ids,
82
+ past_key_values=past_key_values,
83
+ inputs_embeds=inputs_embeds,
84
+ labels=labels,
85
+ use_cache=use_cache,
86
+ output_attentions=output_attentions,
87
+ output_hidden_states=output_hidden_states,
88
+ return_dict=return_dict
89
+ )
90
+
91
+ def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict):
92
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
93
+ output_hidden_states = (
94
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
95
+ )
96
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
97
+
98
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
99
+ outputs = self.model(
100
+ input_ids=input_ids,
101
+ attention_mask=attention_mask,
102
+ position_ids=position_ids,
103
+ past_key_values=past_key_values,
104
+ inputs_embeds=inputs_embeds,
105
+ use_cache=use_cache,
106
+ output_attentions=output_attentions,
107
+ output_hidden_states=output_hidden_states,
108
+ return_dict=return_dict,
109
+ )
110
+
111
+ hidden_states = outputs[0]
112
+ hidden_dim = hidden_states.size(-1)
113
+ shift_labels = labels[..., 1:].contiguous().reshape(-1)
114
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim)
115
+ assert shift_labels.size(0) == shift_hidden_states.size(0)
116
+ mask = shift_labels > -1
117
+ assert mask.float().sum() > 0
118
+ shift_labels = shift_labels[mask]
119
+ shift_hidden_states = shift_hidden_states[mask, :]
120
+ logits = self.lm_head(shift_hidden_states)
121
+ logits = logits.float()
122
+ loss_fct = nn.CrossEntropyLoss()
123
+ loss = loss_fct(logits, shift_labels)
124
+
125
+
126
+ if not return_dict:
127
+ output = (logits,) + outputs[1:]
128
+ return (loss,) + output if loss is not None else output
129
+
130
+
131
+ return CausalLMOutputWithPast(
132
+ loss=loss,
133
+ logits=logits,
134
+ past_key_values=outputs.past_key_values,
135
+ hidden_states=outputs.hidden_states,
136
+ attentions=outputs.attentions,
137
+ )
138
+
139
+ @torch.no_grad()
140
+ def generate(
141
+ self,
142
+ inputs: Optional[torch.Tensor] = None,
143
+ images: Optional[torch.Tensor] = None,
144
+ images_highres: Optional[List[torch.FloatTensor]] = None,
145
+ image_sizes: Optional[torch.Tensor] = None,
146
+ modalities: Optional[List[str]] = ["image"],
147
+ **kwargs,
148
+ ) -> Union[GenerateOutput, torch.LongTensor]:
149
+ position_ids = kwargs.pop("position_ids", None)
150
+ attention_mask = kwargs.pop("attention_mask", None)
151
+ if "inputs_embeds" in kwargs:
152
+ raise NotImplementedError("`inputs_embeds` is not supported")
153
+
154
+ if images is not None:
155
+ (inputs,
156
+ position_ids,
157
+ attention_mask,
158
+ _,
159
+ inputs_embeds,
160
+ _) = self.prepare_inputs_labels_for_multimodal(inputs,
161
+ position_ids,
162
+ attention_mask,
163
+ None, None,
164
+ images, modalities, image_sizes=image_sizes, images_highres=images_highres)
165
+ else:
166
+ inputs_embeds = self.get_model().embed_tokens(inputs)
167
+
168
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
169
+
170
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
171
+ images = kwargs.pop("images", None)
172
+ image_sizes = kwargs.pop("image_sizes", None)
173
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
174
+ if images is not None:
175
+ inputs["images"] = images
176
+ if image_sizes is not None:
177
+ inputs["image_sizes"] = image_sizes
178
+ return inputs
179
+
180
+
181
+ AutoConfig.register("oryx_qwen", OryxQwenConfig)
182
+ AutoModelForCausalLM.register(OryxQwenConfig, OryxQwenForCausalLM)
oryx/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .oryx_vit import OryxViTWrapper
3
+
4
+ def build_vision_tower(vision_tower_cfg, **kwargs):
5
+ vision_tower = getattr(vision_tower_cfg, 'vision_tower', getattr(vision_tower_cfg, 'mm_vision_tower', None))
6
+ is_absolute_path_exists = os.path.exists(vision_tower)
7
+ if "oryx_vit" in vision_tower:
8
+ print(f"Buiding OryxViTWrapper from {vision_tower}...")
9
+ path = vision_tower.split(":")[1]
10
+ return OryxViTWrapper(vision_tower, path=path, args=vision_tower_cfg, **kwargs)
11
+ else:
12
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
oryx/model/multimodal_encoder/oryx_vit.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from dataclasses import dataclass
4
+ from functools import partial
5
+ from typing import (
6
+ Callable,
7
+ Dict,
8
+ Final,
9
+ List,
10
+ Literal,
11
+ Optional,
12
+ Sequence,
13
+ Set,
14
+ Tuple,
15
+ Type,
16
+ Union,
17
+ )
18
+
19
+ from torch.utils.checkpoint import checkpoint
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ try:
24
+ from timm.layers import (
25
+ AttentionPoolLatent,
26
+ DropPath,
27
+ LayerType,
28
+ Mlp,
29
+ PatchDropout,
30
+ PatchEmbed,
31
+ resample_abs_pos_embed,
32
+ )
33
+ from timm.models._manipulate import checkpoint_seq, named_apply
34
+ except:
35
+ print('Wrong timm version')
36
+
37
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
38
+
39
+ from typing import Optional
40
+
41
+ import logging
42
+ import torch
43
+ import torch.nn as nn
44
+ import torch.nn.functional as F
45
+
46
+ import os
47
+
48
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
49
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
50
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
51
+ def norm_cdf(x):
52
+ # Computes standard normal cumulative distribution function
53
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
54
+
55
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
56
+ warnings.warn(
57
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
58
+ "The distribution of values may be incorrect.",
59
+ stacklevel=2,
60
+ )
61
+
62
+ with torch.no_grad():
63
+ # Values are generated by using a truncated uniform distribution and
64
+ # then using the inverse CDF for the normal distribution.
65
+ # Get upper and lower cdf values
66
+ l = norm_cdf((a - mean) / std) # noqa: E741
67
+ u = norm_cdf((b - mean) / std)
68
+
69
+ # Uniformly fill tensor with values from [l, u], then translate to
70
+ # [2l-1, 2u-1].
71
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
72
+
73
+ # Use inverse cdf transform for normal distribution to get truncated
74
+ # standard normal
75
+ tensor.erfinv_()
76
+
77
+ # Transform to proper mean, std
78
+ tensor.mul_(std * math.sqrt(2.0))
79
+ tensor.add_(mean)
80
+
81
+ # Clamp to ensure it's in the proper range
82
+ tensor.clamp_(min=a, max=b)
83
+ return tensor
84
+
85
+
86
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
87
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
88
+ r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
89
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
90
+ Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
91
+ from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
92
+ with values outside :math:`[a, b]` redrawn until they are within
93
+ the bounds. The method used for generating the random values works
94
+ best when :math:`a \leq \text{mean} \leq b`.
95
+ Args:
96
+ tensor: an n-dimensional `torch.Tensor`
97
+ mean: the mean of the normal distribution
98
+ std: the standard deviation of the normal distribution
99
+ a: the minimum cutoff value
100
+ b: the maximum cutoff value
101
+ Examples:
102
+ >>> w = torch.empty(3, 5)
103
+ >>> nn.init.trunc_normal_(w)
104
+ """
105
+
106
+ with torch.no_grad():
107
+ dtype = tensor.dtype
108
+ tensor_fp32 = tensor.float()
109
+ tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
110
+ tensor_dtype = tensor_fp32.to(dtype=dtype)
111
+ tensor.copy_(tensor_dtype)
112
+
113
+
114
+ def init_weights(self):
115
+ if self.pos_embed is not None:
116
+ trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
117
+ trunc_normal_(self.latent, std=self.latent_dim**-0.5)
118
+
119
+
120
+ def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
121
+ """ViT weight initialization, original timm impl (for reproducibility)"""
122
+ if isinstance(module, nn.Linear):
123
+ trunc_normal_(module.weight, std=0.02)
124
+ if module.bias is not None:
125
+ nn.init.zeros_(module.bias)
126
+ elif hasattr(module, "init_weights"):
127
+ module.init_weights()
128
+
129
+
130
+ class Attention(nn.Module):
131
+ fused_attn: Final[bool]
132
+
133
+ def __init__(
134
+ self,
135
+ dim: int,
136
+ num_heads: int = 8,
137
+ qkv_bias: bool = False,
138
+ qk_norm: bool = False,
139
+ attn_drop: float = 0.0,
140
+ proj_drop: float = 0.0,
141
+ norm_layer: nn.Module = nn.LayerNorm,
142
+ ) -> None:
143
+ super().__init__()
144
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
145
+ self.num_heads = num_heads
146
+ self.head_dim = dim // num_heads
147
+ self.scale = self.head_dim**-0.5
148
+ # self.fused_attn = use_fused_attn()
149
+ self.fused_attn = True
150
+
151
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
152
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
153
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
154
+ self.attn_drop = nn.Dropout(attn_drop)
155
+ self.proj = nn.Linear(dim, dim)
156
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
157
+
158
+ def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor:
159
+ B, N, C = x.shape
160
+ qkv = (
161
+ self.qkv(x)
162
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
163
+ .permute(2, 0, 3, 1, 4)
164
+ )
165
+ q, k, v = qkv.unbind(0)
166
+ q, k = self.q_norm(q), self.k_norm(k)
167
+
168
+ if cu_slens is not None:
169
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
170
+ k = k.permute(0, 2, 1, 3)
171
+ v = v.permute(0, 2, 1, 3)
172
+ max_seqlen = torch.max(cu_slens[1:] - cu_slens[:-1]).item()
173
+ x = flash_attn_varlen_func(
174
+ q.squeeze(0),
175
+ k.squeeze(0),
176
+ v.squeeze(0),
177
+ cu_seqlens_q=cu_slens,
178
+ cu_seqlens_k=cu_slens,
179
+ max_seqlen_q=max_seqlen,
180
+ max_seqlen_k=max_seqlen,
181
+ softmax_scale=self.scale,
182
+ causal=False,
183
+ )
184
+
185
+ x = x.reshape(B, N, -1)
186
+ x = self.proj(x)
187
+ x = self.proj_drop(x)
188
+
189
+ else:
190
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
191
+ k = k.permute(0, 2, 1, 3)
192
+ v = v.permute(0, 2, 1, 3)
193
+ x = flash_attn_func(q, k, v, softmax_scale=self.scale) # -> b, n, h, c
194
+
195
+ x = x.reshape(B, N, -1)
196
+ x = self.proj(x)
197
+ x = self.proj_drop(x)
198
+ return x
199
+
200
+
201
+ class LayerScale(nn.Module):
202
+ def __init__(
203
+ self,
204
+ dim: int,
205
+ init_values: float = 1e-5,
206
+ inplace: bool = False,
207
+ ) -> None:
208
+ super().__init__()
209
+ self.inplace = inplace
210
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
211
+
212
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
213
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
214
+
215
+
216
+ class Block(nn.Module):
217
+ def __init__(
218
+ self,
219
+ dim: int,
220
+ num_heads: int,
221
+ mlp_ratio: float = 4.0,
222
+ qkv_bias: bool = False,
223
+ qk_norm: bool = False,
224
+ proj_drop: float = 0.0,
225
+ attn_drop: float = 0.0,
226
+ init_values: Optional[float] = None,
227
+ drop_path: float = 0.0,
228
+ act_layer: nn.Module = nn.GELU,
229
+ norm_layer: nn.Module = nn.LayerNorm,
230
+ mlp_layer: nn.Module = Mlp,
231
+ ) -> None:
232
+ super().__init__()
233
+ self.norm1 = norm_layer(dim)
234
+ self.attn = Attention(
235
+ dim,
236
+ num_heads=num_heads,
237
+ qkv_bias=qkv_bias,
238
+ qk_norm=qk_norm,
239
+ attn_drop=attn_drop,
240
+ proj_drop=proj_drop,
241
+ norm_layer=norm_layer,
242
+ )
243
+ self.ls1 = (
244
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
245
+ )
246
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
247
+
248
+ self.norm2 = norm_layer(dim)
249
+ self.mlp = mlp_layer(
250
+ in_features=dim,
251
+ hidden_features=int(dim * mlp_ratio),
252
+ act_layer=act_layer,
253
+ drop=proj_drop,
254
+ )
255
+ self.ls2 = (
256
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
257
+ )
258
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
259
+
260
+ def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor:
261
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_slens=cu_slens)))
262
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
263
+ return x
264
+
265
+
266
+ class VisionTransformer(nn.Module):
267
+ """Vision Transformer
268
+
269
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
270
+ - https://arxiv.org/abs/2010.11929
271
+ """
272
+
273
+ dynamic_img_size: Final[bool]
274
+
275
+ def __init__(
276
+ self,
277
+ img_size: Union[int, Tuple[int, int]] = 224,
278
+ patch_size: Union[int, Tuple[int, int]] = 16,
279
+ in_chans: int = 3,
280
+ num_classes: int = 1000,
281
+ global_pool: Literal["", "avg", "token", "map"] = "token",
282
+ embed_dim: int = 768,
283
+ depth: int = 12,
284
+ num_heads: int = 12,
285
+ mlp_ratio: float = 4.0,
286
+ qkv_bias: bool = True,
287
+ qk_norm: bool = False,
288
+ init_values: Optional[float] = None,
289
+ class_token: bool = True,
290
+ no_embed_class: bool = False,
291
+ reg_tokens: int = 0,
292
+ pre_norm: bool = False,
293
+ fc_norm: Optional[bool] = None,
294
+ dynamic_img_size: bool = False,
295
+ dynamic_img_pad: bool = False,
296
+ drop_rate: float = 0.0,
297
+ pos_drop_rate: float = 0.0,
298
+ patch_drop_rate: float = 0.0,
299
+ proj_drop_rate: float = 0.0,
300
+ attn_drop_rate: float = 0.0,
301
+ drop_path_rate: float = 0.0,
302
+ weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
303
+ embed_layer: Callable = PatchEmbed,
304
+ norm_layer: Optional[LayerType] = None,
305
+ act_layer: Optional[LayerType] = None,
306
+ strict_img_size: bool = False,
307
+ block_fn: Type[nn.Module] = Block,
308
+ mlp_layer: Type[nn.Module] = Mlp,
309
+ ignore_head: bool = False,
310
+ ) -> None:
311
+ """
312
+ Args:
313
+ img_size: Input image size.
314
+ patch_size: Patch size.
315
+ in_chans: Number of image input channels.
316
+ num_classes: Mumber of classes for classification head.
317
+ global_pool: Type of global pooling for final sequence (default: 'token').
318
+ embed_dim: Transformer embedding dimension.
319
+ depth: Depth of transformer.
320
+ num_heads: Number of attention heads.
321
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
322
+ qkv_bias: Enable bias for qkv projections if True.
323
+ init_values: Layer-scale init values (layer-scale enabled if not None).
324
+ class_token: Use class token.
325
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
326
+ reg_tokens: Number of register tokens.
327
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
328
+ drop_rate: Head dropout rate.
329
+ pos_drop_rate: Position embedding dropout rate.
330
+ attn_drop_rate: Attention dropout rate.
331
+ drop_path_rate: Stochastic depth rate.
332
+ weight_init: Weight initialization scheme.
333
+ embed_layer: Patch embedding layer.
334
+ norm_layer: Normalization layer.
335
+ act_layer: MLP activation layer.
336
+ block_fn: Transformer block layer.
337
+ """
338
+ super().__init__()
339
+ assert global_pool in ("", "avg", "token", "map")
340
+ assert class_token or global_pool != "token"
341
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
342
+ # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
343
+ # act_layer = get_act_layer(act_layer) or nn.GELU
344
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
345
+ act_layer = nn.GELU
346
+
347
+ self.num_classes = num_classes
348
+ self.global_pool = global_pool
349
+ self.num_features = self.embed_dim = (
350
+ embed_dim # num_features for consistency with other models
351
+ )
352
+ self.num_prefix_tokens = 1 if class_token else 0
353
+ self.num_prefix_tokens += reg_tokens
354
+ self.num_reg_tokens = reg_tokens
355
+ self.has_class_token = class_token
356
+ self.no_embed_class = (
357
+ no_embed_class # don't embed prefix positions (includes reg)
358
+ )
359
+ self.dynamic_img_size = dynamic_img_size
360
+ self.grad_checkpointing = False
361
+ self.ignore_head = ignore_head
362
+
363
+ embed_args = {}
364
+ if dynamic_img_size:
365
+ # flatten deferred until after pos embed
366
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
367
+ self.patch_embed = embed_layer(
368
+ img_size=img_size,
369
+ patch_size=patch_size,
370
+ in_chans=in_chans,
371
+ embed_dim=embed_dim,
372
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
373
+ dynamic_img_pad=dynamic_img_pad,
374
+ strict_img_size=strict_img_size,
375
+ **embed_args,
376
+ )
377
+ num_patches = self.patch_embed.num_patches
378
+
379
+ self.cls_token = (
380
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
381
+ )
382
+ self.reg_token = (
383
+ nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
384
+ )
385
+ embed_len = (
386
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
387
+ )
388
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
389
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
390
+ if patch_drop_rate > 0:
391
+ self.patch_drop = PatchDropout(
392
+ patch_drop_rate,
393
+ num_prefix_tokens=self.num_prefix_tokens,
394
+ )
395
+ else:
396
+ self.patch_drop = nn.Identity()
397
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
398
+
399
+ dpr = [
400
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
401
+ ] # stochastic depth decay rule
402
+ self.blocks = nn.Sequential(
403
+ *[
404
+ block_fn(
405
+ dim=embed_dim,
406
+ num_heads=num_heads,
407
+ mlp_ratio=mlp_ratio,
408
+ qkv_bias=qkv_bias,
409
+ qk_norm=qk_norm,
410
+ init_values=init_values,
411
+ proj_drop=proj_drop_rate,
412
+ attn_drop=attn_drop_rate,
413
+ drop_path=dpr[i],
414
+ norm_layer=norm_layer,
415
+ act_layer=act_layer,
416
+ mlp_layer=mlp_layer,
417
+ )
418
+ for i in range(depth)
419
+ ]
420
+ )
421
+
422
+ def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
423
+ assert mode in ("jax", "jax_nlhb", "moco", "")
424
+ # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
425
+ trunc_normal_(self.pos_embed, std=0.02)
426
+ if self.cls_token is not None:
427
+ nn.init.normal_(self.cls_token, std=1e-6)
428
+ named_apply(init_weights_vit_timm, self)
429
+
430
+ @torch.jit.ignore
431
+ def no_weight_decay(self) -> Set:
432
+ return {"pos_embed", "cls_token", "dist_token"}
433
+
434
+ @torch.jit.ignore
435
+ def group_matcher(self, coarse: bool = False) -> Dict:
436
+ return dict(
437
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
438
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
439
+ )
440
+
441
+ @torch.jit.ignore
442
+ def set_grad_checkpointing(self, enable: bool = True) -> None:
443
+ self.grad_checkpointing = enable
444
+
445
+ @torch.jit.ignore
446
+ def get_classifier(self) -> nn.Module:
447
+ return self.head
448
+
449
+ def reset_classifier(self, num_classes: int, global_pool=None) -> None:
450
+ self.num_classes = num_classes
451
+ if global_pool is not None:
452
+ assert global_pool in ("", "avg", "token", "map")
453
+ if global_pool == "map" and self.attn_pool is None:
454
+ assert (
455
+ False
456
+ ), "Cannot currently add attention pooling in reset_classifier()."
457
+ elif global_pool != "map " and self.attn_pool is not None:
458
+ self.attn_pool = None # remove attention pooling
459
+ self.global_pool = global_pool
460
+ self.head = (
461
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
462
+ )
463
+
464
+ def rescale_positional_embedding(self, out_size):
465
+ h, w = out_size
466
+ pos_embed_shape = int((self.pos_embed.shape[1]) ** 0.5)
467
+ if (h, w) == (pos_embed_shape, pos_embed_shape):
468
+ return self.pos_embed
469
+ rescaled_positional_embedding = \
470
+ self.pos_embed.new_zeros(1, h*w, self.pos_embed.shape[2])
471
+ pe_2d = self.pos_embed[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape)
472
+ pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w)
473
+ rescaled_positional_embedding[0] = pe_2d.T.contiguous()
474
+ return rescaled_positional_embedding
475
+
476
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
477
+ if self.dynamic_img_size:
478
+ B, H, W, C = x.shape
479
+ pos_embed = resample_abs_pos_embed(
480
+ self.pos_embed,
481
+ (H, W),
482
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
483
+ )
484
+ x = x.view(B, -1, C)
485
+ else:
486
+ pos_embed = self.pos_embed
487
+
488
+ to_cat = []
489
+ if self.cls_token is not None:
490
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
491
+ if self.reg_token is not None:
492
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
493
+
494
+ if self.no_embed_class:
495
+ # deit-3, updated JAX (big vision)
496
+ # position embedding does not overlap with class token, add then concat
497
+ x = x + pos_embed
498
+ if to_cat:
499
+ x = torch.cat(to_cat + [x], dim=1)
500
+ else:
501
+ # original timm, JAX, and deit vit impl
502
+ # pos_embed has entry for class token, concat then add
503
+ if to_cat:
504
+ x = torch.cat(to_cat + [x], dim=1)
505
+ x = x + pos_embed
506
+
507
+ return self.pos_drop(x)
508
+
509
+ def _intermediate_layers(
510
+ self,
511
+ x: torch.Tensor,
512
+ n: Union[int, Sequence] = 1,
513
+ ) -> List[torch.Tensor]:
514
+ outputs, num_blocks = [], len(self.blocks)
515
+ take_indices = set(
516
+ range(num_blocks - n, num_blocks) if isinstance(n, int) else n
517
+ )
518
+
519
+ # forward pass
520
+ x = self.patch_embed(x)
521
+ x = self._pos_embed(x)
522
+ x = self.patch_drop(x)
523
+ x = self.norm_pre(x)
524
+ for i, blk in enumerate(self.blocks):
525
+ x = blk(x)
526
+ if i in take_indices:
527
+ outputs.append(x)
528
+
529
+ return outputs
530
+
531
+ def get_intermediate_layers(
532
+ self,
533
+ x: torch.Tensor,
534
+ n: Union[int, Sequence] = 1,
535
+ reshape: bool = False,
536
+ return_prefix_tokens: bool = False,
537
+ norm: bool = False,
538
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
539
+ """Intermediate layer accessor (NOTE: This is a WIP experiment).
540
+ Inspired by DINO / DINOv2 interface
541
+ """
542
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
543
+ outputs = self._intermediate_layers(x, n)
544
+ if norm:
545
+ outputs = [self.norm(out) for out in outputs]
546
+ prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
547
+ outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
548
+
549
+ if reshape:
550
+ grid_size = self.patch_embed.grid_size
551
+ outputs = [
552
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
553
+ .permute(0, 3, 1, 2)
554
+ .contiguous()
555
+ for out in outputs
556
+ ]
557
+
558
+ if return_prefix_tokens:
559
+ return tuple(zip(outputs, prefix_tokens))
560
+ return tuple(outputs)
561
+
562
+ def forward_features_list(self, x_list):
563
+ x_all = []
564
+ image_sizes = []
565
+ for x in x_list:
566
+ bs, _, h, w = x.shape
567
+
568
+ # fix patch size=14 in datasets
569
+ pad_h = (self.patch_embed.patch_size[0] - h % self.patch_embed.patch_size[0]) % self.patch_embed.patch_size[0]
570
+ pad_w = (self.patch_embed.patch_size[1] - w % self.patch_embed.patch_size[1]) % self.patch_embed.patch_size[1]
571
+ x = F.pad(x, (0, pad_w, 0, pad_h))
572
+
573
+ bs, _, h, w = x.shape
574
+
575
+ h = h // self.patch_embed.patch_size[0]
576
+ w = w // self.patch_embed.patch_size[1]
577
+
578
+ x = self.patch_embed(x)
579
+ x = x + self.rescale_positional_embedding(out_size=(h, w))
580
+ x = self.patch_drop(x)
581
+ x = self.norm_pre(x)
582
+ x_all.append(x)
583
+ image_sizes.append((h, w))
584
+
585
+ slen = [xi.size(1) for xi in x_all]
586
+ x = torch.cat(x_all, dim=1)
587
+
588
+ cu_indices = [0, ]
589
+ for i in slen:
590
+ cu_indices.append(cu_indices[-1] + i)
591
+
592
+ cu_slens = torch.tensor(cu_indices, dtype=torch.int32).to(x.device)
593
+ for idx, blk in enumerate(self.blocks):
594
+ if self.grad_checkpointing and not torch.jit.is_scripting():
595
+ x = checkpoint(blk, x, cu_slens, use_reentrant=True)
596
+ else:
597
+ x = blk(x, cu_slens=cu_slens)
598
+ feats = x.split(slen, dim=1) #[(1, slen, c)]
599
+ return feats, image_sizes
600
+
601
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
602
+ bs, _, h, w = x.shape
603
+ h = h // self.patch_embed.patch_size[0]
604
+ w = w // self.patch_embed.patch_size[1]
605
+
606
+ x = self.patch_embed(x)
607
+ # x = self._pos_embed(x)
608
+ x = x + self.rescale_positional_embedding(out_size=(h, w))
609
+ x = self.patch_drop(x)
610
+ x = self.norm_pre(x)
611
+ if self.grad_checkpointing and not torch.jit.is_scripting():
612
+ x = checkpoint_seq(self.blocks, x)
613
+ else:
614
+ x = self.blocks(x)
615
+ return x, (h, w)
616
+
617
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
618
+ x = self.norm(x)
619
+ if self.attn_pool is not None:
620
+ x = self.attn_pool(x)
621
+ elif self.global_pool == "avg":
622
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
623
+ elif self.global_pool:
624
+ x = x[:, 0] # class token
625
+ x = self.fc_norm(x)
626
+ x = self.head_drop(x)
627
+ return x if pre_logits else self.head(x)
628
+
629
+ def forward(self, x, cal_attn_pool=False):
630
+ if type(x) is list:
631
+ x, image_sizes = self.forward_features_list(x)
632
+ return x, image_sizes, None
633
+ else:
634
+ x, image_sizes = self.forward_features(x)
635
+ return x, image_sizes, None
636
+
637
+ @dataclass
638
+ class SigLIPVisionCfg:
639
+ width: int = 1152
640
+ layers: Union[Tuple[int, int, int, int], int] = 27
641
+ heads: int = 16
642
+ patch_size: int = 14
643
+ image_size: Union[Tuple[int, int], int] = 336
644
+ global_pool: str = "map"
645
+ mlp_ratio: float = 3.7362
646
+ class_token: bool = False
647
+ num_classes: int = 0
648
+ use_checkpoint: bool = False
649
+
650
+
651
+ SigLIP_MODEL_CONFIG = {
652
+ "siglip_so400m_patch14_384": {
653
+ "image_size": 384,
654
+ "patch_size": 14,
655
+ "width": 1152,
656
+ "layers": 27,
657
+ "heads": 16,
658
+ "mlp_ratio": 3.7362,
659
+ "global_pool": "map",
660
+ "use_checkpoint": False,
661
+ },
662
+ "siglip_so400m_patch16_384": {
663
+ "image_size": 384,
664
+ "patch_size": 16,
665
+ "width": 1152,
666
+ "layers": 27,
667
+ "heads": 16,
668
+ "mlp_ratio": 3.7362,
669
+ "global_pool": "map",
670
+ "use_checkpoint": False,
671
+ },
672
+ "siglip_so400m_patch14_224": {
673
+ "image_size": 224,
674
+ "patch_size": 14,
675
+ "width": 1152,
676
+ "layers": 27,
677
+ "heads": 16,
678
+ "mlp_ratio": 3.7362,
679
+ "global_pool": "map",
680
+ "use_checkpoint": False,
681
+ },
682
+ "siglip_large_patch16_384": {
683
+ "image_size": 384,
684
+ "patch_size": 16,
685
+ "width": 1024,
686
+ "layers": 24,
687
+ "heads": 16,
688
+ "mlp_ratio": 4,
689
+ "global_pool": "map",
690
+ "use_checkpoint": False,
691
+ },
692
+ }
693
+
694
+ def resize_evaclip_pos_embed(model: VisionTransformer, interpolation: str = 'bicubic'):
695
+ # interpolate position embedding
696
+ orig_size = 24
697
+ new_size = 128
698
+ pos_tokens = model.pos_embed
699
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, model.embed_dim).permute(0, 3, 1, 2)
700
+ pos_tokens = torch.nn.functional.interpolate(
701
+ pos_tokens, size=(new_size, new_size), mode=interpolation, align_corners=False)
702
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
703
+ model.pos_embed = nn.Parameter(pos_tokens, requires_grad=True)
704
+ return model
705
+
706
+ def create_siglip_vit(
707
+ model_name: str = "siglip_so400m_patch14_384",
708
+ image_size: int = 384,
709
+ select_layer: int = -1,
710
+ path: str = "",
711
+ gradient_checkpointing: bool = False,
712
+ **kwargs,
713
+ ):
714
+ assert (
715
+ model_name in SigLIP_MODEL_CONFIG.keys()
716
+ ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
717
+
718
+ vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
719
+
720
+ if select_layer <= 0:
721
+ layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
722
+ else:
723
+ layers = min(vision_cfg.layers, select_layer)
724
+
725
+ model = VisionTransformer(
726
+ img_size=2048,
727
+ patch_size=16,
728
+ embed_dim=vision_cfg.width,
729
+ depth=layers,
730
+ num_heads=vision_cfg.heads,
731
+ mlp_ratio=vision_cfg.mlp_ratio,
732
+ class_token=vision_cfg.class_token,
733
+ global_pool=vision_cfg.global_pool,
734
+ dynamic_img_pad=False,
735
+ strict_img_size=False,
736
+ ignore_head=kwargs.get("ignore_head", False),
737
+ weight_init=kwargs.get("weight_init", "skip"),
738
+ num_classes=0
739
+ )
740
+
741
+ if path is not None and os.path.exists(path):
742
+ ckpt = path
743
+ else:
744
+ raise ValueError(f"Model checkpoint not found at {path}")
745
+ # state_dict = torch.load(ckpt, map_location="cpu")
746
+ # print('loading vision backbone from', path)
747
+
748
+ # msg = model.load_state_dict(state_dict, strict=False)
749
+ # print(msg)
750
+
751
+ if gradient_checkpointing:
752
+ model.set_grad_checkpointing(True)
753
+ return model
754
+
755
+ import os
756
+
757
+ from transformers import CLIPImageProcessor
758
+ import torch.distributed as dist
759
+
760
+ class OryxViTWrapper(nn.Module):
761
+ def __init__(self, vision_tower, path, args, delay_load=False):
762
+ super().__init__()
763
+
764
+ self.is_loaded = False
765
+
766
+ self.vision_tower_name = vision_tower
767
+ self.args = args
768
+ self.path = path
769
+
770
+ self.select_layer = -1
771
+ if self.select_layer < -1: self.select_layer += 1
772
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
773
+
774
+ self.output_dim = 1152
775
+ self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384',
776
+ gradient_checkpointing=False)
777
+ if not delay_load:
778
+ self.load_model()
779
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
780
+ # TODO: better detector is needed.
781
+ print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
782
+ self.load_model()
783
+
784
+ def load_model(self, device_map=None):
785
+ if self.is_loaded:
786
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
787
+ return
788
+
789
+ self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
790
+ self.image_processor.image_mean = [0.5, 0.5, 0.5]
791
+ self.image_processor.image_std = [0.5, 0.5, 0.5]
792
+ print("Loading vision model...")
793
+
794
+ # self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384',
795
+ # gradient_checkpointing=False)
796
+ for p in self.vision_tower.parameters():
797
+ p.requires_grad = False
798
+ self.vision_tower.eval()
799
+ self.is_loaded = True
800
+
801
+ def train(self, mode = True):
802
+ self.training = mode
803
+
804
+ if self.is_loaded:
805
+ self.vision_tower.eval()
806
+
807
+ def forward_func(self, images, force_fix_size=False, cal_attn_pool=False):
808
+ if type(images) is list:
809
+ xs = [x.to(self.dtype) for x in images]
810
+ image_features, img_size, cls_token = self.vision_tower(xs, cal_attn_pool=cal_attn_pool)
811
+ image_features = [x.to(images[0].dtype) for x in image_features]
812
+
813
+ else:
814
+ image_forward_outs, img_size, cls_token = self.vision_tower(images.to(self.dtype), cal_attn_pool=cal_attn_pool)
815
+ image_features = image_forward_outs.to(images.dtype)
816
+
817
+ return image_features, img_size, cls_token
818
+
819
+ def forward(self, images, cal_attn_pool=False):
820
+ with torch.no_grad():
821
+ image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool)
822
+ return image_features, img_size
823
+
824
+ @property
825
+ def dummy_feature(self):
826
+ return torch.zeros(1, 1152, device=self.device, dtype=self.dtype)
827
+
828
+ @property
829
+ def dtype(self):
830
+ return self.vision_tower.pos_embed.dtype
831
+
832
+ @property
833
+ def device(self):
834
+ return self.vision_tower.pos_embed.device
835
+
836
+ @property
837
+ def hidden_size(self):
838
+ return self.output_dim
839
+
840
+ @property
841
+ def config(self):
842
+ return type('OryxConfigWrapper', (), {
843
+ 'patch_size': 16,
844
+ })()
oryx/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+ import math
6
+
7
+ class IdentityMap(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def forward(self, x, *args, **kwargs):
12
+ return x
13
+
14
+ @property
15
+ def config(self):
16
+ return {"mm_projector_type": 'identity'}
17
+
18
+
19
+ class SimpleResBlock(nn.Module):
20
+ def __init__(self, channels):
21
+ super().__init__()
22
+ self.pre_norm = nn.LayerNorm(channels)
23
+
24
+ self.proj = nn.Sequential(
25
+ nn.Linear(channels, channels),
26
+ nn.GELU(),
27
+ nn.Linear(channels, channels)
28
+ )
29
+ def forward(self, x):
30
+ x = self.pre_norm(x)
31
+ return x + self.proj(x)
32
+
33
+ class SimpleMlp(nn.Module):
34
+ def __init__(self, in_channels, out_channels, twoview=False):
35
+ super().__init__()
36
+ self.proj = nn.Sequential(
37
+ nn.Linear(in_channels, out_channels),
38
+ nn.GELU(),
39
+ nn.Linear(out_channels, out_channels)
40
+ )
41
+
42
+ embed_std = 1 / math.sqrt(out_channels)
43
+ self.image_newline = nn.Parameter(
44
+ torch.randn(out_channels) * embed_std
45
+ )
46
+ self.image_begin = nn.Parameter(
47
+ torch.randn(out_channels) * embed_std
48
+ )
49
+ self.image_end = nn.Parameter(
50
+ torch.randn(out_channels) * embed_std
51
+ )
52
+
53
+ if twoview:
54
+ self.image_sep = nn.Parameter(
55
+ torch.randn(out_channels) * embed_std
56
+ )
57
+
58
+ def forward(self, x, size=(16,16), x2=None, size2=(16, 16), modalities='image'):
59
+
60
+ if modalities in ['image', 'text']:
61
+ h, w = size
62
+ dtype = x.dtype
63
+ x = x.reshape(x.shape[0], h, w, -1)
64
+ x = self.proj(x) #b,h,w, c
65
+ b, h, w, c = x.shape
66
+ x = torch.cat([
67
+ x,
68
+ self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
69
+ ], dim=2)
70
+ x = x.reshape(b, -1, c)
71
+
72
+ if x2 is not None:
73
+ h2, w2 = size2
74
+ x2 = x2.reshape(x2.shape[0], h2, w2, -1)
75
+ x2 = self.proj(x2) #b,h,w, c
76
+ b2, h2, w2, c2 = x2.shape
77
+ x2 = torch.cat([
78
+ x2,
79
+ self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
80
+ ], dim=2)
81
+ x2 = x2.reshape(b, -1, c)
82
+ sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
83
+ x = torch.cat([x, sep, x2], dim=1)
84
+
85
+ assert b == 1
86
+ assert b2 == 1 # only support batch size 1
87
+
88
+ begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
89
+ end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
90
+ x = torch.cat([begin, x, end], dim=1)
91
+ return x
92
+ elif modalities in ['video', 'video_long']:
93
+ # x2 is the true feature, ignore x
94
+ h, w = size
95
+ dtype = x.dtype
96
+ x = x.reshape(x.shape[0], h, w, -1)
97
+ x = self.proj(x).mean() * 0.0
98
+
99
+ h2, w2 = size2
100
+ x2 = x2.reshape(x2.shape[0], h2, w2, -1)
101
+ x2 = self.proj(x2) + x #b, h, w, c
102
+
103
+ b2, h2, w2, c = x2.shape
104
+ x2 = torch.cat([
105
+ x2,
106
+ self.image_newline.reshape(1, 1, 1, c).expand(b2, h2, 1, c).to(dtype)
107
+ ], dim=2)
108
+
109
+ x2 = x2.reshape(b2, -1, c)
110
+
111
+ sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, c).to(dtype)
112
+ x2 = torch.cat([x2, sep], dim=1)
113
+
114
+ x2 = x2.flatten(0, 1)
115
+
116
+ begin = self.image_begin.reshape(1, -1).expand(1, c).to(dtype)
117
+ end = self.image_end.reshape(1, -1).expand(1, c).to(dtype)
118
+ x2 = torch.cat([begin, x2, end], dim=0)
119
+ x2 = x2.unsqueeze(0)
120
+ return x2
121
+
122
+ def build_vision_projector(config, delay_load=False, **kwargs):
123
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
124
+
125
+ if projector_type == 'linear':
126
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
127
+
128
+ elif projector_type == 'simple_mlp_twoview':
129
+ return SimpleMlp(config.mm_hidden_size, config.hidden_size, twoview=True)
130
+
131
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
132
+ if mlp_gelu_match:
133
+ mlp_depth = int(mlp_gelu_match.group(1))
134
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
135
+ for _ in range(1, mlp_depth):
136
+ modules.append(nn.GELU())
137
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
138
+ return nn.Sequential(*modules)
139
+
140
+ mlp_gelu_resnet_match = re.match(r'^mlp(\d+)x_res(\d+)x_gelu$', projector_type)
141
+ if mlp_gelu_resnet_match:
142
+ mlp_depth = int(mlp_gelu_resnet_match.group(1))
143
+ res_depth = int(mlp_gelu_resnet_match.group(2))
144
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
145
+ for _ in range(1, mlp_depth):
146
+ modules.append(nn.GELU())
147
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
148
+ for _ in range(res_depth):
149
+ modules.append(SimpleResBlock(config.hidden_size))
150
+ return nn.Sequential(*modules)
151
+
152
+ if projector_type == 'identity':
153
+ return IdentityMap()
154
+
155
+ raise ValueError(f'Unknown projector type: {projector_type}')
oryx/model/multimodal_resampler/builder.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .masked_drop import MaskedDrop
4
+ from .spatial_pool import SpatialPool
5
+ from .qformer import Qformer
6
+ from .vlm_attention import VlmAttention
7
+ from .perceiver import DynamicCompressor
8
+
9
+ class IdentityMap(torch.nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def forward(self, x, *args, **kwargs):
14
+ return x
15
+
16
+ @property
17
+ def config(self):
18
+ return {"mm_resampler_type": None}
19
+
20
+ def build_vision_resampler(model_args, delay_load=False, **kwargs):
21
+ # import pdb;pdb.set_trace()
22
+ resampler_type = getattr(model_args, 'mm_resampler_type', None)
23
+ if resampler_type == 'masked_drop':
24
+ return MaskedDrop(model_args)
25
+ elif resampler_type == 'spatial_pool':
26
+ return SpatialPool(model_args, **kwargs)
27
+ elif resampler_type == 'qformer':
28
+ return Qformer(model_args, **kwargs)
29
+ elif resampler_type == 'vlm_attention':
30
+ return VlmAttention(model_args,**kwargs)
31
+ elif resampler_type == 'dynamic_compressor':
32
+ return DynamicCompressor(model_args, **kwargs)
33
+ elif resampler_type is None:
34
+ return IdentityMap()
35
+ else:
36
+ raise ValueError(f'Unknown resampler type: {resampler_type}')
oryx/model/multimodal_resampler/masked_drop.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import random
5
+
6
+
7
+ class MaskedDrop(nn.Module):
8
+ def __init__(self, model_args):
9
+ super().__init__()
10
+
11
+ self.mode = model_args.mm_mask_drop_mode
12
+ self.skip_percentage = model_args.mm_mask_drop_skip_percentage
13
+ self.ratio = model_args.mm_mask_drop_ratio
14
+ self.ratio_upper = model_args.mm_mask_drop_ratio_upper
15
+ self.ratio_lower = model_args.mm_mask_drop_ratio_lower
16
+
17
+ def forward(self, image_features, *args, **kwargs):
18
+
19
+ if not self.training:
20
+ return image_features
21
+
22
+ if self.skip_percentage > random.random():
23
+ return image_features
24
+
25
+ masked_features = []
26
+
27
+ for image_feature in image_features:
28
+ num_tokens = image_feature.shape[0]
29
+ if self.mode == 'fixed':
30
+ num_keep = int(num_tokens * self.ratio)
31
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
32
+ elif self.mode == 'range':
33
+ num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
34
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
35
+ elif self.mode == 'cls_only':
36
+ masked_features.append(image_feature[0:1])
37
+ else:
38
+ raise ValueError(f'Unexpected masked drop mode: {self.mode}')
39
+
40
+ if self.mode not in ['range'] and \
41
+ (type(image_features) is not list or self.mode in ['cls_only']):
42
+ masked_features = torch.stack(masked_features, dim=0)
43
+
44
+ return masked_features
45
+
46
+ @property
47
+ def config(self):
48
+ return {
49
+ 'mm_resampler_type': 'masked_drop',
50
+ 'mm_mask_drop_mode': self.mode,
51
+ 'mm_mask_drop_skip_percentage': self.skip_percentage,
52
+ 'mm_mask_drop_ratio': self.ratio,
53
+ 'mm_mask_drop_ratio_upper': self.ratio_upper,
54
+ 'mm_mask_drop_ratio_lower': self.ratio_lower,
55
+ }
56
+
57
+ def random_masking(self, x, len_keep):
58
+ """
59
+ Perform per-sample random masking by per-sample shuffling.
60
+ Per-sample shuffling is done by argsort random noise.
61
+ x: [N, L, D], sequence
62
+ """
63
+ N, L, D = x.shape # batch, length, dim
64
+
65
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
66
+
67
+ # sort noise for each sample
68
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
69
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
70
+
71
+ # keep the first subset
72
+ ids_keep = ids_shuffle[:, :len_keep]
73
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
74
+
75
+ # generate the binary mask: 0 is keep, 1 is remove
76
+ mask = torch.ones([N, L], device=x.device)
77
+ mask[:, :len_keep] = 0
78
+ # unshuffle to get the binary mask
79
+ mask = torch.gather(mask, dim=1, index=ids_restore)
80
+
81
+ return x_masked, mask, ids_restore
82
+
oryx/model/multimodal_resampler/perceiver.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import torch.nn.functional as F
5
+
6
+ class DynamicCompressor(nn.Module):
7
+ def __init__(self, model_args, vision_tower):
8
+ super().__init__()
9
+
10
+ self.out_channels = vision_tower.hidden_size
11
+ self.mid_channel = 256
12
+
13
+ self.vlm_query_projector = nn.Linear(self.out_channels, self.mid_channel)
14
+ self.vlm_key_projector = nn.Linear(self.out_channels, self.mid_channel)
15
+
16
+ def downsample(self, x):
17
+ return F.avg_pool2d(x, 2, 2)
18
+
19
+ def downsample_4(self, x):
20
+ return F.avg_pool2d(x, 4, 4)
21
+
22
+ def forward(self, image_features, forward_type, image_size=None):
23
+ if image_size is None:
24
+ ori_W = int(math.sqrt(image_features.shape[1]))
25
+ ori_H = int(ori_W)
26
+ else:
27
+ ori_H, ori_W = image_size
28
+ T, N, C = image_features.shape
29
+ image_features = image_features.view(T, ori_H, ori_W, C).permute(0, 3, 1, 2) # T, C, H, W
30
+
31
+ if forward_type == 'video':
32
+ image_features_pool = self.downsample(image_features)
33
+ image_feature_attn = image_features.reshape(T, C, ori_H // 2, 2, ori_W // 2, 2).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H // 2 * ori_W // 2, 4, C)
34
+ new_image_size = (ori_H // 2, ori_W // 2)
35
+ elif forward_type == 'image' or forward_type == 'text':
36
+ image_features_pool = image_features
37
+ image_feature_attn = image_features.reshape(T, C, ori_H, 1, ori_W, 1).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H * ori_W, 1, C)
38
+ new_image_size = (ori_H, ori_W)
39
+ elif forward_type == 'video_long':
40
+ image_features_pool = self.downsample_4(image_features)
41
+ image_feature_attn = image_features.reshape(T, C, ori_H // 4, 4, ori_W // 4, 4).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H // 4 * ori_W // 4, 16, C)
42
+ new_image_size = (ori_H // 4, ori_W // 4)
43
+ else:
44
+ raise NotImplementedError
45
+
46
+ image_features_pool = image_features_pool.flatten(2).permute(0, 2, 1) # T, H*W, C
47
+ new_t, new_p, _ = image_features_pool.shape
48
+
49
+ image_query = self.vlm_query_projector(image_features_pool).reshape(new_t*new_p, self.mid_channel)
50
+ image_key = self.vlm_key_projector(image_feature_attn).reshape(new_t*new_p, -1, self.mid_channel)
51
+
52
+ image_value = image_feature_attn.reshape(new_t*new_p, -1, self.out_channels)
53
+ image_attn = image_query[:,None] @ (image_key.transpose(-1,-2) / (image_key.shape[-1]**0.5))
54
+ image_attn = image_attn.nan_to_num()
55
+ attn_feat = (image_attn.softmax(-1) @ image_value).mean(1).reshape(new_t, new_p, C)
56
+
57
+ image_features_pool = image_features_pool + attn_feat
58
+
59
+ return image_features_pool, new_image_size
60
+
61
+ @property
62
+ def config(self):
63
+ return {
64
+ 'mm_resampler_type': 'dynamic_compressor',
65
+ 'mm_out_channels': self.out_channels,
66
+ }
67
+
68
+ @property
69
+ def hidden_size(self):
70
+ return self.out_channels
oryx/model/multimodal_resampler/qformer.py ADDED
@@ -0,0 +1,1287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ def disabled_train(self, mode=True):
52
+ """Overwrite model.train with this function to make sure train/eval mode
53
+ does not change anymore."""
54
+ return self
55
+
56
+
57
+ class BertEmbeddings(nn.Module):
58
+ """Construct the embeddings from word and position embeddings."""
59
+
60
+ def __init__(self, config):
61
+ super().__init__()
62
+ self.word_embeddings = nn.Embedding(
63
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
64
+ )
65
+ self.position_embeddings = nn.Embedding(
66
+ config.max_position_embeddings, config.hidden_size
67
+ )
68
+
69
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
70
+ # any TensorFlow checkpoint file
71
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
72
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
73
+
74
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
75
+ self.register_buffer(
76
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
77
+ )
78
+ self.position_embedding_type = getattr(
79
+ config, "position_embedding_type", "absolute"
80
+ )
81
+
82
+ self.config = config
83
+
84
+ def forward(
85
+ self,
86
+ input_ids=None,
87
+ position_ids=None,
88
+ query_embeds=None,
89
+ past_key_values_length=0,
90
+ ):
91
+ if input_ids is not None:
92
+ seq_length = input_ids.size()[1]
93
+ else:
94
+ seq_length = 0
95
+
96
+ if position_ids is None:
97
+ position_ids = self.position_ids[
98
+ :, past_key_values_length : seq_length + past_key_values_length
99
+ ].clone()
100
+
101
+ if input_ids is not None:
102
+ embeddings = self.word_embeddings(input_ids)
103
+ if self.position_embedding_type == "absolute":
104
+ position_embeddings = self.position_embeddings(position_ids)
105
+ embeddings = embeddings + position_embeddings
106
+
107
+ if query_embeds is not None:
108
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
109
+ else:
110
+ embeddings = query_embeds
111
+
112
+ embeddings = self.LayerNorm(embeddings)
113
+ embeddings = self.dropout(embeddings)
114
+ return embeddings
115
+
116
+
117
+ class BertSelfAttention(nn.Module):
118
+ def __init__(self, config, is_cross_attention):
119
+ super().__init__()
120
+ self.config = config
121
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
122
+ config, "embedding_size"
123
+ ):
124
+ raise ValueError(
125
+ "The hidden size (%d) is not a multiple of the number of attention "
126
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
127
+ )
128
+
129
+ self.num_attention_heads = config.num_attention_heads
130
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
131
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
132
+
133
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
134
+ if is_cross_attention:
135
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
136
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
137
+ else:
138
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
139
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
140
+
141
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
142
+ self.position_embedding_type = getattr(
143
+ config, "position_embedding_type", "absolute"
144
+ )
145
+ if (
146
+ self.position_embedding_type == "relative_key"
147
+ or self.position_embedding_type == "relative_key_query"
148
+ ):
149
+ self.max_position_embeddings = config.max_position_embeddings
150
+ self.distance_embedding = nn.Embedding(
151
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
152
+ )
153
+ self.save_attention = False
154
+
155
+ def save_attn_gradients(self, attn_gradients):
156
+ self.attn_gradients = attn_gradients
157
+
158
+ def get_attn_gradients(self):
159
+ return self.attn_gradients
160
+
161
+ def save_attention_map(self, attention_map):
162
+ self.attention_map = attention_map
163
+
164
+ def get_attention_map(self):
165
+ return self.attention_map
166
+
167
+ def transpose_for_scores(self, x):
168
+ new_x_shape = x.size()[:-1] + (
169
+ self.num_attention_heads,
170
+ self.attention_head_size,
171
+ )
172
+ x = x.view(*new_x_shape)
173
+ return x.permute(0, 2, 1, 3)
174
+
175
+ def forward(
176
+ self,
177
+ hidden_states,
178
+ attention_mask=None,
179
+ head_mask=None,
180
+ encoder_hidden_states=None,
181
+ encoder_attention_mask=None,
182
+ past_key_value=None,
183
+ output_attentions=False,
184
+ ):
185
+
186
+ # If this is instantiated as a cross-attention module, the keys
187
+ # and values come from an encoder; the attention mask needs to be
188
+ # such that the encoder's padding tokens are not attended to.
189
+ is_cross_attention = encoder_hidden_states is not None
190
+
191
+ if is_cross_attention:
192
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
193
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
194
+ attention_mask = encoder_attention_mask
195
+ elif past_key_value is not None:
196
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
197
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
198
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
199
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
200
+ else:
201
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
202
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
203
+
204
+ mixed_query_layer = self.query(hidden_states)
205
+
206
+ query_layer = self.transpose_for_scores(mixed_query_layer)
207
+
208
+ past_key_value = (key_layer, value_layer)
209
+
210
+ # Take the dot product between "query" and "key" to get the raw attention scores.
211
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
212
+
213
+ if (
214
+ self.position_embedding_type == "relative_key"
215
+ or self.position_embedding_type == "relative_key_query"
216
+ ):
217
+ seq_length = hidden_states.size()[1]
218
+ position_ids_l = torch.arange(
219
+ seq_length, dtype=torch.long, device=hidden_states.device
220
+ ).view(-1, 1)
221
+ position_ids_r = torch.arange(
222
+ seq_length, dtype=torch.long, device=hidden_states.device
223
+ ).view(1, -1)
224
+ distance = position_ids_l - position_ids_r
225
+ positional_embedding = self.distance_embedding(
226
+ distance + self.max_position_embeddings - 1
227
+ )
228
+ positional_embedding = positional_embedding.to(
229
+ dtype=query_layer.dtype
230
+ ) # fp16 compatibility
231
+
232
+ if self.position_embedding_type == "relative_key":
233
+ relative_position_scores = torch.einsum(
234
+ "bhld,lrd->bhlr", query_layer, positional_embedding
235
+ )
236
+ attention_scores = attention_scores + relative_position_scores
237
+ elif self.position_embedding_type == "relative_key_query":
238
+ relative_position_scores_query = torch.einsum(
239
+ "bhld,lrd->bhlr", query_layer, positional_embedding
240
+ )
241
+ relative_position_scores_key = torch.einsum(
242
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
243
+ )
244
+ attention_scores = (
245
+ attention_scores
246
+ + relative_position_scores_query
247
+ + relative_position_scores_key
248
+ )
249
+
250
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
251
+ if attention_mask is not None:
252
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
253
+ attention_scores = attention_scores + attention_mask
254
+
255
+ # Normalize the attention scores to probabilities.
256
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
257
+
258
+ if is_cross_attention and self.save_attention:
259
+ self.save_attention_map(attention_probs)
260
+ attention_probs.register_hook(self.save_attn_gradients)
261
+
262
+ # This is actually dropping out entire tokens to attend to, which might
263
+ # seem a bit unusual, but is taken from the original Transformer paper.
264
+ attention_probs_dropped = self.dropout(attention_probs)
265
+
266
+ # Mask heads if we want to
267
+ if head_mask is not None:
268
+ attention_probs_dropped = attention_probs_dropped * head_mask
269
+
270
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
271
+
272
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
273
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
274
+ context_layer = context_layer.view(*new_context_layer_shape)
275
+
276
+ outputs = (
277
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
278
+ )
279
+
280
+ outputs = outputs + (past_key_value,)
281
+ return outputs
282
+
283
+
284
+ class BertSelfOutput(nn.Module):
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
288
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
289
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
290
+
291
+ def forward(self, hidden_states, input_tensor):
292
+ hidden_states = self.dense(hidden_states)
293
+ hidden_states = self.dropout(hidden_states)
294
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
295
+ return hidden_states
296
+
297
+
298
+ class BertAttention(nn.Module):
299
+ def __init__(self, config, is_cross_attention=False):
300
+ super().__init__()
301
+ self.self = BertSelfAttention(config, is_cross_attention)
302
+ self.output = BertSelfOutput(config)
303
+ self.pruned_heads = set()
304
+
305
+ def prune_heads(self, heads):
306
+ if len(heads) == 0:
307
+ return
308
+ heads, index = find_pruneable_heads_and_indices(
309
+ heads,
310
+ self.self.num_attention_heads,
311
+ self.self.attention_head_size,
312
+ self.pruned_heads,
313
+ )
314
+
315
+ # Prune linear layers
316
+ self.self.query = prune_linear_layer(self.self.query, index)
317
+ self.self.key = prune_linear_layer(self.self.key, index)
318
+ self.self.value = prune_linear_layer(self.self.value, index)
319
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
320
+
321
+ # Update hyper params and store pruned heads
322
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
323
+ self.self.all_head_size = (
324
+ self.self.attention_head_size * self.self.num_attention_heads
325
+ )
326
+ self.pruned_heads = self.pruned_heads.union(heads)
327
+
328
+ def forward(
329
+ self,
330
+ hidden_states,
331
+ attention_mask=None,
332
+ head_mask=None,
333
+ encoder_hidden_states=None,
334
+ encoder_attention_mask=None,
335
+ past_key_value=None,
336
+ output_attentions=False,
337
+ ):
338
+ self_outputs = self.self(
339
+ hidden_states,
340
+ attention_mask,
341
+ head_mask,
342
+ encoder_hidden_states,
343
+ encoder_attention_mask,
344
+ past_key_value,
345
+ output_attentions,
346
+ )
347
+ attention_output = self.output(self_outputs[0], hidden_states)
348
+
349
+ outputs = (attention_output,) + self_outputs[
350
+ 1:
351
+ ] # add attentions if we output them
352
+ return outputs
353
+
354
+
355
+ class BertIntermediate(nn.Module):
356
+ def __init__(self, config):
357
+ super().__init__()
358
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
359
+ if isinstance(config.hidden_act, str):
360
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
361
+ else:
362
+ self.intermediate_act_fn = config.hidden_act
363
+
364
+ def forward(self, hidden_states):
365
+ hidden_states = self.dense(hidden_states)
366
+ hidden_states = self.intermediate_act_fn(hidden_states)
367
+ return hidden_states
368
+
369
+
370
+ class BertOutput(nn.Module):
371
+ def __init__(self, config):
372
+ super().__init__()
373
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
374
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
375
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
376
+
377
+ def forward(self, hidden_states, input_tensor):
378
+ hidden_states = self.dense(hidden_states)
379
+ hidden_states = self.dropout(hidden_states)
380
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
381
+ return hidden_states
382
+
383
+
384
+ class BertLayer(nn.Module):
385
+ def __init__(self, config, layer_num):
386
+ super().__init__()
387
+ self.config = config
388
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
389
+ self.seq_len_dim = 1
390
+ self.attention = BertAttention(config)
391
+ self.layer_num = layer_num
392
+ if (
393
+ self.config.add_cross_attention
394
+ and layer_num % self.config.cross_attention_freq == 0
395
+ ):
396
+ self.crossattention = BertAttention(
397
+ config, is_cross_attention=self.config.add_cross_attention
398
+ )
399
+ self.has_cross_attention = True
400
+ else:
401
+ self.has_cross_attention = False
402
+ self.intermediate = BertIntermediate(config)
403
+ self.output = BertOutput(config)
404
+
405
+ self.intermediate_query = BertIntermediate(config)
406
+ self.output_query = BertOutput(config)
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states,
411
+ attention_mask=None,
412
+ head_mask=None,
413
+ encoder_hidden_states=None,
414
+ encoder_attention_mask=None,
415
+ past_key_value=None,
416
+ output_attentions=False,
417
+ query_length=0,
418
+ ):
419
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
420
+ self_attn_past_key_value = (
421
+ past_key_value[:2] if past_key_value is not None else None
422
+ )
423
+ self_attention_outputs = self.attention(
424
+ hidden_states,
425
+ attention_mask,
426
+ head_mask,
427
+ output_attentions=output_attentions,
428
+ past_key_value=self_attn_past_key_value,
429
+ )
430
+ attention_output = self_attention_outputs[0]
431
+ outputs = self_attention_outputs[1:-1]
432
+
433
+ present_key_value = self_attention_outputs[-1]
434
+
435
+ if query_length > 0:
436
+ query_attention_output = attention_output[:, :query_length, :]
437
+
438
+ if self.has_cross_attention:
439
+ assert (
440
+ encoder_hidden_states is not None
441
+ ), "encoder_hidden_states must be given for cross-attention layers"
442
+ cross_attention_outputs = self.crossattention(
443
+ query_attention_output,
444
+ attention_mask,
445
+ head_mask,
446
+ encoder_hidden_states,
447
+ encoder_attention_mask,
448
+ output_attentions=output_attentions,
449
+ )
450
+ query_attention_output = cross_attention_outputs[0]
451
+ outputs = (
452
+ outputs + cross_attention_outputs[1:-1]
453
+ ) # add cross attentions if we output attention weights
454
+
455
+ layer_output = apply_chunking_to_forward(
456
+ self.feed_forward_chunk_query,
457
+ self.chunk_size_feed_forward,
458
+ self.seq_len_dim,
459
+ query_attention_output,
460
+ )
461
+ if attention_output.shape[1] > query_length:
462
+ layer_output_text = apply_chunking_to_forward(
463
+ self.feed_forward_chunk,
464
+ self.chunk_size_feed_forward,
465
+ self.seq_len_dim,
466
+ attention_output[:, query_length:, :],
467
+ )
468
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
469
+ else:
470
+ layer_output = apply_chunking_to_forward(
471
+ self.feed_forward_chunk,
472
+ self.chunk_size_feed_forward,
473
+ self.seq_len_dim,
474
+ attention_output,
475
+ )
476
+ outputs = (layer_output,) + outputs
477
+
478
+ outputs = outputs + (present_key_value,)
479
+
480
+ return outputs
481
+
482
+ def feed_forward_chunk(self, attention_output):
483
+ intermediate_output = self.intermediate(attention_output)
484
+ layer_output = self.output(intermediate_output, attention_output)
485
+ return layer_output
486
+
487
+ def feed_forward_chunk_query(self, attention_output):
488
+ intermediate_output = self.intermediate_query(attention_output)
489
+ layer_output = self.output_query(intermediate_output, attention_output)
490
+ return layer_output
491
+
492
+
493
+ class BertEncoder(nn.Module):
494
+ def __init__(self, config):
495
+ super().__init__()
496
+ self.config = config
497
+ self.layer = nn.ModuleList(
498
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
499
+ )
500
+
501
+ def forward(
502
+ self,
503
+ hidden_states,
504
+ attention_mask=None,
505
+ head_mask=None,
506
+ encoder_hidden_states=None,
507
+ encoder_attention_mask=None,
508
+ past_key_values=None,
509
+ use_cache=None,
510
+ output_attentions=False,
511
+ output_hidden_states=False,
512
+ return_dict=True,
513
+ query_length=0,
514
+ ):
515
+ all_hidden_states = () if output_hidden_states else None
516
+ all_self_attentions = () if output_attentions else None
517
+ all_cross_attentions = (
518
+ () if output_attentions and self.config.add_cross_attention else None
519
+ )
520
+
521
+ next_decoder_cache = () if use_cache else None
522
+
523
+ for i in range(self.config.num_hidden_layers):
524
+ layer_module = self.layer[i]
525
+ if output_hidden_states:
526
+ all_hidden_states = all_hidden_states + (hidden_states,)
527
+
528
+ layer_head_mask = head_mask[i] if head_mask is not None else None
529
+ past_key_value = past_key_values[i] if past_key_values is not None else None
530
+
531
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
532
+
533
+ if use_cache:
534
+ logger.warn(
535
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
536
+ )
537
+ use_cache = False
538
+
539
+ def create_custom_forward(module):
540
+ def custom_forward(*inputs):
541
+ return module(
542
+ *inputs, past_key_value, output_attentions, query_length
543
+ )
544
+
545
+ return custom_forward
546
+
547
+ layer_outputs = torch.utils.checkpoint.checkpoint(
548
+ create_custom_forward(layer_module),
549
+ hidden_states,
550
+ attention_mask,
551
+ layer_head_mask,
552
+ encoder_hidden_states,
553
+ encoder_attention_mask,
554
+ )
555
+ else:
556
+ layer_outputs = layer_module(
557
+ hidden_states,
558
+ attention_mask,
559
+ layer_head_mask,
560
+ encoder_hidden_states,
561
+ encoder_attention_mask,
562
+ past_key_value,
563
+ output_attentions,
564
+ query_length,
565
+ )
566
+
567
+ hidden_states = layer_outputs[0]
568
+ if use_cache:
569
+ next_decoder_cache += (layer_outputs[-1],)
570
+ if output_attentions:
571
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
572
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
573
+
574
+ if output_hidden_states:
575
+ all_hidden_states = all_hidden_states + (hidden_states,)
576
+
577
+ if not return_dict:
578
+ return tuple(
579
+ v
580
+ for v in [
581
+ hidden_states,
582
+ next_decoder_cache,
583
+ all_hidden_states,
584
+ all_self_attentions,
585
+ all_cross_attentions,
586
+ ]
587
+ if v is not None
588
+ )
589
+ return BaseModelOutputWithPastAndCrossAttentions(
590
+ last_hidden_state=hidden_states,
591
+ past_key_values=next_decoder_cache,
592
+ hidden_states=all_hidden_states,
593
+ attentions=all_self_attentions,
594
+ cross_attentions=all_cross_attentions,
595
+ )
596
+
597
+
598
+ class BertPooler(nn.Module):
599
+ def __init__(self, config):
600
+ super().__init__()
601
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
602
+ self.activation = nn.Tanh()
603
+
604
+ def forward(self, hidden_states):
605
+ # We "pool" the model by simply taking the hidden state corresponding
606
+ # to the first token.
607
+ first_token_tensor = hidden_states[:, 0]
608
+ pooled_output = self.dense(first_token_tensor)
609
+ pooled_output = self.activation(pooled_output)
610
+ return pooled_output
611
+
612
+
613
+ class BertPredictionHeadTransform(nn.Module):
614
+ def __init__(self, config):
615
+ super().__init__()
616
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
617
+ if isinstance(config.hidden_act, str):
618
+ self.transform_act_fn = ACT2FN[config.hidden_act]
619
+ else:
620
+ self.transform_act_fn = config.hidden_act
621
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
622
+
623
+ def forward(self, hidden_states):
624
+ hidden_states = self.dense(hidden_states)
625
+ hidden_states = self.transform_act_fn(hidden_states)
626
+ hidden_states = self.LayerNorm(hidden_states)
627
+ return hidden_states
628
+
629
+
630
+ class BertLMPredictionHead(nn.Module):
631
+ def __init__(self, config):
632
+ super().__init__()
633
+ self.transform = BertPredictionHeadTransform(config)
634
+
635
+ # The output weights are the same as the input embeddings, but there is
636
+ # an output-only bias for each token.
637
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
638
+
639
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
640
+
641
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
642
+ self.decoder.bias = self.bias
643
+
644
+ def forward(self, hidden_states):
645
+ hidden_states = self.transform(hidden_states)
646
+ hidden_states = self.decoder(hidden_states)
647
+ return hidden_states
648
+
649
+
650
+ class BertOnlyMLMHead(nn.Module):
651
+ def __init__(self, config):
652
+ super().__init__()
653
+ self.predictions = BertLMPredictionHead(config)
654
+
655
+ def forward(self, sequence_output):
656
+ prediction_scores = self.predictions(sequence_output)
657
+ return prediction_scores
658
+
659
+
660
+ class BertPreTrainedModel(PreTrainedModel):
661
+ """
662
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
663
+ models.
664
+ """
665
+
666
+ config_class = BertConfig
667
+ base_model_prefix = "bert"
668
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
669
+
670
+ def _init_weights(self, module):
671
+ """Initialize the weights"""
672
+ if isinstance(module, (nn.Linear, nn.Embedding)):
673
+ # Slightly different from the TF version which uses truncated_normal for initialization
674
+ # cf https://github.com/pytorch/pytorch/pull/5617
675
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
676
+ elif isinstance(module, nn.LayerNorm):
677
+ module.bias.data.zero_()
678
+ module.weight.data.fill_(1.0)
679
+ if isinstance(module, nn.Linear) and module.bias is not None:
680
+ module.bias.data.zero_()
681
+
682
+
683
+ class BertModel(BertPreTrainedModel):
684
+ """
685
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
686
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
687
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
688
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
689
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
690
+ input to the forward pass.
691
+ """
692
+
693
+ def __init__(self, config, add_pooling_layer=False):
694
+ super().__init__(config)
695
+ self.config = config
696
+
697
+ self.embeddings = BertEmbeddings(config)
698
+
699
+ self.encoder = BertEncoder(config)
700
+
701
+ self.pooler = BertPooler(config) if add_pooling_layer else None
702
+
703
+ self.init_weights()
704
+
705
+ def get_input_embeddings(self):
706
+ return self.embeddings.word_embeddings
707
+
708
+ def set_input_embeddings(self, value):
709
+ self.embeddings.word_embeddings = value
710
+
711
+ def _prune_heads(self, heads_to_prune):
712
+ """
713
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
714
+ class PreTrainedModel
715
+ """
716
+ for layer, heads in heads_to_prune.items():
717
+ self.encoder.layer[layer].attention.prune_heads(heads)
718
+
719
+ def get_extended_attention_mask(
720
+ self,
721
+ attention_mask: Tensor,
722
+ input_shape: Tuple[int],
723
+ device: device,
724
+ is_decoder: bool,
725
+ has_query: bool = False,
726
+ ) -> Tensor:
727
+ """
728
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
729
+
730
+ Arguments:
731
+ attention_mask (:obj:`torch.Tensor`):
732
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
733
+ input_shape (:obj:`Tuple[int]`):
734
+ The shape of the input to the model.
735
+ device: (:obj:`torch.device`):
736
+ The device of the input to the model.
737
+
738
+ Returns:
739
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
740
+ """
741
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
742
+ # ourselves in which case we just need to make it broadcastable to all heads.
743
+ if attention_mask.dim() == 3:
744
+ extended_attention_mask = attention_mask[:, None, :, :]
745
+ elif attention_mask.dim() == 2:
746
+ # Provided a padding mask of dimensions [batch_size, seq_length]
747
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
748
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
749
+ if is_decoder:
750
+ batch_size, seq_length = input_shape
751
+
752
+ seq_ids = torch.arange(seq_length, device=device)
753
+ causal_mask = (
754
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
755
+ <= seq_ids[None, :, None]
756
+ )
757
+
758
+ # add a prefix ones mask to the causal mask
759
+ # causal and attention masks must have same type with pytorch version < 1.3
760
+ causal_mask = causal_mask.to(attention_mask.dtype)
761
+
762
+ if causal_mask.shape[1] < attention_mask.shape[1]:
763
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
764
+ if has_query: # UniLM style attention mask
765
+ causal_mask = torch.cat(
766
+ [
767
+ torch.zeros(
768
+ (batch_size, prefix_seq_len, seq_length),
769
+ device=device,
770
+ dtype=causal_mask.dtype,
771
+ ),
772
+ causal_mask,
773
+ ],
774
+ axis=1,
775
+ )
776
+ causal_mask = torch.cat(
777
+ [
778
+ torch.ones(
779
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
780
+ device=device,
781
+ dtype=causal_mask.dtype,
782
+ ),
783
+ causal_mask,
784
+ ],
785
+ axis=-1,
786
+ )
787
+ extended_attention_mask = (
788
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
789
+ )
790
+ else:
791
+ extended_attention_mask = attention_mask[:, None, None, :]
792
+ else:
793
+ raise ValueError(
794
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
795
+ input_shape, attention_mask.shape
796
+ )
797
+ )
798
+
799
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
800
+ # masked positions, this operation will create a tensor which is 0.0 for
801
+ # positions we want to attend and -10000.0 for masked positions.
802
+ # Since we are adding it to the raw scores before the softmax, this is
803
+ # effectively the same as removing these entirely.
804
+ extended_attention_mask = extended_attention_mask.to(
805
+ dtype=self.dtype
806
+ ) # fp16 compatibility
807
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
808
+ return extended_attention_mask
809
+
810
+ def forward(
811
+ self,
812
+ input_ids=None,
813
+ attention_mask=None,
814
+ position_ids=None,
815
+ head_mask=None,
816
+ query_embeds=None,
817
+ encoder_hidden_states=None,
818
+ encoder_attention_mask=None,
819
+ past_key_values=None,
820
+ use_cache=None,
821
+ output_attentions=None,
822
+ output_hidden_states=None,
823
+ return_dict=None,
824
+ is_decoder=False,
825
+ ):
826
+ r"""
827
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
828
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
829
+ the model is configured as a decoder.
830
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
831
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
832
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
833
+ - 1 for tokens that are **not masked**,
834
+ - 0 for tokens that are **masked**.
835
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
836
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
837
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
838
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
839
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
840
+ use_cache (:obj:`bool`, `optional`):
841
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
842
+ decoding (see :obj:`past_key_values`).
843
+ """
844
+ output_attentions = (
845
+ output_attentions
846
+ if output_attentions is not None
847
+ else self.config.output_attentions
848
+ )
849
+ output_hidden_states = (
850
+ output_hidden_states
851
+ if output_hidden_states is not None
852
+ else self.config.output_hidden_states
853
+ )
854
+ return_dict = (
855
+ return_dict if return_dict is not None else self.config.use_return_dict
856
+ )
857
+
858
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
859
+
860
+ if input_ids is None:
861
+ assert (
862
+ query_embeds is not None
863
+ ), "You have to specify query_embeds when input_ids is None"
864
+
865
+ # past_key_values_length
866
+ past_key_values_length = (
867
+ past_key_values[0][0].shape[2] - self.config.query_length
868
+ if past_key_values is not None
869
+ else 0
870
+ )
871
+
872
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
873
+
874
+ embedding_output = self.embeddings(
875
+ input_ids=input_ids,
876
+ position_ids=position_ids,
877
+ query_embeds=query_embeds,
878
+ past_key_values_length=past_key_values_length,
879
+ )
880
+
881
+ input_shape = embedding_output.size()[:-1]
882
+ batch_size, seq_length = input_shape
883
+ device = embedding_output.device
884
+
885
+ if attention_mask is None:
886
+ attention_mask = torch.ones(
887
+ ((batch_size, seq_length + past_key_values_length)), device=device
888
+ )
889
+
890
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
891
+ # ourselves in which case we just need to make it broadcastable to all heads.
892
+ if is_decoder:
893
+ extended_attention_mask = self.get_extended_attention_mask(
894
+ attention_mask,
895
+ input_ids.shape,
896
+ device,
897
+ is_decoder,
898
+ has_query=(query_embeds is not None),
899
+ )
900
+ else:
901
+ extended_attention_mask = self.get_extended_attention_mask(
902
+ attention_mask, input_shape, device, is_decoder
903
+ )
904
+
905
+ # If a 2D or 3D attention mask is provided for the cross-attention
906
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
907
+ if encoder_hidden_states is not None:
908
+ if type(encoder_hidden_states) == list:
909
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
910
+ 0
911
+ ].size()
912
+ else:
913
+ (
914
+ encoder_batch_size,
915
+ encoder_sequence_length,
916
+ _,
917
+ ) = encoder_hidden_states.size()
918
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
919
+
920
+ if type(encoder_attention_mask) == list:
921
+ encoder_extended_attention_mask = [
922
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
923
+ ]
924
+ elif encoder_attention_mask is None:
925
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
926
+ encoder_extended_attention_mask = self.invert_attention_mask(
927
+ encoder_attention_mask
928
+ )
929
+ else:
930
+ encoder_extended_attention_mask = self.invert_attention_mask(
931
+ encoder_attention_mask
932
+ )
933
+ else:
934
+ encoder_extended_attention_mask = None
935
+
936
+ # Prepare head mask if needed
937
+ # 1.0 in head_mask indicate we keep the head
938
+ # attention_probs has shape bsz x n_heads x N x N
939
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
940
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
941
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
942
+
943
+ encoder_outputs = self.encoder(
944
+ embedding_output,
945
+ attention_mask=extended_attention_mask,
946
+ head_mask=head_mask,
947
+ encoder_hidden_states=encoder_hidden_states,
948
+ encoder_attention_mask=encoder_extended_attention_mask,
949
+ past_key_values=past_key_values,
950
+ use_cache=use_cache,
951
+ output_attentions=output_attentions,
952
+ output_hidden_states=output_hidden_states,
953
+ return_dict=return_dict,
954
+ query_length=query_length,
955
+ )
956
+ sequence_output = encoder_outputs[0]
957
+ pooled_output = (
958
+ self.pooler(sequence_output) if self.pooler is not None else None
959
+ )
960
+
961
+ if not return_dict:
962
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
963
+
964
+ return BaseModelOutputWithPoolingAndCrossAttentions(
965
+ last_hidden_state=sequence_output,
966
+ pooler_output=pooled_output,
967
+ past_key_values=encoder_outputs.past_key_values,
968
+ hidden_states=encoder_outputs.hidden_states,
969
+ attentions=encoder_outputs.attentions,
970
+ cross_attentions=encoder_outputs.cross_attentions,
971
+ )
972
+
973
+
974
+ class BertLMHeadModel(BertPreTrainedModel):
975
+
976
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
977
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
978
+
979
+ def __init__(self, config):
980
+ super().__init__(config)
981
+
982
+ self.bert = BertModel(config, add_pooling_layer=False)
983
+ self.cls = BertOnlyMLMHead(config)
984
+
985
+ self.init_weights()
986
+
987
+ def get_output_embeddings(self):
988
+ return self.cls.predictions.decoder
989
+
990
+ def set_output_embeddings(self, new_embeddings):
991
+ self.cls.predictions.decoder = new_embeddings
992
+
993
+ def forward(
994
+ self,
995
+ input_ids=None,
996
+ attention_mask=None,
997
+ position_ids=None,
998
+ head_mask=None,
999
+ query_embeds=None,
1000
+ encoder_hidden_states=None,
1001
+ encoder_attention_mask=None,
1002
+ labels=None,
1003
+ past_key_values=None,
1004
+ use_cache=True,
1005
+ output_attentions=None,
1006
+ output_hidden_states=None,
1007
+ return_dict=None,
1008
+ return_logits=False,
1009
+ is_decoder=True,
1010
+ reduction="mean",
1011
+ ):
1012
+ r"""
1013
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1014
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1015
+ the model is configured as a decoder.
1016
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1017
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1018
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1019
+ - 1 for tokens that are **not masked**,
1020
+ - 0 for tokens that are **masked**.
1021
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1022
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1023
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1024
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1025
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1026
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1027
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1028
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1029
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1030
+ use_cache (:obj:`bool`, `optional`):
1031
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1032
+ decoding (see :obj:`past_key_values`).
1033
+ Returns:
1034
+ Example::
1035
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1036
+ >>> import torch
1037
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1038
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1039
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1040
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1041
+ >>> outputs = model(**inputs)
1042
+ >>> prediction_logits = outputs.logits
1043
+ """
1044
+ return_dict = (
1045
+ return_dict if return_dict is not None else self.config.use_return_dict
1046
+ )
1047
+ if labels is not None:
1048
+ use_cache = False
1049
+ if past_key_values is not None:
1050
+ query_embeds = None
1051
+
1052
+ outputs = self.bert(
1053
+ input_ids,
1054
+ attention_mask=attention_mask,
1055
+ position_ids=position_ids,
1056
+ head_mask=head_mask,
1057
+ query_embeds=query_embeds,
1058
+ encoder_hidden_states=encoder_hidden_states,
1059
+ encoder_attention_mask=encoder_attention_mask,
1060
+ past_key_values=past_key_values,
1061
+ use_cache=use_cache,
1062
+ output_attentions=output_attentions,
1063
+ output_hidden_states=output_hidden_states,
1064
+ return_dict=return_dict,
1065
+ is_decoder=is_decoder,
1066
+ )
1067
+
1068
+ sequence_output = outputs[0]
1069
+ if query_embeds is not None:
1070
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1071
+
1072
+ prediction_scores = self.cls(sequence_output)
1073
+
1074
+ if return_logits:
1075
+ return prediction_scores[:, :-1, :].contiguous()
1076
+
1077
+ lm_loss = None
1078
+ if labels is not None:
1079
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1080
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1081
+ labels = labels[:, 1:].contiguous()
1082
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1083
+ lm_loss = loss_fct(
1084
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1085
+ labels.view(-1),
1086
+ )
1087
+ if reduction == "none":
1088
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1089
+
1090
+ if not return_dict:
1091
+ output = (prediction_scores,) + outputs[2:]
1092
+ return ((lm_loss,) + output) if lm_loss is not None else output
1093
+
1094
+ return CausalLMOutputWithCrossAttentions(
1095
+ loss=lm_loss,
1096
+ logits=prediction_scores,
1097
+ past_key_values=outputs.past_key_values,
1098
+ hidden_states=outputs.hidden_states,
1099
+ attentions=outputs.attentions,
1100
+ cross_attentions=outputs.cross_attentions,
1101
+ )
1102
+
1103
+ def prepare_inputs_for_generation(
1104
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1105
+ ):
1106
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1107
+ if attention_mask is None:
1108
+ attention_mask = input_ids.new_ones(input_ids.shape)
1109
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1110
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1111
+
1112
+ # cut decoder_input_ids if past is used
1113
+ if past is not None:
1114
+ input_ids = input_ids[:, -1:]
1115
+
1116
+ return {
1117
+ "input_ids": input_ids,
1118
+ "query_embeds": query_embeds,
1119
+ "attention_mask": attention_mask,
1120
+ "past_key_values": past,
1121
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1122
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1123
+ "is_decoder": True,
1124
+ }
1125
+
1126
+ def _reorder_cache(self, past, beam_idx):
1127
+ reordered_past = ()
1128
+ for layer_past in past:
1129
+ reordered_past += (
1130
+ tuple(
1131
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1132
+ ),
1133
+ )
1134
+ return reordered_past
1135
+
1136
+
1137
+ class BertForMaskedLM(BertPreTrainedModel):
1138
+
1139
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1140
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1141
+
1142
+ def __init__(self, config):
1143
+ super().__init__(config)
1144
+
1145
+ self.bert = BertModel(config, add_pooling_layer=False)
1146
+ self.cls = BertOnlyMLMHead(config)
1147
+
1148
+ self.init_weights()
1149
+
1150
+ def get_output_embeddings(self):
1151
+ return self.cls.predictions.decoder
1152
+
1153
+ def set_output_embeddings(self, new_embeddings):
1154
+ self.cls.predictions.decoder = new_embeddings
1155
+
1156
+ def forward(
1157
+ self,
1158
+ input_ids=None,
1159
+ attention_mask=None,
1160
+ position_ids=None,
1161
+ head_mask=None,
1162
+ query_embeds=None,
1163
+ encoder_hidden_states=None,
1164
+ encoder_attention_mask=None,
1165
+ labels=None,
1166
+ output_attentions=None,
1167
+ output_hidden_states=None,
1168
+ return_dict=None,
1169
+ return_logits=False,
1170
+ is_decoder=False,
1171
+ ):
1172
+ r"""
1173
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1174
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1175
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1176
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1177
+ """
1178
+
1179
+ return_dict = (
1180
+ return_dict if return_dict is not None else self.config.use_return_dict
1181
+ )
1182
+
1183
+ outputs = self.bert(
1184
+ input_ids,
1185
+ attention_mask=attention_mask,
1186
+ position_ids=position_ids,
1187
+ head_mask=head_mask,
1188
+ query_embeds=query_embeds,
1189
+ encoder_hidden_states=encoder_hidden_states,
1190
+ encoder_attention_mask=encoder_attention_mask,
1191
+ output_attentions=output_attentions,
1192
+ output_hidden_states=output_hidden_states,
1193
+ return_dict=return_dict,
1194
+ is_decoder=is_decoder,
1195
+ )
1196
+
1197
+ if query_embeds is not None:
1198
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1199
+ prediction_scores = self.cls(sequence_output)
1200
+
1201
+ if return_logits:
1202
+ return prediction_scores
1203
+
1204
+ masked_lm_loss = None
1205
+ if labels is not None:
1206
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1207
+ masked_lm_loss = loss_fct(
1208
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1209
+ )
1210
+
1211
+ if not return_dict:
1212
+ output = (prediction_scores,) + outputs[2:]
1213
+ return (
1214
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1215
+ )
1216
+
1217
+ return MaskedLMOutput(
1218
+ loss=masked_lm_loss,
1219
+ logits=prediction_scores,
1220
+ hidden_states=outputs.hidden_states,
1221
+ attentions=outputs.attentions,
1222
+ )
1223
+
1224
+
1225
+ class Qformer(nn.Module):
1226
+ def __init__(self, model_args, vision_tower):
1227
+ super().__init__()
1228
+
1229
+ self.depth = model_args.mm_qformer_depth
1230
+ self.num_latents = model_args.mm_qformer_latents
1231
+ self.pretrained = model_args.mm_qformer_pretrained
1232
+
1233
+ self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents)
1234
+
1235
+ if self.pretrained is not None:
1236
+ pretrained_dict = torch.load(self.pretrained, map_location='cpu')['model']
1237
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith('t5_proj')}
1238
+ # import pdb;pdb.set_trace()
1239
+ _ = self.load_state_dict(pretrained_dict,strict=False)
1240
+ print(_)
1241
+
1242
+ def build_Qformer(self, vision_width, cross_attention_freq, num_query_token):
1243
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
1244
+ encoder_config.encoder_width = vision_width
1245
+ # insert cross-attention layer every other block
1246
+ encoder_config.add_cross_attention = True
1247
+ encoder_config.cross_attention_freq = cross_attention_freq
1248
+ encoder_config.query_length = num_query_token
1249
+ Qformer = BertLMHeadModel(config=encoder_config)
1250
+ query_tokens = nn.Parameter(
1251
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
1252
+ )
1253
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
1254
+ Qformer.cls = None
1255
+ Qformer.bert.embeddings.word_embeddings = None
1256
+ Qformer.bert.embeddings.position_embeddings = None
1257
+ for layer in Qformer.bert.encoder.layer:
1258
+ layer.output = None
1259
+ layer.intermediate = None
1260
+ return Qformer, query_tokens, nn.LayerNorm(vision_width)
1261
+
1262
+ def forward(self, image_features, *args, **kwargs):
1263
+ x = self.ln_vision(image_features)
1264
+ image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device)
1265
+
1266
+ query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
1267
+ query_output = self.Qformer.bert(
1268
+ query_embeds=query_tokens,
1269
+ encoder_hidden_states=x,
1270
+ encoder_attention_mask=image_atts,
1271
+ return_dict=True,
1272
+ )
1273
+
1274
+ return query_output.last_hidden_state
1275
+
1276
+ @property
1277
+ def hidden_size(self):
1278
+ return 768
1279
+
1280
+ @property
1281
+ def config(self):
1282
+ return {
1283
+ 'mm_resampler_type': 'qformer',
1284
+ 'mm_qformer_depth': self.depth,
1285
+ 'mm_qformer_latents': self.num_latents,
1286
+ 'mm_qformer_pretrained': self.pretrained,
1287
+ }
oryx/model/multimodal_resampler/spatial_pool.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+
6
+ class SpatialPool(nn.Module):
7
+ def __init__(self, model_args, vision_tower):
8
+ super().__init__()
9
+
10
+ self.mode = model_args.mm_spatial_pool_mode
11
+ self.stride = model_args.mm_spatial_pool_stride
12
+ # import pdb; pdb.set_trace()
13
+ self.out_channels = getattr(model_args, 'mm_spatial_pool_out_channels', vision_tower.hidden_size)
14
+
15
+ if self.mode == 'average':
16
+ self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride)
17
+ elif self.mode == 'max':
18
+ self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride)
19
+ elif self.mode == 'conv':
20
+ self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride)
21
+ else:
22
+ raise ValueError(f'Unknown pooling mode: {self.pool}.')
23
+
24
+ def forward(self, image_features, images, *args, **kwargs):
25
+ ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2]))
26
+ ori_H = int(ori_W * images.shape[2] // images.shape[3])
27
+
28
+ B, _, F = image_features.shape
29
+
30
+ image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2)
31
+ image_features_spatial_pool = self.pool(image_features_spatial)
32
+
33
+ return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
34
+
35
+ @property
36
+ def config(self):
37
+ return {
38
+ 'mm_resampler_type': 'spatial_pool',
39
+ 'mm_spatial_pool_stride': self.stride,
40
+ 'mm_spatial_pool_mode': self.mode,
41
+ 'mm_spatial_pool_out_channels': self.out_channels,
42
+ }
oryx/model/multimodal_resampler/vlm_attention.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from transformers import BertTokenizer
8
+ from transformers.models.bert.modeling_bert import BertLMHeadModel as BertLMHeadModelRaw
9
+
10
+ from .qformer import BertConfig
11
+ from .qformer import BertLMHeadModel as BertLMHeadModelQF
12
+
13
+ class VlmAttention(nn.Module):
14
+ def __init__(self, model_args, vision_tower):
15
+ super().__init__()
16
+
17
+ pretrain_mm_mlp_adapter = getattr(model_args, "pretrain_mm_mlp_adapter", None)
18
+ pretrain_qformer = getattr(model_args, "mm_vlmattention_pretrained", None)
19
+ self.bert_type = getattr(model_args, "mm_vlmattention_bert_type", "qformer")
20
+ self.num_query = getattr(model_args, "mm_vlmattention_num_query", 32)
21
+ self.compress_type = getattr(model_args, "mm_vlmattention_compress_type", None)
22
+ self.mm_hidden_size = self.hidden_size = vision_tower.hidden_size
23
+ self.mm_vision_select_feature = model_args.mm_vision_select_feature
24
+ self.language_hidden_size = 4096
25
+ for_eval = True
26
+
27
+ if 'pretrain' in self.bert_type:
28
+ # for qformer that use evaclip for prtrain
29
+ att_feat_size = 1408
30
+ else:
31
+ att_feat_size = self.mm_hidden_size
32
+ self.vlm_att_tokenlizer, self.vlm_att_encoder, self.vlm_att_query = self.init_bert(att_feat_size, truncation_side="left")
33
+ self.vlm_att_projector = torch.nn.Linear(self.vlm_att_encoder.config.hidden_size, self.mm_hidden_size)
34
+ self.vlm_att_key_projector = torch.nn.Linear(self.mm_hidden_size, self.mm_hidden_size)
35
+ self.vlm_att_val_projector = torch.nn.Linear(self.mm_hidden_size, self.language_hidden_size)
36
+
37
+ if "raw" in self.bert_type:
38
+ self.vlm_att_bert_proj = torch.nn.Linear(att_feat_size, self.vlm_att_encoder.config.hidden_size)
39
+ elif "pretrain" in self.bert_type and self.mm_hidden_size!=att_feat_size:
40
+ self.vlm_att_bert_proj = torch.nn.Linear(self.mm_hidden_size, att_feat_size)
41
+ else:
42
+ self.vlm_att_bert_proj = None
43
+
44
+ def get_w(weights, keyword):
45
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
46
+
47
+ if 'qformer_pretrain' in self.bert_type:
48
+ self.vlm_att_ln = torch.nn.LayerNorm(att_feat_size)
49
+
50
+ if pretrain_qformer is not None:
51
+ print("Loading pretrained qformer weights...")
52
+ qformer_weight = torch.load(pretrain_qformer, map_location='cpu')['model']
53
+ bert_weight = {_key: qformer_weight[_key] for _key in qformer_weight if 'bert' in _key}
54
+ self.vlm_att_encoder.load_state_dict(get_w(bert_weight, 'Qformer'))
55
+ self.vlm_att_ln.load_state_dict(get_w(qformer_weight, 'ln_vision'))
56
+ self.vlm_att_query.data = qformer_weight['query_tokens']
57
+
58
+ if 'freeze_all' in self.bert_type:
59
+ print("Freezing all qformer weights...")
60
+ self.vlm_att_encoder.requires_grad_(False)
61
+ self.vlm_att_ln.requires_grad_(False)
62
+ self.vlm_att_query.requires_grad_(False)
63
+ self.vlm_att_projector.requires_grad_(False)
64
+ self.vlm_att_key_projector.requires_grad_(False)
65
+ self.vlm_att_val_projector.requires_grad_(False)
66
+ elif 'freeze' in self.bert_type:
67
+ print("Freezing pretrained qformer weights...")
68
+ self.vlm_att_encoder.requires_grad_(False)
69
+ self.vlm_att_ln.requires_grad_(False)
70
+ self.vlm_att_query.requires_grad_(False)
71
+
72
+
73
+ if pretrain_mm_mlp_adapter is not None:
74
+ att_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
75
+ else:
76
+ trainable_module = ['vlm_att_encoder', 'vlm_att_projector', 'vlm_att_key_projector',
77
+ 'vlm_att_val_projector', 'vlm_att_query', 'vlm_att_visual_proj',
78
+ 'vlm_att_ln']
79
+ if hasattr(model_args, 'model_name_or_path'):
80
+ model_save_path = model_args.model_name_or_path
81
+ else:
82
+ model_save_path = model_args.model_path
83
+ model_idx_path = getattr(model_args, 'model_path', model_save_path)
84
+ weight_file = json.load(open(os.path.join(model_idx_path, 'pytorch_model.bin.index.json'), 'r'))['weight_map']
85
+ model_path = set([weight_file[_key] for _key in weight_file if any([_module in _key for _module in trainable_module])])
86
+ att_projector_weights = {}
87
+ for _model in model_path:
88
+ att_projector_weights.update(torch.load(os.path.join(model_idx_path, _model), map_location='cpu'))
89
+ if len(att_projector_weights) == 0:
90
+ return
91
+
92
+ bert_dict = get_w(att_projector_weights, 'vlm_att_encoder')
93
+ if "bert.embeddings.position_ids" not in bert_dict and "raw_bert" not in self.bert_type:
94
+ bert_dict["bert.embeddings.position_ids"] = self.vlm_att_encoder.bert.embeddings.position_ids
95
+ print('Loading pretrained weights...')
96
+ # import pdb;pdb.set_trace()
97
+
98
+ self.vlm_att_encoder.load_state_dict(bert_dict)
99
+ self.vlm_att_projector.load_state_dict(get_w(att_projector_weights, 'vlm_att_projector'))
100
+ self.vlm_att_key_projector.load_state_dict(get_w(att_projector_weights, 'vlm_att_key_projector'))
101
+ self.vlm_att_val_projector.load_state_dict(get_w(att_projector_weights, 'vlm_att_val_projector'))
102
+
103
+ if "qformer" in self.bert_type:
104
+ print('Loading vlm_att_query weights...')
105
+ self.vlm_att_query.data = att_projector_weights['model.vlm_att_query']
106
+ if "pretrain" in self.bert_type:
107
+ print('Loading vlm_att_ln weights...')
108
+ self.vlm_att_ln.load_state_dict(get_w(att_projector_weights, 'vlm_att_ln'))
109
+
110
+ if self.vlm_att_bert_proj is not None:
111
+ print('Loading vlm_att_bert_proj weights...')
112
+ self.vlm_att_bert_proj.load_state_dict(get_w(att_projector_weights, 'vlm_att_bert_proj'))
113
+
114
+ if for_eval:
115
+ weight_type = torch.float16
116
+ # import pdb;pdb.set_trace()
117
+ # device_type = self.mm_projector[0].weight.device
118
+ device_type = vision_tower.vision_tower.patch_embed.proj.weight.device
119
+ self.vlm_att_encoder = self.vlm_att_encoder.to(device=device_type, dtype=weight_type)
120
+ self.vlm_att_projector = self.vlm_att_projector.to(device=device_type, dtype=weight_type)
121
+ self.vlm_att_key_projector = self.vlm_att_key_projector.to(device=device_type, dtype=weight_type)
122
+ self.vlm_att_val_projector = self.vlm_att_val_projector.to(device=device_type, dtype=weight_type)
123
+
124
+ if "qformer" in self.bert_type:
125
+ self.vlm_att_query.data = self.vlm_att_query.data.to(device=device_type, dtype=weight_type)
126
+ if "pretrain" in self.bert_type:
127
+ self.vlm_att_ln = self.vlm_att_ln.to(device=device_type, dtype=weight_type)
128
+
129
+ if self.vlm_att_bert_proj is not None:
130
+ self.vlm_att_bert_proj = self.vlm_att_bert_proj.to(device=device_type, dtype=weight_type)
131
+
132
+ def forward(self, image_features, prompts=None, image_counts=None, long_video=False):
133
+ img_feat_lst = []
134
+ # import pdb;pdb.set_trace()
135
+ if image_counts is None:
136
+ assert len(image_features) == len(prompts), f"Size mismatch! image_features: {len(image_features)}, prompts: {len(prompts)}"
137
+ else:
138
+ assert len(prompts) == len(image_counts), f"Size mismatch! prompts: {len(prompts)}, image_counts: {len(image_counts)}"
139
+ image_atts = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device)
140
+
141
+ total_count = 0
142
+ # calculate each image feat according to the prompt
143
+ # import pdb;pdb.set_trace()
144
+ for _idx in range(len(prompts)):
145
+ assert isinstance(prompts[_idx], list), f"Prompt should be a list, but got {type(prompts[_idx])}"
146
+ input_token = self.vlm_att_tokenlizer(
147
+ prompts[_idx],
148
+ padding='longest',
149
+ truncation=True,
150
+ max_length=256,
151
+ return_tensors="pt"
152
+ ).to(image_features.device)
153
+
154
+ input_ids = input_token.input_ids
155
+ attention_masks = input_token.attention_mask
156
+
157
+ if image_counts is None:
158
+ img_feat_prompt = image_features[_idx, None].expand(len(prompts[_idx]), -1, -1)
159
+ img_att_prompt = image_atts[_idx, None].expand(len(prompts[_idx]), -1)
160
+ else:
161
+ # shape: [prompt_num*frame_num, image_shape, feat_dim]
162
+ img_feat_prompt = image_features[total_count:total_count+image_counts[_idx]]
163
+ img_feat_prompt = img_feat_prompt[None].expand(len(prompts[_idx]), -1, -1, -1).flatten(0,1)
164
+ img_att_prompt = image_atts[total_count:total_count+image_counts[_idx]]
165
+ img_att_prompt = img_att_prompt[None].expand(len(prompts[_idx]), -1, -1).flatten(0,1)
166
+ input_ids = input_ids[:,None].expand(-1, image_counts[_idx], -1).flatten(0,1)
167
+ attention_masks = attention_masks[:,None].expand(-1, image_counts[_idx], -1).flatten(0,1)
168
+ total_count += image_counts[_idx]
169
+
170
+ if "pretrain" in self.bert_type and self.vlm_att_bert_proj is not None:
171
+ bert_feat = self.vlm_att_bert_proj(img_feat_prompt)
172
+ else:
173
+ bert_feat = img_feat_prompt.clone()
174
+
175
+ # remove cls embedding
176
+ if self.mm_vision_select_feature == 'patch':
177
+ if img_feat_prompt.shape[1]%2 == 1:
178
+ img_feat_prompt = img_feat_prompt[:, 1:]
179
+
180
+ if "qformer" in self.bert_type:
181
+ query_tokens = self.vlm_att_query.expand(bert_feat.shape[0], -1, -1)
182
+ query_atts = torch.cat([torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(bert_feat.device),
183
+ attention_masks],dim=1)
184
+
185
+ if 'pretrain' in self.bert_type:
186
+ mm_img_in = self.vlm_att_ln(bert_feat)
187
+ else:
188
+ mm_img_in = bert_feat
189
+
190
+ if long_video:
191
+ outputs = []
192
+ block_size = 64
193
+ for L in range(0, len(input_ids), block_size):
194
+ R = L + block_size
195
+ mm_output = self.vlm_att_encoder.bert(
196
+ input_ids[L:R],
197
+ query_embeds=query_tokens[L:R],
198
+ attention_mask=query_atts[L:R],
199
+ encoder_hidden_states=mm_img_in[L:R],
200
+ encoder_attention_mask=img_att_prompt[L:R],
201
+ return_dict=True,
202
+ )
203
+ mm_output = mm_output.last_hidden_state[:,:query_tokens.shape[1]]
204
+ outputs.append(mm_output)
205
+ mm_output = torch.cat(outputs)
206
+ torch.cuda.empty_cache()
207
+ else:
208
+ mm_output = self.vlm_att_encoder.bert(
209
+ input_ids,
210
+ query_embeds=query_tokens,
211
+ attention_mask=query_atts,
212
+ encoder_hidden_states=mm_img_in,
213
+ encoder_attention_mask=img_att_prompt,
214
+ return_dict=True,
215
+ )
216
+ mm_output = mm_output.last_hidden_state[:,:query_tokens.shape[1]]
217
+
218
+ elif "raw" in self.bert_type:
219
+ if self.mm_vision_select_feature == 'patch' and bert_feat.shape[1]%2 == 1:
220
+ bert_feat = bert_feat[:, 1:]
221
+ img_att_prompt = img_att_prompt[:, 1:]
222
+
223
+ mm_output = self.vlm_att_encoder.bert(
224
+ input_ids,
225
+ attention_mask=attention_masks,
226
+ encoder_hidden_states=self.vlm_att_bert_proj(bert_feat),
227
+ encoder_attention_mask=img_att_prompt,
228
+ return_dict=True,
229
+ )
230
+ mm_output = mm_output.last_hidden_state
231
+ else:
232
+ raise ValueError(f'Unexpected bert type: {self.bert_type}')
233
+
234
+ text_q = self.vlm_att_projector(mm_output)
235
+ # shape: [prompt_num*frame_num, feat_dim]
236
+ # ctx_embed,vis_embed = self.token_generation(text_q, img_feat_prompt, long_video=long_video)
237
+ final_token = self.token_generation(text_q, img_feat_prompt, long_video=long_video)
238
+
239
+ if image_counts is not None:
240
+ # shape: [prompt_num, frame_num*image_shape, feat_dim]
241
+ final_token = final_token.reshape(len(prompts[_idx]), image_counts[_idx], *final_token.shape[-2:])
242
+ final_token = final_token.flatten(1,2)
243
+ img_feat_lst.append(final_token)
244
+
245
+ return img_feat_lst
246
+
247
+ def init_bert(self, vision_width, cross_attention_freq=2, truncation_side="right"):
248
+ # initialize BERT tokenizer
249
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side)
250
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
251
+ # initialize BERT
252
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
253
+ encoder_config.encoder_width = vision_width
254
+ # insert cross-attention layer every other block
255
+ encoder_config.add_cross_attention = True
256
+ encoder_config.cross_attention_freq = cross_attention_freq
257
+ query_tokens = None
258
+
259
+ if "qformer" in self.bert_type:
260
+ mm_model = BertLMHeadModelQF.from_pretrained(
261
+ "bert-base-uncased", config=encoder_config
262
+ )
263
+ query_tokens = nn.Parameter(
264
+ torch.zeros(1, self.num_query, encoder_config.hidden_size)
265
+ )
266
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
267
+ elif "raw" in self.bert_type:
268
+ encoder_config.is_decoder = True
269
+ mm_model = BertLMHeadModelRaw.from_pretrained(
270
+ "bert-base-uncased", config=encoder_config
271
+ )
272
+ else:
273
+ raise NotImplementedError("BERT type not implemented...")
274
+
275
+ mm_model.resize_token_embeddings(len(tokenizer))
276
+ mm_model.cls = None
277
+
278
+ if "layer" in self.bert_type:
279
+ layer_num = int(self.bert_type.split(':')[-1])
280
+ mm_model.bert.encoder.layer = mm_model.bert.encoder.layer[:layer_num]
281
+ print(f"Only use {layer_num} layers in BERT...")
282
+
283
+ return tokenizer, mm_model, query_tokens
284
+
285
+
286
+ def token_generation(self, text_q, vis_embed, long_video=False):
287
+ ctx_embed = self.vlm_att_key_projector(vis_embed)
288
+ # Key part 1: calculate context-related embedding
289
+ ctx_embed = text_q @ ctx_embed.transpose(-1,-2)
290
+ ctx_embed = ctx_embed / (vis_embed.shape[-1] ** 0.5)
291
+ if not long_video:
292
+ ctx_embed = (ctx_embed.softmax(-1) @ vis_embed).mean(1)
293
+ else:
294
+ block_size = 64
295
+ outputs = []
296
+ ctx_score = ctx_embed.softmax(-1)
297
+ for L in range(0, len(ctx_score), block_size):
298
+ R = L + block_size
299
+ sub_embed = (ctx_score[L:R] @ vis_embed[L:R]).mean(1)
300
+ outputs.append(sub_embed)
301
+ ctx_embed = torch.cat(outputs)
302
+ torch.cuda.empty_cache()
303
+ ctx_embed = self.vlm_att_val_projector(ctx_embed[:,None])
304
+
305
+ # Key part 2: calculate visual embedding
306
+ if self.compress_type is not None:
307
+ if 'grid' in self.compress_type:
308
+ grid_size = int(self.compress_type.split('grid:')[-1])
309
+ cur_shape = int(vis_embed.shape[1]**0.5)
310
+ assert grid_size > 1, f'Grid size should be larger than 1, but got {grid_size}'
311
+ vis_embed = vis_embed.reshape(vis_embed.shape[0], cur_shape, cur_shape, -1)
312
+ grid_stride = cur_shape // grid_size
313
+ vis_embed = F.avg_pool2d(vis_embed.permute(0, 3, 1, 2),
314
+ padding=0,
315
+ kernel_size=grid_stride,
316
+ stride=grid_stride)
317
+
318
+ vis_embed = vis_embed.permute(0, 2, 3, 1).flatten(1,2)
319
+ elif 'mean' in self.compress_type:
320
+ # import pdb;pdb.set_trace()
321
+ vis_embed = vis_embed.mean(dim=1, keepdim=True)
322
+
323
+ # import pdb ; pdb.set_trace()
324
+ # concat token in shape (B, n+1, C)
325
+ vis_embed = self.mm_projector(vis_embed)
326
+ final_token = torch.cat([ctx_embed, vis_embed], dim=1)
327
+ return final_token
328
+
329
+ @property
330
+ def config(self):
331
+ return {
332
+ 'mm_resampler_type': 'vlm_attention',
333
+ 'mm_vlmattention_bert_type': self.bert_type,
334
+ 'mm_vlmattention_num_query': self.num_query,
335
+ 'mm_vlmattention_compress_type': self.compress_type,
336
+ }
337
+
oryx/model/oryx_arch.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .multimodal_encoder.builder import build_vision_tower
7
+ from .multimodal_resampler.builder import build_vision_resampler
8
+ from .multimodal_projector.builder import build_vision_projector
9
+
10
+ from oryx.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
11
+
12
+ import ast
13
+ import torch.distributed as dist
14
+
15
+ class OryxMetaModel:
16
+
17
+ def __init__(self, config):
18
+ super(OryxMetaModel, self).__init__(config)
19
+
20
+ if hasattr(config, "mm_vision_tower"):
21
+ self.vision_tower = build_vision_tower(config, delay_load=True)
22
+ self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
23
+ self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
24
+ def get_vision_tower(self):
25
+ vision_tower = getattr(self, 'vision_tower', None)
26
+ if type(vision_tower) is list:
27
+ vision_tower = vision_tower[0]
28
+ return vision_tower
29
+
30
+ def initialize_vision_modules(self, model_args, fsdp=None):
31
+ vision_tower = model_args.vision_tower
32
+ mm_vision_select_layer = model_args.mm_vision_select_layer
33
+ mm_vision_select_feature = model_args.mm_vision_select_feature
34
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
35
+
36
+ self.config.mm_vision_tower = vision_tower
37
+
38
+ if self.get_vision_tower() is None:
39
+ vision_tower = build_vision_tower(model_args)
40
+ vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
41
+ ## Get the mm_spatial_pool_mode and mm_spatial_pool_stride
42
+ for k, v in vision_resampler.config.items():
43
+ setattr(self.config, k, v)
44
+
45
+ if fsdp is not None and len(fsdp) > 0:
46
+ self.vision_tower = [vision_tower]
47
+ self.vision_resampler = [vision_resampler]
48
+ else:
49
+ self.vision_tower = vision_tower
50
+ self.vision_resampler = vision_resampler
51
+ else:
52
+ if fsdp is not None and len(fsdp) > 0:
53
+ vision_resampler = self.vision_resampler[0]
54
+ vision_tower = self.vision_tower[0]
55
+ else:
56
+ vision_resampler = self.vision_resampler
57
+ vision_tower = self.vision_tower
58
+ vision_tower.load_model()
59
+
60
+ # In case it is frozen by LoRA
61
+ for p in self.vision_resampler.parameters():
62
+ p.requires_grad = True
63
+
64
+ self.config.use_mm_proj = True
65
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
66
+ self.config.mm_hidden_size = getattr(vision_resampler, 'hidden_size', vision_tower.hidden_size)
67
+
68
+ self.config.mm_vision_select_layer = mm_vision_select_layer
69
+ self.config.mm_vision_select_feature = mm_vision_select_feature
70
+
71
+ if getattr(self, 'mm_projector', None) is None:
72
+ self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
73
+ else:
74
+ for p in self.mm_projector.parameters():
75
+ p.requires_grad = True
76
+
77
+ if pretrain_mm_mlp_adapter is not None:
78
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
79
+ def get_w(weights, keyword):
80
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
81
+
82
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
83
+ incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, 'vision_resampler'), strict=False)
84
+ print(incompatible_keys)
85
+
86
+
87
+ class OryxMetaForCausalLM(ABC):
88
+
89
+ @abstractmethod
90
+ def get_model(self):
91
+ pass
92
+
93
+ def get_vision_tower(self):
94
+ return self.get_model().get_vision_tower()
95
+
96
+ def prepare_inputs_labels_for_multimodal(
97
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
98
+ images, modalities, image_sizes=None, images_highres=None):
99
+ # print(modalities, len(images), len(images_highres), len(input_ids))
100
+ vision_tower = self.get_vision_tower()
101
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
102
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
103
+
104
+ if isinstance(modalities, str):
105
+ modalities = [modalities]
106
+
107
+ video_idx_in_batch = []
108
+ for modal in range(len(modalities)):
109
+ if 'video' in modalities[modal]:
110
+ video_idx_in_batch.append(modal)
111
+
112
+ # Fix training with deepspeed zero3
113
+ num_modality = len(modalities)
114
+ # try:
115
+ # world_size = dist.get_world_size()
116
+ # tensor_in = torch.zeros(1, dtype=torch.int64, device=images[0].device).fill_(num_modality)
117
+ # tensor_out = torch.zeros(world_size, dtype=torch.int64, device=images[0].device)
118
+ # dist.all_gather_into_tensor(tensor_out, tensor_in)
119
+ # max_num_modality = tensor_out.max().item()
120
+ # except:
121
+ max_num_modality = num_modality
122
+
123
+ aimg = images[-1]
124
+ lowres_img = []
125
+ for idx, img_feat in enumerate(images):
126
+ if idx in video_idx_in_batch:
127
+ img_feat = aimg.new(1, 3, 128, 128).fill_(0)
128
+ lowres_img.append(img_feat)
129
+
130
+ # Fix training with deepspeed zero3
131
+ if max_num_modality > num_modality:
132
+ for _ in range(max_num_modality - num_modality):
133
+ lowres_img.append(aimg.new(1, 3, 64, 64).fill_(0))
134
+ images_highres.append(aimg.new(1, 3, 64, 64).fill_(0))
135
+ modalities.append('image')
136
+
137
+ lowres_img_features, lowres_img_sizes = self.get_model().get_vision_tower()(lowres_img)
138
+ highres_img_features = []
139
+ highres_img_sizes = []
140
+ for idx, img_feat in enumerate(images_highres):
141
+ if img_feat.ndim == 5:
142
+ img_feat = img_feat.squeeze(1)
143
+ highres_img_feature, highres_img_size = self.get_model().get_vision_tower()(img_feat)
144
+ highres_img_features.append(highres_img_feature)
145
+ highres_img_sizes.append(highres_img_size)
146
+ image_features = []
147
+ for idx in range(len(modalities)):
148
+ img_feat_highres, img_size_highres = self.get_model().vision_resampler(highres_img_features[idx],
149
+ modalities[idx],
150
+ highres_img_sizes[idx])
151
+ img_feat_lowres, img_size_lowres = self.get_model().vision_resampler(lowres_img_features[idx],
152
+ modalities[idx],
153
+ lowres_img_sizes[idx])
154
+ img_feat = self.get_model().mm_projector(img_feat_lowres,
155
+ img_size_lowres,
156
+ img_feat_highres,
157
+ img_size_highres,
158
+ modalities[idx])
159
+ image_features.append(img_feat.flatten(0, 1))
160
+
161
+ if max_num_modality > num_modality:
162
+ image_features = image_features[:num_modality]
163
+ modalities = modalities[:num_modality]
164
+
165
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
166
+ raise NotImplementedError
167
+
168
+ # Let's just add dummy tensors if they do not exist,
169
+ # it is a headache to deal with None all the time.
170
+ # But it is not ideal, and if you have a better idea,
171
+ # please open an issue / submit a PR, thanks.
172
+ _labels = labels
173
+ _position_ids = position_ids
174
+ _attention_mask = attention_mask
175
+ if attention_mask is None:
176
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
177
+ else:
178
+ attention_mask = attention_mask.bool()
179
+ if position_ids is None:
180
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
181
+ if labels is None:
182
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
183
+
184
+ # remove the padding using attention_mask -- FIXME
185
+ _input_ids = input_ids
186
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
187
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
188
+
189
+ new_input_embeds = []
190
+ new_labels = []
191
+ cur_image_idx = 0
192
+ for batch_idx, cur_input_ids in enumerate(input_ids):
193
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
194
+ if num_images == 0:
195
+ cur_image_features = image_features[cur_image_idx]
196
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
197
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
198
+ new_input_embeds.append(cur_input_embeds)
199
+ new_labels.append(labels[batch_idx])
200
+ cur_image_idx += 1
201
+ continue
202
+
203
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
204
+ cur_input_ids_noim = []
205
+ cur_labels = labels[batch_idx]
206
+ cur_labels_noim = []
207
+ for i in range(len(image_token_indices) - 1):
208
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
209
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
210
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
211
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
212
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
213
+ cur_new_input_embeds = []
214
+ cur_new_labels = []
215
+
216
+ for i in range(num_images + 1):
217
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
218
+ cur_new_labels.append(cur_labels_noim[i])
219
+ if i < num_images:
220
+ cur_image_features = image_features[cur_image_idx]
221
+ cur_image_idx += 1
222
+ cur_new_input_embeds.append(cur_image_features)
223
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
224
+
225
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
226
+
227
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
228
+ cur_new_labels = torch.cat(cur_new_labels)
229
+
230
+ new_input_embeds.append(cur_new_input_embeds)
231
+ new_labels.append(cur_new_labels)
232
+
233
+ # Truncate sequences to max length as image embeddings can make the sequence longer
234
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
235
+ modality_max_length = getattr(self.config, 'modality_max_length', None)
236
+
237
+ if modality_max_length is None or modality_max_length == "None":
238
+ if tokenizer_model_max_length is not None:
239
+ # if new_input_embeds[0] > tokenizer_model_max_length:
240
+ # print(f"Embeds length ({new_input_embeds.shape[0]}) larger than max length")
241
+ new_input_embeds =[x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
242
+ new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
243
+ else:
244
+ modality_max_length = ast.literal_eval(modality_max_length)
245
+ modality_max_length_dict = {"image": modality_max_length[0], "text": modality_max_length[1], "video": modality_max_length[2]}
246
+ new_input_embeds =[x[: modality_max_length_dict[modality]] for x, modality in zip(new_input_embeds, modalities)]
247
+ new_labels = [x[: modality_max_length_dict[modality]] for x, modality in zip(new_labels, modalities)]
248
+
249
+ # Combine them
250
+ max_len = max(x.shape[0] for x in new_input_embeds)
251
+ batch_size = len(new_input_embeds)
252
+
253
+ new_input_embeds_padded = []
254
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
255
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
256
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
257
+
258
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
259
+ cur_len = cur_new_embed.shape[0]
260
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
261
+ new_input_embeds_padded.append(torch.cat((
262
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
263
+ cur_new_embed
264
+ ), dim=0))
265
+ if cur_len > 0:
266
+ new_labels_padded[i, -cur_len:] = cur_new_labels
267
+ attention_mask[i, -cur_len:] = True
268
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
269
+ else:
270
+ new_input_embeds_padded.append(torch.cat((
271
+ cur_new_embed,
272
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
273
+ ), dim=0))
274
+ if cur_len > 0:
275
+ new_labels_padded[i, :cur_len] = cur_new_labels
276
+ attention_mask[i, :cur_len] = True
277
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
278
+
279
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
280
+
281
+ if _labels is None:
282
+ new_labels = None
283
+ else:
284
+ new_labels = new_labels_padded
285
+
286
+ if _attention_mask is None:
287
+ attention_mask = None
288
+ else:
289
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
290
+
291
+ if _position_ids is None:
292
+ position_ids = None
293
+
294
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
295
+
296
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
297
+ if model_args.mm_use_im_patch_token:
298
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
299
+ self.resize_token_embeddings(len(tokenizer))
300
+
301
+ if model_args.mm_use_im_start_end:
302
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
303
+ self.resize_token_embeddings(len(tokenizer))
304
+
305
+ if num_new_tokens > 0:
306
+ input_embeddings = self.get_input_embeddings().weight.data
307
+ output_embeddings = self.get_output_embeddings().weight.data
308
+
309
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
310
+ dim=0, keepdim=True)
311
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
312
+ dim=0, keepdim=True)
313
+
314
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
315
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
316
+
317
+ if model_args.tune_mm_mlp_adapter:
318
+ for p in self.get_input_embeddings().parameters():
319
+ p.requires_grad = True
320
+ for p in self.get_output_embeddings().parameters():
321
+ p.requires_grad = False
322
+
323
+ if model_args.pretrain_mm_mlp_adapter:
324
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
325
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
326
+ assert num_new_tokens == 2
327
+ if input_embeddings.shape == embed_tokens_weight.shape:
328
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
329
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
330
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
331
+ else:
332
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
333
+ elif model_args.mm_use_im_patch_token:
334
+ if model_args.tune_mm_mlp_adapter:
335
+ for p in self.get_input_embeddings().parameters():
336
+ p.requires_grad = False
337
+ for p in self.get_output_embeddings().parameters():
338
+ p.requires_grad = False
oryx/train/llama_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import warnings
3
+
4
+ import torch
5
+
6
+ import transformers
7
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8
+
9
+ try:
10
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11
+ except ImportError:
12
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13
+ from flash_attn.bert_padding import unpad_input, pad_input
14
+
15
+
16
+ def forward(
17
+ self,
18
+ hidden_states: torch.Tensor,
19
+ attention_mask: Optional[torch.Tensor] = None,
20
+ position_ids: Optional[torch.Tensor] = None,
21
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
22
+ output_attentions: bool = False,
23
+ use_cache: bool = False,
24
+ padding_mask: Optional[torch.Tensor] = None,
25
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
26
+ if output_attentions:
27
+ warnings.warn(
28
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
29
+ )
30
+
31
+ bsz, q_len, _ = hidden_states.size()
32
+
33
+ query_states = (
34
+ self.q_proj(hidden_states)
35
+ .view(bsz, q_len, self.num_heads, self.head_dim)
36
+ .transpose(1, 2)
37
+ )
38
+ key_states = (
39
+ self.k_proj(hidden_states)
40
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
41
+ .transpose(1, 2)
42
+ )
43
+ value_states = (
44
+ self.v_proj(hidden_states)
45
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
46
+ .transpose(1, 2)
47
+ ) # shape: (b, num_heads, s, head_dim)
48
+
49
+ kv_seq_len = key_states.shape[-2]
50
+ if past_key_value is not None:
51
+ kv_seq_len += past_key_value[0].shape[-2]
52
+
53
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
54
+ query_states, key_states = apply_rotary_pos_emb(
55
+ query_states, key_states, cos, sin, position_ids
56
+ )
57
+
58
+ if past_key_value is not None:
59
+ # reuse k, v
60
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
61
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
62
+
63
+ past_key_value = (key_states, value_states) if use_cache else None
64
+
65
+ # repeat k/v heads if n_kv_heads < n_heads
66
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
67
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
68
+
69
+ # Transform the data into the format required by flash attention
70
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
71
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
72
+ key_padding_mask = attention_mask
73
+
74
+ if key_padding_mask is None:
75
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
76
+ cu_q_lens = torch.arange(
77
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
78
+ )
79
+ max_s = q_len
80
+ output = flash_attn_unpadded_qkvpacked_func(
81
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
82
+ )
83
+ output = output.view(bsz, q_len, -1)
84
+ else:
85
+ qkv = qkv.reshape(bsz, q_len, -1)
86
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
87
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
88
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
89
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
90
+ )
91
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
92
+ output = pad_input(output_unpad, indices, bsz, q_len)
93
+
94
+ return self.o_proj(output), None, past_key_value
95
+
96
+
97
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
98
+ # requires the attention mask to be the same as the key_padding_mask
99
+ def _prepare_decoder_attention_mask(
100
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
101
+ ):
102
+ # [bsz, seq_len]
103
+ return attention_mask
104
+
105
+
106
+ def replace_llama_attn_with_flash_attn():
107
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
108
+ if cuda_major < 8:
109
+ warnings.warn(
110
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
111
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
112
+ )
113
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
114
+ _prepare_decoder_attention_mask
115
+ )
116
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
oryx/train/oryx_trainer.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ import importlib.metadata
6
+
7
+ from torch.utils.data import Sampler
8
+
9
+
10
+ from transformers import Trainer
11
+ from transformers.trainer import (
12
+ is_sagemaker_mp_enabled,
13
+ get_parameter_names,
14
+ has_length,
15
+ ALL_LAYERNORM_LAYERS,
16
+ logger,
17
+ )
18
+ from transformers.trainer_pt_utils import get_length_grouped_indices as get_length_grouped_indices_hf
19
+ from typing import List, Optional
20
+
21
+ from transformers.trainer_pt_utils import (
22
+ get_dataloader_sampler,
23
+ get_model_param_count,
24
+ get_parameter_names,
25
+ )
26
+
27
+ from transformers.training_args import ParallelMode
28
+ from transformers.utils import (
29
+ is_peft_available,
30
+ is_accelerate_available,
31
+ is_sagemaker_mp_enabled,
32
+ is_torch_xla_available,
33
+ )
34
+
35
+ from transformers.trainer_utils import (
36
+ HPSearchBackend,
37
+ TrainOutput,
38
+ has_length,
39
+ speed_metrics,
40
+ )
41
+
42
+ from packaging import version
43
+
44
+ from peft import PeftModel
45
+
46
+ TIME_STAMP = os.environ.get('TIME_STAMP', 'default_value')
47
+ BYTENAS = os.environ.get('BYTENAS', 'vl-research')
48
+
49
+ def maybe_zero_3(param, ignore_status=False, name=None):
50
+ from deepspeed import zero
51
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
52
+ if hasattr(param, "ds_id"):
53
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
54
+ if not ignore_status:
55
+ print(name, 'no ignore status')
56
+ with zero.GatheredParameters([param]):
57
+ param = param.data.detach().cpu().clone()
58
+ else:
59
+ param = param.detach().cpu().clone()
60
+ return param
61
+
62
+
63
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
64
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
65
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
66
+ return to_return
67
+
68
+
69
+ def split_to_even_chunks(indices, lengths, num_chunks):
70
+ """
71
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
72
+ """
73
+
74
+ if len(indices) % num_chunks != 0:
75
+ return [indices[i::num_chunks] for i in range(num_chunks)]
76
+
77
+ num_indices_per_chunk = len(indices) // num_chunks
78
+
79
+ chunks = [[] for _ in range(num_chunks)]
80
+ chunks_lengths = [0 for _ in range(num_chunks)]
81
+ for index in indices:
82
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
83
+ chunks[shortest_chunk].append(index)
84
+ chunks_lengths[shortest_chunk] += lengths[index]
85
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
86
+ chunks_lengths[shortest_chunk] = float("inf")
87
+
88
+ return chunks
89
+
90
+
91
+ def get_variable_length_grouped_indices(lengths, batch_size, world_size, megabatch_mult = 8, generator=None):
92
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
93
+ indices = torch.randperm(len(lengths), generator=generator)
94
+ sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True)
95
+ megabatch_size = world_size * batch_size * megabatch_mult
96
+ megabatches = [sorted_indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
97
+ megabatches = [sorted(megabatch, key=lambda i: indices[i], reverse=True) for megabatch in megabatches]
98
+ shuffled_indices = [i for megabatch in megabatches for i in megabatch]
99
+ world_batch_size = world_size * batch_size
100
+ batches = [shuffled_indices[i : i + world_batch_size] for i in range(0, len(lengths), world_batch_size)]
101
+ batch_indices = torch.randperm(len(batches), generator=generator)
102
+ batches = [batches[i] for i in batch_indices]
103
+
104
+ return [i for batch in batches for i in batch]
105
+
106
+
107
+ def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
108
+ """
109
+ Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
110
+ lengths. To do this, the indices are:
111
+
112
+ - randomly permuted
113
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
114
+ - reorder by length in each mega-batch
115
+
116
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
117
+ maximum length placed first, so that an OOM happens sooner rather than later.
118
+ """
119
+
120
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
121
+ assert all(l != 0 for l in lengths), "Should not have zero length."
122
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
123
+ # all samples are in the same modality
124
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
125
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
126
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
127
+
128
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
129
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
130
+ megabatch_size = world_size * batch_size
131
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
132
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
133
+
134
+ last_mm = mm_megabatches[-1]
135
+ last_lang = lang_megabatches[-1]
136
+ additional_batch = last_mm + last_lang
137
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
138
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
139
+ megabatches = [megabatches[i] for i in megabatch_indices]
140
+
141
+ if len(additional_batch) > 0:
142
+ megabatches.append(sorted(additional_batch))
143
+
144
+ return [i for megabatch in megabatches for i in megabatch]
145
+
146
+
147
+ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
148
+ """
149
+ Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
150
+ lengths. To do this, the indices are:
151
+
152
+ - randomly permuted
153
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
154
+ - reorder by length in each mega-batch
155
+
156
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
157
+ maximum length placed first, so that an OOM happens sooner rather than later.
158
+ """
159
+
160
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
161
+ indices = torch.randperm(len(lengths), generator=generator)
162
+ megabatch_size = world_size * batch_size
163
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
164
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
165
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
166
+
167
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
168
+
169
+
170
+ def get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=None):
171
+ indices = get_length_grouped_indices_hf(lengths, batch_size * world_size, generator=generator)
172
+
173
+ megabatch_size = world_size * batch_size
174
+ megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
175
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
176
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
177
+
178
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
179
+ batch_indices = torch.randperm(len(megabatches), generator=generator)
180
+ megabatches = [megabatches[i] for i in batch_indices]
181
+
182
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
183
+
184
+
185
+ def get_modality_length_grouped_indices_auto(lengths, batch_size, world_size, generator=None):
186
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
187
+ assert all(l != 0 for l in lengths), "Should not have zero length."
188
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
189
+ # all samples are in the same modality
190
+ return get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=generator)
191
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
192
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
193
+
194
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices_auto_single(mm_lengths, batch_size, world_size, generator=None)]
195
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices_auto_single(lang_lengths, batch_size, world_size, generator=None)]
196
+ megabatch_size = world_size * batch_size
197
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
198
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
199
+
200
+ last_mm = mm_megabatches[-1]
201
+ last_lang = lang_megabatches[-1]
202
+ additional_batch = last_mm + last_lang
203
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
204
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
205
+ megabatches = [megabatches[i] for i in megabatch_indices]
206
+
207
+ if len(additional_batch) > 0:
208
+ megabatches.append(sorted(additional_batch))
209
+
210
+ return [i for megabatch in megabatches for i in megabatch]
211
+
212
+
213
+ class LengthGroupedSampler(Sampler):
214
+ r"""
215
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
216
+ keeping a bit of randomness.
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ batch_size: int,
222
+ world_size: int,
223
+ lengths: Optional[List[int]] = None,
224
+ generator=None,
225
+ variable_length: bool = False,
226
+ group_by_modality: bool = False,
227
+ group_by_modality_auto: bool = False,
228
+ ):
229
+ if lengths is None:
230
+ raise ValueError("Lengths must be provided.")
231
+
232
+ self.batch_size = batch_size
233
+ self.world_size = world_size
234
+ self.lengths = lengths
235
+ self.generator = generator
236
+ self.variable_length = variable_length
237
+ self.group_by_modality = group_by_modality
238
+ self.group_by_modality_auto = group_by_modality_auto
239
+
240
+ def __len__(self):
241
+ return len(self.lengths)
242
+
243
+ def __iter__(self):
244
+ if self.variable_length:
245
+ assert not self.group_by_modality, "Variable length grouping is not supported with modality grouping."
246
+ indices = get_variable_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
247
+ else:
248
+ if self.group_by_modality:
249
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
250
+ elif self.group_by_modality_auto:
251
+ indices = get_modality_length_grouped_indices_auto(self.lengths, self.batch_size, self.world_size, generator=self.generator)
252
+ else:
253
+ indices = get_length_grouped_indices_auto_single(self.lengths, self.batch_size, self.world_size, generator=self.generator)
254
+ return iter(indices)
255
+
256
+
257
+
258
+ def _is_peft_model(model):
259
+ if is_peft_available():
260
+ classes_to_check = (PeftModel,) if is_peft_available() else ()
261
+ # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
262
+ if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
263
+ from peft import PeftMixedModel
264
+
265
+ classes_to_check = (*classes_to_check, PeftMixedModel)
266
+ return isinstance(model, classes_to_check)
267
+ return False
268
+
269
+
270
+ TRAINER_STATE_NAME = "trainer_state.json"
271
+
272
+ class OryxTrainer(Trainer):
273
+
274
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
275
+ if self.train_dataset is None or not has_length(self.train_dataset):
276
+ return None
277
+
278
+ if self.args.group_by_length:
279
+ lengths = self.train_dataset.lengths
280
+ return LengthGroupedSampler(
281
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
282
+ self.args.train_batch_size,
283
+ # world_size=self.args.world_size,
284
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
285
+ lengths=lengths,
286
+ )
287
+ elif self.args.group_by_modality_length:
288
+ lengths = self.train_dataset.modality_lengths
289
+ return LengthGroupedSampler(
290
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
291
+ self.args.train_batch_size,
292
+ # world_size=self.args.world_size,
293
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
294
+ lengths=lengths,
295
+ group_by_modality=True,
296
+ )
297
+ elif self.args.group_by_modality_length_auto:
298
+ lengths = self.train_dataset.modality_lengths
299
+ return LengthGroupedSampler(
300
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
301
+ self.args.train_batch_size,
302
+ # world_size=self.args.world_size,
303
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
304
+ lengths=lengths,
305
+ group_by_modality_auto=True,
306
+ )
307
+ elif self.args.group_by_varlen:
308
+ lengths = self.train_dataset.lengths
309
+ return LengthGroupedSampler(
310
+ self.args.train_batch_size * self.args.gradient_accumulation_steps,
311
+ # self.args.train_batch_size, # TODO: seems that we should have gradient_accumulation_steps
312
+ # world_size=self.args.world_size,
313
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
314
+ lengths=lengths,
315
+ variable_length=True,
316
+ )
317
+ else:
318
+ return super()._get_train_sampler()
319
+
320
+ def create_optimizer(self):
321
+ """
322
+ Setup the optimizer.
323
+
324
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
325
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
326
+ """
327
+ if is_sagemaker_mp_enabled():
328
+ return super().create_optimizer()
329
+
330
+ opt_model = self.model
331
+
332
+ if self.optimizer is None:
333
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
334
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
335
+ lr_mapper = {}
336
+ if self.args.mm_projector_lr is not None:
337
+ lr_mapper['mm_projector'] = self.args.mm_projector_lr
338
+ if self.args.mm_vision_tower_lr is not None:
339
+ lr_mapper['vision_tower'] = self.args.mm_vision_tower_lr
340
+ if len(lr_mapper) > 0:
341
+ special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)]
342
+ optimizer_grouped_parameters = [
343
+ {
344
+ "params": [
345
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)
346
+ ],
347
+ "weight_decay": self.args.weight_decay,
348
+ },
349
+ {
350
+ "params": [
351
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)
352
+ ],
353
+ "weight_decay": 0.0,
354
+ },
355
+ ]
356
+ for module_keyword, lr in lr_mapper.items():
357
+ module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name]
358
+ optimizer_grouped_parameters.extend([
359
+ {
360
+ "params": [
361
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)
362
+ ],
363
+ "weight_decay": self.args.weight_decay,
364
+ "lr": lr,
365
+ },
366
+ {
367
+ "params": [
368
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)
369
+ ],
370
+ "weight_decay": 0.0,
371
+ "lr": lr,
372
+ },
373
+ ])
374
+ else:
375
+ optimizer_grouped_parameters = [
376
+ {
377
+ "params": [
378
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
379
+ ],
380
+ "weight_decay": self.args.weight_decay,
381
+ },
382
+ {
383
+ "params": [
384
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
385
+ ],
386
+ "weight_decay": 0.0,
387
+ },
388
+ ]
389
+
390
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
391
+
392
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
393
+ if optimizer_cls.__name__ == "Adam8bit":
394
+ import bitsandbytes
395
+
396
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
397
+
398
+ skipped = 0
399
+ for module in opt_model.modules():
400
+ if isinstance(module, nn.Embedding):
401
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
402
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
403
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
404
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
405
+ logger.info(f"skipped: {skipped/2**20}M params")
406
+
407
+ return self.optimizer
408
+
409
+ def _save_checkpoint(self, model, trial, metrics=None):
410
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
411
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
412
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
413
+
414
+ run_dir = self._get_output_dir(trial=trial)
415
+ output_dir = os.path.join(run_dir, checkpoint_folder)
416
+
417
+ # Only save Adapter
418
+ keys_to_match = ['mm_projector', 'vision_resampler']
419
+ if getattr(self.args, "use_im_start_end", False):
420
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
421
+
422
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
423
+
424
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
425
+ self.model.config.save_pretrained(output_dir)
426
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
427
+ else:
428
+ print("self.is_local_process_zero()",self.is_local_process_zero())
429
+ super(OryxTrainer, self)._save_checkpoint(model, trial, metrics)
430
+
431
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
432
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
433
+ pass
434
+ else:
435
+ super(OryxTrainer, self)._save(output_dir, state_dict)
oryx/train/train.py ADDED
@@ -0,0 +1,1686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import copy
4
+ from dataclasses import dataclass, field
5
+ import json
6
+ import logging
7
+ import pathlib
8
+ from typing import Dict, Optional, Sequence, List
9
+ import ast
10
+
11
+ import torch
12
+ import time
13
+ import random
14
+ import cv2
15
+
16
+ import transformers
17
+ import tokenizers
18
+
19
+ from oryx.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
20
+ from torch.utils.data import Dataset
21
+ from oryx.train.oryx_trainer import OryxTrainer
22
+
23
+ from oryx import conversation as conversation_lib
24
+ from oryx.model import *
25
+ from oryx.mm_utils import tokenizer_image_token, process_anyres_highres_image_genli, process_anyres_video_genli, process_anyres_video_genli_long
26
+
27
+ from PIL import Image
28
+ import io
29
+ import base64
30
+
31
+ from packaging import version
32
+
33
+ import numpy as np
34
+
35
+ from transformers import AutoConfig
36
+
37
+ import math
38
+ import copy
39
+
40
+
41
+ local_rank = None
42
+
43
+
44
+ def rank0_print(*args):
45
+ if local_rank == 0:
46
+ print(*args)
47
+
48
+ IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
49
+
50
+ @dataclass
51
+ class ModelArguments:
52
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
53
+ version: Optional[str] = field(default="v0")
54
+ freeze_backbone: bool = field(default=False)
55
+ tune_mm_mlp_adapter: bool = field(default=False)
56
+ tune_mm_vision_resampler: bool = field(default=False)
57
+ vision_tower: Optional[str] = field(default=None)
58
+ image_processor: Optional[str] = field(default=None)
59
+ unfreeze_mm_vision_tower: bool = field(default=False)
60
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
61
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
62
+ mm_projector_type: Optional[str] = field(default='linear')
63
+ mm_use_im_start_end: bool = field(default=False)
64
+ mm_use_im_patch_token: bool = field(default=True)
65
+ mm_vision_select_feature: Optional[str] = field(default="patch")
66
+ mm_resampler_type: Optional[str] = field(default=None)
67
+ mm_mask_drop_mode: str = field(default="fixed")
68
+ mm_mask_drop_skip_percentage: float = field(default=0.)
69
+ mm_mask_drop_ratio: float = field(default=0.25)
70
+ mm_mask_drop_ratio_upper: Optional[float] = field(default=None)
71
+ mm_mask_drop_ratio_lower: Optional[float] = field(default=None)
72
+
73
+ @dataclass
74
+ class DataArguments:
75
+ data_path: str = field(default=None,
76
+ metadata={"help": "Path to the training data."})
77
+ lazy_preprocess: bool = False
78
+ is_multimodal: bool = False
79
+ video_fps: Optional[int] = field(default=1)
80
+ frames_upbound: Optional[int] = field(default=0)
81
+
82
+ @dataclass
83
+ class TrainingArguments(transformers.TrainingArguments):
84
+ cache_dir: Optional[str] = field(default=None)
85
+ optim: str = field(default="adamw_torch")
86
+ remove_unused_columns: bool = field(default=False)
87
+ freeze_mm_mlp_adapter: bool = field(default=False)
88
+ freeze_mm_vision_resampler: bool = field(default=False)
89
+ mpt_attn_impl: Optional[str] = field(default="triton")
90
+ model_max_length: int = field(
91
+ default=512,
92
+ metadata={
93
+ "help":
94
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
95
+ },
96
+ )
97
+ double_quant: bool = field(
98
+ default=True,
99
+ metadata={"help": "Compress the quantization statistics through double quantization."}
100
+ )
101
+ quant_type: str = field(
102
+ default="nf4",
103
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
104
+ )
105
+ bits: int = field(
106
+ default=16,
107
+ metadata={"help": "How many bits to use."}
108
+ )
109
+ lora_enable: bool = False
110
+ lora_r: int = 64
111
+ lora_alpha: int = 16
112
+ lora_dropout: float = 0.05
113
+ lora_weight_path: str = ""
114
+ lora_bias: str = "none"
115
+ mm_projector_lr: Optional[float] = None
116
+ mm_vision_tower_lr: Optional[float] = None
117
+ group_by_varlen: bool = field(default=False)
118
+ group_by_modality_length: bool = field(default=False)
119
+ group_by_modality_length_auto: bool = field(default=False)
120
+ do_resize: bool = field(default=False)
121
+ do_center_crop: bool = field(default=False)
122
+
123
+
124
+ def maybe_zero_3(param, ignore_status=False, name=None):
125
+ from deepspeed import zero
126
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
127
+ if hasattr(param, "ds_id"):
128
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
129
+ if not ignore_status:
130
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
131
+ with zero.GatheredParameters([param]):
132
+ param = param.data.detach().cpu().clone()
133
+ else:
134
+ param = param.detach().cpu().clone()
135
+ return param
136
+
137
+
138
+ # Borrowed from peft.utils.get_peft_model_state_dict
139
+ def get_peft_state_maybe_zero_3(named_params, bias):
140
+ if bias == "none":
141
+ to_return = {k: t for k, t in named_params if "lora_" in k}
142
+ elif bias == "all":
143
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
144
+ elif bias == "lora_only":
145
+ to_return = {}
146
+ maybe_lora_bias = {}
147
+ lora_bias_names = set()
148
+ for k, t in named_params:
149
+ if "lora_" in k:
150
+ to_return[k] = t
151
+ bias_name = k.split("lora_")[0] + "bias"
152
+ lora_bias_names.add(bias_name)
153
+ elif "bias" in k:
154
+ maybe_lora_bias[k] = t
155
+ for k, t in maybe_lora_bias:
156
+ if bias_name in lora_bias_names:
157
+ to_return[bias_name] = t
158
+ else:
159
+ raise NotImplementedError
160
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
161
+ return to_return
162
+
163
+
164
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
165
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
166
+ if require_grad_only:
167
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
168
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
169
+ return to_return
170
+
171
+
172
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
173
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
174
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
175
+ return to_return
176
+
177
+
178
+ def find_all_linear_names(model):
179
+ cls = torch.nn.Linear
180
+ lora_module_names = set()
181
+ multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
182
+ for name, module in model.named_modules():
183
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
184
+ continue
185
+ if isinstance(module, cls):
186
+ names = name.split('.')
187
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
188
+
189
+
190
+ if 'lm_head' in lora_module_names: # needed for 16-bit
191
+ lora_module_names.remove('lm_head')
192
+ return list(lora_module_names)
193
+
194
+
195
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
196
+ output_dir: str):
197
+ """Collects the state dict and dump to disk."""
198
+
199
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False):
200
+ # Only save Adapter
201
+ keys_to_match = ['mm_projector', 'vision_resampler']
202
+ if getattr(trainer.args, "use_im_start_end", False):
203
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
204
+
205
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
206
+ trainer.model.config.save_pretrained(output_dir)
207
+
208
+ current_folder = output_dir.split('/')[-1]
209
+ parent_folder = os.path.dirname(output_dir)
210
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
211
+ if current_folder.startswith('checkpoint-'):
212
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
213
+ os.makedirs(mm_projector_folder, exist_ok=True)
214
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
215
+ else:
216
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
217
+ return
218
+
219
+ if trainer.deepspeed:
220
+ torch.cuda.synchronize()
221
+ trainer.save_model(output_dir)
222
+ return
223
+
224
+ state_dict = trainer.model.state_dict()
225
+ if trainer.args.should_save:
226
+ cpu_state_dict = {
227
+ key: value.cpu()
228
+ for key, value in state_dict.items()
229
+ }
230
+ del state_dict
231
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
232
+
233
+
234
+ def smart_tokenizer_and_embedding_resize(
235
+ special_tokens_dict: Dict,
236
+ tokenizer: transformers.PreTrainedTokenizer,
237
+ model: transformers.PreTrainedModel,
238
+ ):
239
+ """Resize tokenizer and embedding.
240
+
241
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
242
+ """
243
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
244
+ model.resize_token_embeddings(len(tokenizer))
245
+
246
+ if num_new_tokens > 0:
247
+ input_embeddings = model.get_input_embeddings().weight.data
248
+ output_embeddings = model.get_output_embeddings().weight.data
249
+
250
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
251
+ dim=0, keepdim=True)
252
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
253
+ dim=0, keepdim=True)
254
+
255
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
256
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
257
+
258
+
259
+ def _tokenize_fn(strings: Sequence[str],
260
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
261
+ """Tokenize a list of strings."""
262
+ tokenized_list = [
263
+ tokenizer(
264
+ text,
265
+ return_tensors="pt",
266
+ padding="longest",
267
+ max_length=tokenizer.model_max_length,
268
+ truncation=True,
269
+ ) for text in strings
270
+ ]
271
+ input_ids = labels = [
272
+ tokenized.input_ids[0] for tokenized in tokenized_list
273
+ ]
274
+ input_ids_lens = labels_lens = [
275
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
276
+ for tokenized in tokenized_list
277
+ ]
278
+ return dict(
279
+ input_ids=input_ids,
280
+ labels=labels,
281
+ input_ids_lens=input_ids_lens,
282
+ labels_lens=labels_lens,
283
+ )
284
+
285
+
286
+ def _mask_targets(target, tokenized_lens, speakers):
287
+ # cur_idx = 0
288
+ cur_idx = tokenized_lens[0]
289
+ tokenized_lens = tokenized_lens[1:]
290
+ target[:cur_idx] = IGNORE_INDEX
291
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
292
+ if speaker == "human":
293
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
294
+ cur_idx += tokenized_len
295
+
296
+
297
+ def _add_speaker_and_signal(header, source, get_conversation=True):
298
+ """Add speaker and start/end signal on each round."""
299
+ BEGIN_SIGNAL = "### "
300
+ END_SIGNAL = "\n"
301
+ conversation = header
302
+ for sentence in source:
303
+ from_str = sentence["from"]
304
+ if from_str.lower() == "human":
305
+ from_str = conversation_lib.default_conversation.roles[0]
306
+ elif from_str.lower() == "gpt":
307
+ from_str = conversation_lib.default_conversation.roles[1]
308
+ else:
309
+ from_str = 'unknown'
310
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
311
+ sentence["value"] + END_SIGNAL)
312
+ if get_conversation:
313
+ conversation += sentence["value"]
314
+ conversation += BEGIN_SIGNAL
315
+ return conversation
316
+
317
+
318
+ def preprocess_multimodal(
319
+ sources: Sequence[str],
320
+ data_args: DataArguments,
321
+ ) -> Dict:
322
+ is_multimodal = data_args.is_multimodal
323
+ if not is_multimodal:
324
+ return sources
325
+
326
+ for source in sources:
327
+ for sentence in source:
328
+ if DEFAULT_IMAGE_TOKEN in sentence['value'] and not sentence['value'].startswith(DEFAULT_IMAGE_TOKEN):
329
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
330
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
331
+ sentence['value'] = sentence['value'].strip()
332
+ if "mmtag" in conversation_lib.default_conversation.version:
333
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
334
+ replace_token = DEFAULT_IMAGE_TOKEN
335
+ if data_args.mm_use_im_start_end:
336
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
337
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
338
+
339
+ return sources
340
+
341
+ def preprocess_multimodal_movie(
342
+ sources: Sequence[str],
343
+ data_args: DataArguments,
344
+ video_inputs: str
345
+ ) -> Dict:
346
+ is_multimodal = data_args.is_multimodal
347
+ if not is_multimodal:
348
+ return sources
349
+
350
+ for source in sources:
351
+ for sentence in source:
352
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
353
+ prompt = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
354
+ replace_token = video_inputs
355
+ if data_args.mm_use_im_start_end:
356
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
357
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
358
+
359
+ return sources, prompt
360
+
361
+
362
+ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
363
+ roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
364
+
365
+ # im_start, im_end = tokenizer.additional_special_tokens_ids
366
+
367
+ im_start = tokenizer("<|im_start|>").input_ids[0]
368
+ im_end = tokenizer("<|im_end|>").input_ids[0]
369
+ nl_tokens = tokenizer("\n").input_ids
370
+ _system = tokenizer("system").input_ids + nl_tokens
371
+
372
+ # Apply prompt templates
373
+ input_ids, targets = [], []
374
+ for i, source in enumerate(sources):
375
+ if roles[source[0]["from"]] != roles["human"]:
376
+ source = source[1:]
377
+
378
+ input_id, target = [], []
379
+ system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
380
+ input_id += system
381
+ target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
382
+ assert len(input_id) == len(target)
383
+ for j, sentence in enumerate(source):
384
+ role = roles[sentence["from"]]
385
+ if has_image and "<image>" in sentence["value"]:
386
+ # assert sentence["value"].startswith("<image>"), print(sentence["value"])
387
+ if sentence["value"].startswith("<image>"):
388
+ _input_id = tokenizer(role).input_ids + nl_tokens + [IMAGE_TOKEN_INDEX] + nl_tokens + tokenizer(sentence["value"][len("<image>") :]).input_ids + [im_end] + nl_tokens
389
+ else:
390
+ _input_id = []
391
+ split_value = sentence["value"].split('<image>\n')
392
+ _input_id += tokenizer(role).input_ids + nl_tokens
393
+ for idx, cur_value in enumerate(split_value):
394
+ if idx == len(split_value) - 1:
395
+ _input_id = _input_id + tokenizer(cur_value).input_ids + [im_end] + nl_tokens
396
+ else:
397
+ _input_id = _input_id + tokenizer(cur_value).input_ids + [IMAGE_TOKEN_INDEX] + nl_tokens
398
+ # # add end of text token
399
+ # if PACK_SEQ > 0:
400
+ # if j > 0:
401
+ # _input_id = _end_of_text + _input_id
402
+ else:
403
+ _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
404
+ # # add end of text token for pure text data
405
+ # if PACK_SEQ > 0:
406
+ # if sentence['from'] == 'human' and j > 0:
407
+ # _input_id = _end_of_text + _input_id
408
+ input_id += _input_id
409
+ if role == "<|im_start|>user":
410
+ _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
411
+ elif role == "<|im_start|>assistant":
412
+ _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
413
+ else:
414
+ raise NotImplementedError
415
+ target += _target
416
+ assert len(input_id) == len(target)
417
+ # input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))
418
+ # target += [IGNORE_INDEX] * (max_len - len(target))
419
+ input_ids.append(input_id)
420
+ targets.append(target)
421
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
422
+ targets = torch.tensor(targets, dtype=torch.long)
423
+
424
+ return dict(
425
+ input_ids=input_ids, # tensor(bs x seq_len)
426
+ labels=targets, # tensor(bs x seq_len)
427
+ # attention_mask=input_ids.ne(tokenizer.pad_token_id), # tensor(bs x seq_len)
428
+ )
429
+
430
+ def preprocess_llama_2(
431
+ sources,
432
+ tokenizer: transformers.PreTrainedTokenizer,
433
+ has_image: bool = False
434
+ ) -> Dict:
435
+ conv = conversation_lib.default_conversation.copy()
436
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
437
+
438
+ # Apply prompt templates
439
+ conversations = []
440
+ for i, source in enumerate(sources):
441
+ if roles[source[0]["from"]] != conv.roles[0]:
442
+ # Skip the first one if it is not from human
443
+ source = source[1:]
444
+
445
+ conv.messages = []
446
+ for j, sentence in enumerate(source):
447
+ role = roles[sentence["from"]]
448
+ assert role == conv.roles[j % 2], f"{i}"
449
+ conv.append_message(role, sentence["value"])
450
+ conversations.append(conv.get_prompt())
451
+
452
+ # Tokenize conversations
453
+
454
+ if has_image:
455
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
456
+ else:
457
+ input_ids = tokenizer(
458
+ conversations,
459
+ return_tensors="pt",
460
+ padding="longest",
461
+ max_length=tokenizer.model_max_length,
462
+ truncation=True,
463
+ ).input_ids
464
+
465
+ targets = input_ids.clone()
466
+
467
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
468
+
469
+ # Mask targets
470
+ sep = "[/INST] "
471
+ for conversation, target in zip(conversations, targets):
472
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
473
+
474
+ rounds = conversation.split(conv.sep2)
475
+ cur_len = 1
476
+ target[:cur_len] = IGNORE_INDEX
477
+ for i, rou in enumerate(rounds):
478
+ if rou == "":
479
+ break
480
+
481
+ parts = rou.split(sep)
482
+ if len(parts) != 2:
483
+ break
484
+ parts[0] += sep
485
+
486
+ if has_image:
487
+ round_len = len(tokenizer_image_token(rou, tokenizer))
488
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
489
+ else:
490
+ round_len = len(tokenizer(rou).input_ids)
491
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
492
+
493
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
494
+
495
+ cur_len += round_len
496
+ target[cur_len:] = IGNORE_INDEX
497
+
498
+ if cur_len < tokenizer.model_max_length:
499
+ if cur_len != total_len:
500
+ target[:] = IGNORE_INDEX
501
+ print(
502
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
503
+ f" (ignored)"
504
+ )
505
+
506
+ return dict(
507
+ input_ids=input_ids,
508
+ labels=targets,
509
+ )
510
+
511
+ def preprocess_llama_3(
512
+ sources,
513
+ tokenizer: transformers.PreTrainedTokenizer,
514
+ has_image: bool = False
515
+ ) -> Dict:
516
+ conv = copy.deepcopy(conversation_lib.conv_llava_llama_3)
517
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
518
+
519
+ # Apply prompt templates
520
+ conversations = []
521
+ for i, source in enumerate(sources):
522
+ if roles[source[0]["from"]] != conv.roles[0]:
523
+ # Skip the first one if it is not from human
524
+ source = source[1:]
525
+
526
+ conv.messages = []
527
+ for j, sentence in enumerate(source):
528
+ role = roles[sentence["from"]]
529
+ assert role == conv.roles[j % 2], f"{i}"
530
+ conv.append_message(role, sentence["value"])
531
+ conversations.append(conv.get_prompt())
532
+
533
+ # Tokenize conversations
534
+
535
+ if has_image:
536
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
537
+ else:
538
+ input_ids = tokenizer(
539
+ conversations,
540
+ return_tensors="pt",
541
+ padding="longest",
542
+ max_length=tokenizer.model_max_length,
543
+ truncation=True,
544
+ ).input_ids
545
+
546
+ targets = input_ids.clone()
547
+
548
+ offset = 0 if input_ids[0][0] != tokenizer.bos_token_id else 1
549
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3
550
+ # Mask targets
551
+ # sep = conv.sep + conv.roles[1] + ":"
552
+ sep = '<|start_header_id|>assistant<|end_header_id|>\n\n'
553
+ sep2 = '<|start_header_id|>user<|end_header_id|>\n\n'
554
+ # Llama3 tokenizer has the token for whitespace
555
+ # Typically, the token after whitespace will be naturally encoded as one token with whitespace
556
+ # some special cases like ": 3" will be encoded as :, whitespace, 3; 3 tokens. Only in this case, the loss on whitespace will be calculated
557
+
558
+ for conversation, target in zip(conversations, targets):
559
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
560
+
561
+ rounds = conversation.split(sep2)
562
+ cur_len = 1
563
+ target[:cur_len] = IGNORE_INDEX
564
+
565
+ # process system prompt
566
+ try:
567
+ rounds[1] = rounds[0] + sep2 + rounds[1]
568
+ del rounds[0]
569
+ except:
570
+ print('no user found')
571
+ raise ValueError
572
+
573
+ # add user
574
+ for i, rou in enumerate(rounds):
575
+ if i != 0:
576
+ rounds[i] = sep2 + rou
577
+
578
+ for i, rou in enumerate(rounds):
579
+ if rou == "":
580
+ break
581
+
582
+ parts = rou.split(sep)
583
+ if len(parts) != 2:
584
+ break
585
+ # parts[0] += sep
586
+
587
+ # supervise assistant: from pp's report
588
+ parts[1] = sep + parts[1]
589
+ # parts[0] = parts[0] + sep
590
+
591
+ if has_image:
592
+ round_len = len(tokenizer_image_token(rou, tokenizer)) - offset
593
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
594
+ else:
595
+ round_len = len(tokenizer(rou).input_ids) - offset
596
+ instruction_len = len(tokenizer(parts[0]).input_ids)
597
+
598
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
599
+
600
+ cur_len += round_len + (1 - offset) #starting from index 0, then cur_len will not cover eos token
601
+
602
+ if cur_len < tokenizer.model_max_length:
603
+ if cur_len != total_len:
604
+ target[:] = IGNORE_INDEX
605
+ print(
606
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
607
+ f" (ignored)"
608
+ )
609
+
610
+ if input_ids[0][0] != tokenizer.bos_token_id:
611
+ input_ids = [torch.cat([torch.LongTensor([tokenizer.bos_token_id]), i]) for i in input_ids]
612
+ targets = [torch.cat([torch.LongTensor([IGNORE_INDEX]), i]) for i in targets]
613
+
614
+ return dict(
615
+ input_ids=input_ids,
616
+ labels=targets,
617
+ )
618
+
619
+
620
+ def preprocess_v1(
621
+ sources,
622
+ tokenizer: transformers.PreTrainedTokenizer,
623
+ has_image: bool = False
624
+ ) -> Dict:
625
+ conv = conversation_lib.default_conversation.copy()
626
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
627
+
628
+ # Apply prompt templates
629
+ conversations = []
630
+ for i, source in enumerate(sources):
631
+ if roles[source[0]["from"]] != conv.roles[0]:
632
+ # Skip the first one if it is not from human
633
+ source = source[1:]
634
+
635
+ conv.messages = []
636
+ for j, sentence in enumerate(source):
637
+ role = roles[sentence["from"]]
638
+ assert role == conv.roles[j % 2], f"{i}"
639
+ conv.append_message(role, sentence["value"])
640
+ conversations.append(conv.get_prompt())
641
+
642
+ # Tokenize conversations
643
+
644
+ if has_image:
645
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
646
+ else:
647
+ input_ids = tokenizer(
648
+ conversations,
649
+ return_tensors="pt",
650
+ padding="longest",
651
+ max_length=tokenizer.model_max_length,
652
+ truncation=True,
653
+ ).input_ids
654
+
655
+ targets = input_ids.clone()
656
+
657
+ if conv.sep_style == conversation_lib.SeparatorStyle.TWO:
658
+
659
+ # Mask targets
660
+ sep = conv.sep + conv.roles[1] + ": "
661
+ for conversation, target in zip(conversations, targets):
662
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
663
+
664
+ rounds = conversation.split(conv.sep2)
665
+ cur_len = 1
666
+ target[:cur_len] = IGNORE_INDEX
667
+ for i, rou in enumerate(rounds):
668
+ if rou == "":
669
+ break
670
+
671
+ parts = rou.split(sep)
672
+ if len(parts) != 2:
673
+ break
674
+ parts[0] += sep
675
+
676
+ if has_image:
677
+ round_len = len(tokenizer_image_token(rou, tokenizer))
678
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
679
+ else:
680
+ round_len = len(tokenizer(rou).input_ids)
681
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
682
+
683
+ if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
684
+ round_len -= 1
685
+ instruction_len -= 1
686
+
687
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
688
+
689
+ cur_len += round_len
690
+ target[cur_len:] = IGNORE_INDEX
691
+
692
+ if cur_len < tokenizer.model_max_length:
693
+ if cur_len != total_len:
694
+ target[:] = IGNORE_INDEX
695
+ print(
696
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
697
+ f" (ignored)"
698
+ )
699
+
700
+ elif conv.sep_style == conversation_lib.SeparatorStyle.QWEN2:
701
+ # Mask targets
702
+ sep = '<|im_start|>assistant\n'
703
+ for conversation, target in zip(conversations, targets):
704
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
705
+
706
+ raw_rounds = conversation.split('<|im_end|>\n')
707
+ cur_len = 0
708
+ rounds = []
709
+ now_str = ''
710
+ for rou in raw_rounds:
711
+ if len(rou) > 0:
712
+ rou = rou + '<|im_end|>\n'
713
+ if rou.startswith('<|endoftext|>'):
714
+ rounds[-1] = rounds[-1] + '<|endoftext|>'
715
+ rou = rou.replace('<|endoftext|>', '')
716
+ if len(rou.strip()) == 0:
717
+ continue
718
+ if '<|im_start|>assistant\n' in rou:
719
+ now_str += rou
720
+ rounds.append(now_str)
721
+ now_str = ''
722
+ else:
723
+ now_str += rou
724
+
725
+ for i, rou in enumerate(rounds):
726
+ if rou == "":
727
+ break
728
+
729
+ parts = rou.split(sep)
730
+ if len(parts) != 2:
731
+ break
732
+ parts[0] += sep
733
+
734
+ if has_image:
735
+ round_len = len(tokenizer_image_token(rou, tokenizer))
736
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
737
+ else:
738
+ round_len = len(tokenizer(rou).input_ids)
739
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
740
+
741
+ try:
742
+ is_legacy = tokenizer.legacy
743
+ except:
744
+ is_legacy = True
745
+
746
+ if i != 0 and not is_legacy and IS_TOKENIZER_GREATER_THAN_0_14:
747
+ round_len -= 1
748
+ instruction_len -= 1
749
+
750
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
751
+
752
+ cur_len += round_len
753
+ target[cur_len:] = IGNORE_INDEX
754
+
755
+ if cur_len < tokenizer.model_max_length:
756
+ if cur_len != total_len:
757
+ target[:] = IGNORE_INDEX
758
+ print(
759
+ f"WARNING: tokenization mismatch for QWEN2: {cur_len} vs. {total_len}."
760
+ f" (ignored)"
761
+ )
762
+
763
+ return dict(
764
+ input_ids=input_ids,
765
+ labels=targets,
766
+ )
767
+
768
+ def preprocess_imgsp_v1(
769
+ sources,
770
+ tokenizer: transformers.PreTrainedTokenizer,
771
+ has_image: bool = False,
772
+ img_token: str = '<image>',
773
+ refine_prompt: bool = False,
774
+ ) -> Dict:
775
+ conv = conversation_lib.default_conversation.copy()
776
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
777
+
778
+ # Apply prompt templates
779
+ conversations = []
780
+ guided_prompt = []
781
+ for i, source in enumerate(sources):
782
+ if roles[source[0]["from"]] != conv.roles[0]:
783
+ # Skip the first one if it is not from human
784
+ source = source[1:]
785
+
786
+ conv.messages = []
787
+ img_in_text = False
788
+ for j, sentence in enumerate(source):
789
+ role = roles[sentence["from"]]
790
+ assert role == conv.roles[j % 2], f"{i}"
791
+
792
+ # add guided prompt
793
+ if role==conv.roles[0]:
794
+ guided_sent = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, '').replace('\n', '')
795
+ if refine_prompt:
796
+ # only keep the useful part of the prompt
797
+ if '\n' in guided_sent:
798
+ for _sent in guided_sent.split('\n'):
799
+ if '?' in _sent:
800
+ guided_sent = _sent
801
+ break
802
+ guided_prompt.append(guided_sent)
803
+ # check if image token in text
804
+ if img_token in sentence["value"]:
805
+ img_in_text = True
806
+ # add image token to all sentence if multimoal input
807
+ if role==conv.roles[0] and img_in_text and img_token not in sentence["value"]:
808
+ # randomly add image token to the beginning or end of the sentence
809
+ if random.randint(0,1)==0:
810
+ img_conv = img_token + '\n' + sentence["value"]
811
+ else:
812
+ img_conv = sentence["value"] + '\n' + img_token
813
+
814
+ conv.append_message(role, img_conv)
815
+ else:
816
+ conv.append_message(role, sentence["value"])
817
+ conversations.append(conv.get_prompt())
818
+
819
+ # Tokenize conversations
820
+ if has_image:
821
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
822
+ else:
823
+ input_ids = tokenizer(
824
+ conversations,
825
+ return_tensors="pt",
826
+ padding="longest",
827
+ max_length=tokenizer.model_max_length,
828
+ truncation=True,
829
+ ).input_ids
830
+
831
+ targets = input_ids.clone()
832
+
833
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
834
+
835
+ # Mask targets
836
+ sep = conv.sep + conv.roles[1] + ": "
837
+ for conversation, target in zip(conversations, targets):
838
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
839
+
840
+ rounds = conversation.split(conv.sep2)
841
+ cur_len = 1
842
+ target[:cur_len] = IGNORE_INDEX
843
+ for i, rou in enumerate(rounds):
844
+ if rou == "":
845
+ break
846
+
847
+ parts = rou.split(sep)
848
+ if len(parts) != 2:
849
+ break
850
+ parts[0] += sep
851
+
852
+ if has_image:
853
+ round_len = len(tokenizer_image_token(rou, tokenizer))
854
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
855
+ else:
856
+ round_len = len(tokenizer(rou).input_ids)
857
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
858
+
859
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
860
+
861
+ cur_len += round_len
862
+ target[cur_len:] = IGNORE_INDEX
863
+
864
+ if cur_len < tokenizer.model_max_length:
865
+ if cur_len != total_len:
866
+ target[:] = IGNORE_INDEX
867
+ print(
868
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
869
+ f" (ignored)"
870
+ )
871
+ return dict(
872
+ input_ids=input_ids,
873
+ labels=targets,
874
+ prompt=guided_prompt,
875
+ )
876
+
877
+
878
+ def preprocess_mpt(
879
+ sources,
880
+ tokenizer: transformers.PreTrainedTokenizer,
881
+ has_image: bool = False
882
+ ) -> Dict:
883
+ conv = conversation_lib.default_conversation.copy()
884
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
885
+
886
+ # Apply prompt templates
887
+ conversations = []
888
+ for i, source in enumerate(sources):
889
+ if roles[source[0]["from"]] != conv.roles[0]:
890
+ # Skip the first one if it is not from human
891
+ source = source[1:]
892
+
893
+ conv.messages = []
894
+ for j, sentence in enumerate(source):
895
+ role = roles[sentence["from"]]
896
+ assert role == conv.roles[j % 2], f"{i}"
897
+ conv.append_message(role, sentence["value"])
898
+ conversations.append(conv.get_prompt())
899
+
900
+ # Tokenize conversations
901
+
902
+ if has_image:
903
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
904
+ else:
905
+ input_ids = tokenizer(
906
+ conversations,
907
+ return_tensors="pt",
908
+ padding="longest",
909
+ max_length=tokenizer.model_max_length,
910
+ truncation=True,
911
+ ).input_ids
912
+
913
+ targets = input_ids.clone()
914
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
915
+
916
+ # Mask targets
917
+ sep = conv.sep + conv.roles[1]
918
+ for conversation, target in zip(conversations, targets):
919
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
920
+
921
+ rounds = conversation.split(conv.sep)
922
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
923
+ for conv_idx in range(3, len(rounds), 2):
924
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
925
+ cur_len = 1
926
+ target[:cur_len] = IGNORE_INDEX
927
+ for i, rou in enumerate(re_rounds):
928
+ if rou == "":
929
+ break
930
+
931
+ parts = rou.split(sep)
932
+ if len(parts) != 2:
933
+ break
934
+ parts[0] += sep
935
+
936
+ if has_image:
937
+ round_len = len(tokenizer_image_token(rou, tokenizer))
938
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
939
+ else:
940
+ round_len = len(tokenizer(rou).input_ids)
941
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
942
+
943
+ if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
944
+ round_len += 1
945
+ instruction_len += 1
946
+
947
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
948
+
949
+ cur_len += round_len
950
+ target[cur_len:] = IGNORE_INDEX
951
+
952
+ if cur_len < tokenizer.model_max_length:
953
+ if cur_len != total_len:
954
+ target[:] = IGNORE_INDEX
955
+ print(
956
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
957
+ f"(#turns={len(re_rounds)} ignored)"
958
+ )
959
+
960
+ return dict(
961
+ input_ids=input_ids,
962
+ labels=targets,
963
+ )
964
+
965
+
966
+ def preprocess_plain(
967
+ sources: Sequence[str],
968
+ tokenizer: transformers.PreTrainedTokenizer,
969
+ ) -> Dict:
970
+ # add end signal and concatenate together
971
+ conversations = []
972
+ for source in sources:
973
+ assert len(source) == 2
974
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
975
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
976
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
977
+ conversations.append(conversation)
978
+ # tokenize conversations
979
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
980
+ targets = copy.deepcopy(input_ids)
981
+ for target, source in zip(targets, sources):
982
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
983
+ target[:tokenized_len] = IGNORE_INDEX
984
+
985
+ return dict(input_ids=input_ids, labels=targets)
986
+
987
+
988
+ def preprocess_plain_guided(
989
+ sources: Sequence[str],
990
+ tokenizer: transformers.PreTrainedTokenizer,
991
+ prompt: str = None,
992
+ ) -> Dict:
993
+ # add end signal and concatenate together
994
+ guided_prompt = []
995
+ conversations = []
996
+ for source in sources:
997
+ assert len(source) == 2
998
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
999
+ guided_prompt.append(source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').replace('\n', ''))
1000
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
1001
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
1002
+ conversations.append(conversation)
1003
+ # tokenize conversations
1004
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
1005
+ targets = copy.deepcopy(input_ids)
1006
+ for target, source in zip(targets, sources):
1007
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
1008
+ target[:tokenized_len] = IGNORE_INDEX
1009
+
1010
+
1011
+ def preprocess(
1012
+ sources: Sequence[str],
1013
+ tokenizer: transformers.PreTrainedTokenizer,
1014
+ has_image: bool = False,
1015
+ ) -> Dict:
1016
+ """
1017
+ Given a list of sources, each is a conversation list. This transform:
1018
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
1019
+ 2. Concatenate conversations together;
1020
+ 3. Tokenize the concatenated conversation;
1021
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
1022
+ """
1023
+ if conversation_lib.default_conversation.version.startswith("plain_guided"):
1024
+ return preprocess_plain_guided(sources, tokenizer)
1025
+ elif conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
1026
+ return preprocess_plain(sources, tokenizer)
1027
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
1028
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
1029
+ if conversation_lib.default_conversation.version.startswith("v1"):
1030
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
1031
+ if conversation_lib.default_conversation.version.startswith("llama_v3"): # for llama 3 tokenizer
1032
+ return preprocess_llama_3(sources, tokenizer, has_image=has_image)
1033
+ if conversation_lib.default_conversation.version == "qwen":
1034
+ return preprocess_qwen(sources, tokenizer, has_image=has_image)
1035
+ elif conversation_lib.default_conversation.version.startswith("imgsp"):
1036
+ return preprocess_imgsp_v1(sources, tokenizer, has_image=has_image)
1037
+ if conversation_lib.default_conversation.version == "mpt":
1038
+ return preprocess_mpt(sources, tokenizer, has_image=has_image)
1039
+ # add end signal and concatenate together
1040
+ conversations = []
1041
+ for source in sources:
1042
+ header = f"{conversation_lib.default_conversation.system}\n\n"
1043
+ conversation = _add_speaker_and_signal(header, source)
1044
+ conversations.append(conversation)
1045
+ # tokenize conversations
1046
+ def get_tokenize_len(prompts):
1047
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
1048
+
1049
+ if has_image:
1050
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
1051
+ else:
1052
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
1053
+ input_ids = conversations_tokenized["input_ids"]
1054
+
1055
+ targets = copy.deepcopy(input_ids)
1056
+ for target, source in zip(targets, sources):
1057
+ if has_image:
1058
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
1059
+ else:
1060
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
1061
+ speakers = [sentence["from"] for sentence in source]
1062
+ _mask_targets(target, tokenized_lens, speakers)
1063
+
1064
+ return dict(input_ids=input_ids, labels=targets)
1065
+
1066
+
1067
+ def read_image_patch(patch_info):
1068
+ if 'img_path' in patch_info.keys():
1069
+ image = Image.open(patch_info['img_path']).convert('RGB')
1070
+ else:
1071
+ image_file_name = patch_info['patch']
1072
+ start_bytes = int(patch_info['start_num'])
1073
+ file_size = int(patch_info['size'])
1074
+
1075
+ with open(image_file_name, 'rb') as f:
1076
+ f.seek(start_bytes)
1077
+ if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
1078
+ image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB")
1079
+ else:
1080
+ image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB")
1081
+ return image
1082
+
1083
+
1084
+ def read_video_patch(patch_info):
1085
+ if 'img_path' in patch_info.keys():
1086
+ image = Image.open(patch_info['img_path']).convert('RGB')
1087
+ else:
1088
+ image_file_name = patch_info['patch']
1089
+ start_bytes = int(patch_info['start_num'])
1090
+ file_size = patch_info['size'] # list of int
1091
+ total_file_size = 0
1092
+ images_all = []
1093
+ with open(image_file_name, 'rb') as f:
1094
+ for idx in range(len(file_size)):
1095
+ f.seek(start_bytes + total_file_size)
1096
+ if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
1097
+ image = Image.open(io.BytesIO(base64.b64decode(f.read(int(file_size[idx])).decode()))).convert("RGB")
1098
+ else:
1099
+ if 'sharegpt4o' in image_file_name or 'ShareGPT4Video/new_patch' in image_file_name or 'cinepile' in image_file_name or 'nextqa' in image_file_name or 'perceptiontest' in image_file_name:
1100
+ byte_str = io.BytesIO(f.read(int(file_size[idx])))
1101
+ array = np.frombuffer(byte_str.getvalue(), dtype=np.uint8)
1102
+ image = cv2.imdecode(array, cv2.IMREAD_COLOR)
1103
+ image = Image.fromarray(image)
1104
+ else:
1105
+ image = Image.open(io.BytesIO(f.read(int(file_size[idx])))).convert("RGB")
1106
+ images_all.append(image)
1107
+ total_file_size += int(file_size[idx])
1108
+ return images_all
1109
+
1110
+ class LazySupervisedDataset(Dataset):
1111
+ """Dataset for supervised fine-tuning."""
1112
+
1113
+ def __init__(self, data_path: str,
1114
+ tokenizer: transformers.PreTrainedTokenizer,
1115
+ data_args: DataArguments):
1116
+ super(LazySupervisedDataset, self).__init__()
1117
+ list_data_dict = json.load(open(data_path, "r"))
1118
+
1119
+ rank0_print("Formatting inputs...Skip in lazy mode")
1120
+ self.tokenizer = tokenizer
1121
+ self.list_data_dict = list_data_dict
1122
+ self.data_args = data_args
1123
+
1124
+ # if PRETRAIN:
1125
+ self.mapping_dict = json.load(open('/apdcephfs_jn/share_302244400/peterrao/nj3/data/llava/videodata/MovieNet/movienet_mapping.json', "r"))
1126
+ print('loadding mapping dict')
1127
+
1128
+ def __len__(self):
1129
+ return len(self.list_data_dict)
1130
+
1131
+ @property
1132
+ def lengths(self):
1133
+ length_list = []
1134
+ for sample in self.list_data_dict:
1135
+ img_tokens = 128 if 'image' in sample else 0
1136
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
1137
+ return length_list
1138
+
1139
+ @property
1140
+ def modality_lengths(self):
1141
+ length_list = []
1142
+ for sample in self.list_data_dict:
1143
+ try:
1144
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
1145
+ except:
1146
+ cur_len = 1
1147
+ cur_len = cur_len if ('image' in sample) or ('video' in sample) or ('video_long' in sample) else -cur_len
1148
+ length_list.append(cur_len)
1149
+ return length_list
1150
+
1151
+ def process_image(self, image_file):
1152
+ if type(image_file) is str:
1153
+ image = Image.open(image_file).convert('RGB')
1154
+ elif type(image_file) is dict:
1155
+ image = read_image_patch(image_file)
1156
+ else:
1157
+ raise ValueError(f"Unknown image file type: {type(image_file)}, {image_file}")
1158
+ image_size = image.size
1159
+ image, image_padded = process_anyres_highres_image_genli(image, self.data_args.image_processor)
1160
+
1161
+ return (image, image_padded), image_size, "image"
1162
+
1163
+ def process_video(self, video_file):
1164
+ video = read_video_patch(video_file)
1165
+ video_processed = []
1166
+
1167
+ cur_frames_upbound = self.data_args.frames_upbound
1168
+
1169
+ if cur_frames_upbound > 0:
1170
+ if len(video) > cur_frames_upbound:
1171
+ uniform_sampled_frames = np.linspace(0, len(video) - 1, cur_frames_upbound, dtype=int)
1172
+ frame_idx = uniform_sampled_frames.tolist()
1173
+ else:
1174
+ frame_idx = None
1175
+
1176
+ for idx, frame in enumerate(video):
1177
+ frame = process_anyres_video_genli(frame, self.data_args.image_processor)
1178
+ if frame_idx is not None and idx in frame_idx:
1179
+ video_processed.append(frame.unsqueeze(0))
1180
+ elif frame_idx is None:
1181
+ video_processed.append(frame.unsqueeze(0))
1182
+
1183
+ if frame_idx is None:
1184
+ frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
1185
+
1186
+ video_processed = torch.cat(video_processed, dim=0)
1187
+
1188
+ video_processed = (video_processed, video_processed)
1189
+ return (video_processed, (384, 384), "video"), frame_idx
1190
+
1191
+ def process_video_pretrain(self, video_file, target_idx):
1192
+ video = read_video_patch(video_file)
1193
+
1194
+ cur_frames_upbound = random.randint(self.data_args.frames_upbound * 3, self.data_args.frames_upbound * 4)
1195
+ video_processed = []
1196
+ if cur_frames_upbound > 0:
1197
+ if len(video) > cur_frames_upbound:
1198
+ uniform_sampled_frames = np.linspace(0, len(video) - 1, cur_frames_upbound, dtype=int)
1199
+ frame_idx = uniform_sampled_frames.tolist()
1200
+
1201
+ # process longer case
1202
+ target_idx_new = []
1203
+ target_frame = []
1204
+ if len(target_idx) == 1:
1205
+ target_idx_new.append(np.random.randint(0, len(uniform_sampled_frames)))
1206
+ target_frame.append(video[target_idx[0]])
1207
+ elif len(target_idx) == 2:
1208
+ num1 = np.random.randint(0, len(uniform_sampled_frames) // 2)
1209
+ num2 = np.random.randint(num1 + 1, len(uniform_sampled_frames))
1210
+ target_idx_new.append(num1)
1211
+ target_idx_new.append(num2)
1212
+ target_frame.append(video[target_idx[0]])
1213
+ target_frame.append(video[target_idx[1]])
1214
+
1215
+ else:
1216
+ frame_idx = None
1217
+ target_idx_new = target_idx
1218
+ target_frame = None
1219
+
1220
+ for idx, frame in enumerate(video):
1221
+ frame = process_anyres_video_genli_long(frame, self.data_args.image_processor)
1222
+
1223
+ if frame_idx is not None and idx in frame_idx:
1224
+ video_processed.append(frame.unsqueeze(0))
1225
+ elif frame_idx is None:
1226
+ video_processed.append(frame.unsqueeze(0))
1227
+
1228
+ # process longer case
1229
+ if target_frame is not None:
1230
+ for idx in target_idx_new:
1231
+ frame = target_frame.pop(0)
1232
+ frame = process_anyres_video_genli_long(frame, self.data_args.image_processor)
1233
+ video_processed[idx] = frame.unsqueeze(0)
1234
+
1235
+ if frame_idx is None:
1236
+ frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
1237
+
1238
+ video_processed = torch.cat(video_processed, dim=0)
1239
+
1240
+ video_processed = (video_processed, video_processed)
1241
+
1242
+ return (video_processed, (384, 384), "video_long"), target_idx_new
1243
+
1244
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
1245
+ # TODO: define number of retries somewhere else
1246
+ num_base_retries = 3
1247
+ num_final_retries = 300
1248
+ # try the current sample first
1249
+ for attempt_idx in range(num_base_retries):
1250
+ try:
1251
+ sample = self._get_item(i)
1252
+ return sample
1253
+ except Exception as e:
1254
+ # sleep 1s in case it is a cloud disk issue
1255
+ print(f'[try #{attempt_idx}] Failed to fetch sample {i}. Exception:', e)
1256
+ time.sleep(1)
1257
+
1258
+ # try other samples, in case it is file corruption issue
1259
+ for attempt_idx in range(num_base_retries):
1260
+ try:
1261
+ sample_idx = random.choice(range(len(self)))
1262
+ sample = self._get_item(sample_idx)
1263
+ return sample
1264
+ except Exception as e:
1265
+ # no need to sleep
1266
+ print(f'[try other #{attempt_idx}] Failed to fetch sample {sample_idx}. Exception:', e)
1267
+ pass
1268
+
1269
+ # still fail, most likely to be path issue or cloud disk issue, retry the same sample for longer
1270
+ for attempt_idx in range(num_final_retries):
1271
+ try:
1272
+ sample = self._get_item(i)
1273
+ return sample
1274
+ except Exception as e:
1275
+ # sleep 1s in case it is a cloud disk issue
1276
+ print(f'[final try #{attempt_idx}] Failed to fetch sample {i}. Exception:', e)
1277
+ time.sleep(1)
1278
+
1279
+ # Finally raise exception on failing.
1280
+ assert False, "Failed to fetch sample."
1281
+
1282
+ def _get_item(self, i) -> Dict[str, torch.Tensor]:
1283
+ sources = self.list_data_dict[i]
1284
+ if isinstance(i, int):
1285
+ sources = [sources]
1286
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
1287
+
1288
+ if 'image' in sources[0]:
1289
+ image_file = self.list_data_dict[i]['image']
1290
+ if type(image_file) is list:
1291
+ image = [self.process_image(f) for f in image_file]
1292
+ else:
1293
+ image = [self.process_image(image_file)]
1294
+ num_frames = 0
1295
+ sources = preprocess_multimodal(
1296
+ copy.deepcopy([e["conversations"] for e in sources]),
1297
+ self.data_args
1298
+ )
1299
+ elif 'video' in sources[0]:
1300
+ video_file = self.list_data_dict[i]['video']
1301
+ video, _ = self.process_video(video_file)
1302
+ video = [video]
1303
+ num_frames = len(video[0][0])
1304
+ sources = preprocess_multimodal(
1305
+ copy.deepcopy([e["conversations"] for e in sources]),
1306
+ self.data_args)
1307
+
1308
+ elif 'video_long' in sources[0]:
1309
+ video_file = self.mapping_dict[self.list_data_dict[i]['video_long']]['video']
1310
+ video, target_idx = self.process_video_pretrain(video_file, self.list_data_dict[i]['idx'])
1311
+ video = [video]
1312
+ num_frames = len(video[0][0][0])
1313
+ question = sources[0]['question']
1314
+ answer = sources[0]['answer']
1315
+ if sources[0]['type'] == 'diff':
1316
+ question = question.replace('<idx1>', str(target_idx[0]))
1317
+ question = question.replace('<idx2>', str(target_idx[1]))
1318
+ elif sources[0]['type'] == 'caption':
1319
+ question = question.replace('<idx>', str(target_idx[0]))
1320
+ else:
1321
+ raise NotImplementedError
1322
+
1323
+ sources[0]['conversations'] = [{'from': 'human', 'value': f'<image>\nThis is a extremely long video with a total of {num_frames} frames sampled from the video. Please carefully read every given frame in this video, identifying the detailed contents in every frame. '+ question},
1324
+ {'from': 'gpt', 'value': answer}]
1325
+ sources = preprocess_multimodal(
1326
+ copy.deepcopy([e["conversations"] for e in sources]),
1327
+ self.data_args)
1328
+ else:
1329
+ sources = copy.deepcopy([e["conversations"] for e in sources])
1330
+
1331
+ has_image = ('image' in self.list_data_dict[i]) or ('video' in self.list_data_dict[i]) or ('video_long' in self.list_data_dict[i])
1332
+ data_dict = preprocess(
1333
+ sources,
1334
+ self.tokenizer,
1335
+ has_image=has_image)
1336
+
1337
+ if isinstance(i, int):
1338
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
1339
+ labels=data_dict["labels"][0])
1340
+
1341
+ # image exist in the data
1342
+ if 'image' in self.list_data_dict[i]:
1343
+ data_dict['image'] = image
1344
+ elif 'video' in self.list_data_dict[i]:
1345
+ data_dict['image'] = video
1346
+ elif 'video_long' in self.list_data_dict[i]:
1347
+ data_dict['image'] = video
1348
+ elif self.data_args.is_multimodal:
1349
+ # image does not exist in the data, but the model is multimodal
1350
+ crop_size = self.data_args.image_processor.crop_size
1351
+ data_dict['image'] = [
1352
+ (
1353
+ (torch.zeros(1, 3, crop_size['height'], crop_size['width']), torch.zeros(1, 3, crop_size['height'], crop_size['width'])),
1354
+ (crop_size['width'], crop_size['height']),
1355
+ "text"
1356
+ ),
1357
+ ]
1358
+ return data_dict
1359
+
1360
+
1361
+ @dataclass
1362
+ class DataCollatorForSupervisedDataset(object):
1363
+ """Collate examples for supervised fine-tuning."""
1364
+
1365
+ tokenizer: transformers.PreTrainedTokenizer
1366
+
1367
+ def pad_sequence(self, input_ids, batch_first, padding_value):
1368
+ if self.tokenizer.padding_side == "left":
1369
+ input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
1370
+ input_ids = torch.nn.utils.rnn.pad_sequence(
1371
+ input_ids,
1372
+ batch_first=batch_first,
1373
+ padding_value=padding_value)
1374
+ if self.tokenizer.padding_side == "left":
1375
+ input_ids = torch.flip(input_ids, [1])
1376
+ return input_ids
1377
+
1378
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
1379
+ # input_ids, labels = tuple([instance[key] for instance in instances]
1380
+ # for key in ("input_ids", "labels"))
1381
+ input_ids, labels = tuple([instance[key] for instance in instances]
1382
+ for key in ("input_ids", "labels"))
1383
+ input_ids = [_input_ids[:self.tokenizer.model_max_length] for _input_ids in input_ids]
1384
+ labels = [_labels[:self.tokenizer.model_max_length] for _labels in labels]
1385
+ if self.tokenizer.pad_token_id is None:
1386
+ if "qwen" in self.tokenizer.name_or_path.lower():
1387
+ print("Setting pad token to bos token for qwen model.")
1388
+ self.tokenizer.pad_token_id = 151643
1389
+ else:
1390
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # FIXME: this could only be triggered for llama3 model.
1391
+ input_ids = self.pad_sequence(
1392
+ input_ids,
1393
+ batch_first=True,
1394
+ padding_value=self.tokenizer.pad_token_id)
1395
+ labels = self.pad_sequence(labels,
1396
+ batch_first=True,
1397
+ padding_value=IGNORE_INDEX)
1398
+
1399
+ batch = dict(
1400
+ input_ids=input_ids,
1401
+ labels=labels,
1402
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id)
1403
+ )
1404
+
1405
+ if 'image' in instances[0]:
1406
+ images = [instance['image'] for instance in instances]
1407
+ batch['image_sizes'] = [im[1] for im_list in images for im in im_list]
1408
+ batch['modalities'] = [im[2] for im_list in images for im in im_list]
1409
+ images_lowres = [im[0][0] for im_list in images for im in im_list]
1410
+ images_highres = [im[0][1] for im_list in images for im in im_list]
1411
+ batch['images_highres'] = images_highres
1412
+ if all(x is not None and x.shape == images_lowres[0].shape for x in images_lowres):
1413
+ batch['images'] = torch.stack(images_lowres)
1414
+ else:
1415
+ batch['images'] = images_lowres
1416
+ return batch
1417
+
1418
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
1419
+ data_args) -> Dict:
1420
+ """Make dataset and collator for supervised fine-tuning."""
1421
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
1422
+ data_path=data_args.data_path,
1423
+ data_args=data_args)
1424
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
1425
+ return dict(train_dataset=train_dataset,
1426
+ eval_dataset=None,
1427
+ data_collator=data_collator)
1428
+
1429
+
1430
+ def train():
1431
+ global local_rank
1432
+
1433
+ parser = transformers.HfArgumentParser(
1434
+ (ModelArguments, DataArguments, TrainingArguments))
1435
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
1436
+ local_rank = training_args.local_rank
1437
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
1438
+
1439
+ bnb_model_from_pretrained_args = {}
1440
+ if training_args.bits in [4, 8]:
1441
+ from transformers import BitsAndBytesConfig
1442
+ bnb_model_from_pretrained_args.update(dict(
1443
+ device_map={"": training_args.device},
1444
+ load_in_4bit=training_args.bits == 4,
1445
+ load_in_8bit=training_args.bits == 8,
1446
+ quantization_config=BitsAndBytesConfig(
1447
+ load_in_4bit=training_args.bits == 4,
1448
+ load_in_8bit=training_args.bits == 8,
1449
+ llm_int8_threshold=6.0,
1450
+ llm_int8_has_fp16_weight=False,
1451
+ bnb_4bit_compute_dtype=compute_dtype,
1452
+ bnb_4bit_use_double_quant=training_args.double_quant,
1453
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
1454
+ )
1455
+ ))
1456
+
1457
+ if model_args.vision_tower is not None:
1458
+ print(model_args.vision_tower)
1459
+ if 'qwen' in model_args.model_name_or_path.lower():
1460
+
1461
+ if not model_args.pretrain_mm_mlp_adapter:
1462
+ cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path)
1463
+ overwrite_config = {}
1464
+ overwrite_config["mm_resampler_type"] = model_args.mm_resampler_type
1465
+
1466
+ print(f"Overwriting config with {overwrite_config}")
1467
+ for k, v in overwrite_config.items():
1468
+ setattr(cfg_pretrained, k, v)
1469
+
1470
+ model = OryxQwenForCausalLM.from_pretrained(
1471
+ model_args.model_name_or_path,
1472
+ config=cfg_pretrained,
1473
+ cache_dir=training_args.cache_dir,
1474
+ attn_implementation="flash_attention_2",
1475
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
1476
+ **bnb_model_from_pretrained_args
1477
+ )
1478
+ else:
1479
+ model = OryxQwenForCausalLM.from_pretrained(
1480
+ model_args.model_name_or_path,
1481
+ cache_dir=training_args.cache_dir,
1482
+ attn_implementation="flash_attention_2",
1483
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
1484
+ **bnb_model_from_pretrained_args
1485
+ )
1486
+
1487
+ else:
1488
+ # finetune from a image trained model
1489
+ # if not model_args.pretrain_mm_mlp_adapter:
1490
+ cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path)
1491
+ overwrite_config = {}
1492
+ overwrite_config["mm_resampler_type"] = model_args.mm_resampler_type
1493
+
1494
+ print(f"Overwriting config with {overwrite_config}")
1495
+ for k, v in overwrite_config.items():
1496
+ setattr(cfg_pretrained, k, v)
1497
+
1498
+ model = OryxLlamaForCausalLM.from_pretrained(
1499
+ model_args.model_name_or_path,
1500
+ config=cfg_pretrained,
1501
+ cache_dir=training_args.cache_dir,
1502
+ attn_implementation="flash_attention_2",
1503
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
1504
+ **bnb_model_from_pretrained_args
1505
+ )
1506
+
1507
+ else:
1508
+ model = transformers.LlamaForCausalLM.from_pretrained(
1509
+ model_args.model_name_or_path,
1510
+ cache_dir=training_args.cache_dir,
1511
+ attn_implementation="flash_attention_2",
1512
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
1513
+ **bnb_model_from_pretrained_args
1514
+ )
1515
+ model.config.use_cache = False
1516
+
1517
+ if model_args.freeze_backbone:
1518
+ model.model.requires_grad_(False)
1519
+
1520
+ if training_args.bits in [4, 8]:
1521
+ from peft import prepare_model_for_kbit_training
1522
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
1523
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
1524
+
1525
+ if training_args.gradient_checkpointing:
1526
+ if hasattr(model, "enable_input_require_grads"):
1527
+ model.enable_input_require_grads()
1528
+ else:
1529
+ def make_inputs_require_grad(module, input, output):
1530
+ output.requires_grad_(True)
1531
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
1532
+
1533
+ if training_args.lora_enable:
1534
+ from peft import LoraConfig, get_peft_model
1535
+ lora_config = LoraConfig(
1536
+ r=training_args.lora_r,
1537
+ lora_alpha=training_args.lora_alpha,
1538
+ target_modules=find_all_linear_names(model),
1539
+ lora_dropout=training_args.lora_dropout,
1540
+ bias=training_args.lora_bias,
1541
+ task_type="CAUSAL_LM",
1542
+ )
1543
+ if training_args.bits == 16:
1544
+ if training_args.bf16:
1545
+ model.to(torch.bfloat16)
1546
+ if training_args.fp16:
1547
+ model.to(torch.float16)
1548
+ rank0_print("Adding LoRA adapters...")
1549
+ model = get_peft_model(model, lora_config)
1550
+
1551
+ if "qwen" in model_args.model_name_or_path.lower():
1552
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
1553
+ model_args.model_name_or_path,
1554
+ cache_dir=training_args.cache_dir,
1555
+ model_max_length=training_args.model_max_length,
1556
+ padding_side="right")
1557
+ else:
1558
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
1559
+ model_args.model_name_or_path,
1560
+ cache_dir=training_args.cache_dir,
1561
+ model_max_length=training_args.model_max_length,
1562
+ padding_side="right",
1563
+ use_fast=False,
1564
+ )
1565
+
1566
+ if model_args.version == "v0":
1567
+ if tokenizer.pad_token is None:
1568
+ smart_tokenizer_and_embedding_resize(
1569
+ special_tokens_dict=dict(pad_token="[PAD]"),
1570
+ tokenizer=tokenizer,
1571
+ model=model,
1572
+ )
1573
+ elif model_args.version == "v0.5":
1574
+ tokenizer.pad_token = tokenizer.unk_token
1575
+ elif model_args.version == "llava_llama_3":
1576
+ tokenizer.pad_token = "<|reserved_special_token_0|>" # only for llama3
1577
+ conversation_lib.default_conversation = conversation_lib.conv_templates["llava_llama_3"]
1578
+ else:
1579
+ if 'llama-3' in model_args.model_name_or_path.lower():
1580
+ tokenizer.pad_token = "<|reserved_special_token_0|>"
1581
+ else:
1582
+ tokenizer.pad_token = tokenizer.unk_token
1583
+ if model_args.version in conversation_lib.conv_templates:
1584
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
1585
+ else:
1586
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
1587
+
1588
+ if model_args.vision_tower is not None:
1589
+ model.get_model().initialize_vision_modules(
1590
+ model_args=model_args,
1591
+ fsdp=training_args.fsdp
1592
+ )
1593
+
1594
+ vision_tower = model.get_vision_tower()
1595
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
1596
+
1597
+ vision_tower.image_processor.do_resize = training_args.do_resize
1598
+ vision_tower.image_processor.do_center_crop = training_args.do_center_crop
1599
+
1600
+ data_args.image_processor = vision_tower.image_processor
1601
+ data_args.is_multimodal = True
1602
+
1603
+ model.config.tokenizer_padding_side = tokenizer.padding_side
1604
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
1605
+
1606
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
1607
+ model.config.tune_mm_vision_resampler = training_args.tune_mm_vision_resampler = model_args.tune_mm_vision_resampler
1608
+ if model_args.tune_mm_mlp_adapter or model_args.tune_mm_vision_resampler:
1609
+ model.requires_grad_(False)
1610
+ if model_args.tune_mm_mlp_adapter:
1611
+ for p in model.get_model().mm_projector.parameters():
1612
+ p.requires_grad = True
1613
+ if model_args.tune_mm_vision_resampler:
1614
+ for p in model.get_model().vision_resampler.parameters():
1615
+ p.requires_grad = True
1616
+
1617
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
1618
+ if training_args.freeze_mm_mlp_adapter:
1619
+ for p in model.get_model().mm_projector.parameters():
1620
+ p.requires_grad = False
1621
+
1622
+ model.config.freeze_mm_vision_resampler = training_args.freeze_mm_vision_resampler
1623
+ if training_args.freeze_mm_vision_resampler:
1624
+ for p in model.get_model().vision_resampler.parameters():
1625
+ p.requires_grad = False
1626
+
1627
+ model.config.unfreeze_mm_vision_tower = model_args.unfreeze_mm_vision_tower
1628
+ if model_args.unfreeze_mm_vision_tower:
1629
+ vision_tower.requires_grad_(True)
1630
+
1631
+ if training_args.bits in [4, 8]:
1632
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
1633
+
1634
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
1635
+ model.config.mm_projector_lr = training_args.mm_projector_lr
1636
+ model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr
1637
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
1638
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
1639
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
1640
+
1641
+ if training_args.bits in [4, 8]:
1642
+ from peft.tuners.lora import LoraLayer
1643
+ for name, module in model.named_modules():
1644
+ if isinstance(module, LoraLayer):
1645
+ if training_args.bf16:
1646
+ module = module.to(torch.bfloat16)
1647
+ if 'norm' in name:
1648
+ module = module.to(torch.float32)
1649
+ if 'lm_head' in name or 'embed_tokens' in name:
1650
+ if hasattr(module, 'weight'):
1651
+ if training_args.bf16 and module.weight.dtype == torch.float32:
1652
+ module = module.to(torch.bfloat16)
1653
+
1654
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
1655
+ data_args=data_args)
1656
+ trainer = OryxTrainer(model=model,
1657
+ tokenizer=tokenizer,
1658
+ args=training_args,
1659
+ **data_module)
1660
+
1661
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
1662
+ trainer.train(resume_from_checkpoint=True)
1663
+ else:
1664
+ trainer.train()
1665
+ trainer.save_state()
1666
+
1667
+ model.config.use_cache = True
1668
+
1669
+ if training_args.lora_enable:
1670
+ state_dict = get_peft_state_maybe_zero_3(
1671
+ model.named_parameters(), training_args.lora_bias
1672
+ )
1673
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
1674
+ model.named_parameters()
1675
+ )
1676
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
1677
+ model.config.save_pretrained(training_args.output_dir)
1678
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
1679
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
1680
+ else:
1681
+ safe_save_model_for_hf_trainer(trainer=trainer,
1682
+ output_dir=training_args.output_dir)
1683
+
1684
+
1685
+ if __name__ == "__main__":
1686
+ train()
oryx/train/train_mem.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ from oryx.train.train import train
3
+
4
+ if __name__ == "__main__":
5
+ train()
oryx/utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from oryx.constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+
17
+
18
+ def rank0_print(*args):
19
+ if dist.is_initialized():
20
+ if dist.get_rank() == 0:
21
+ print(f"Rank {dist.get_rank()}: ", *args)
22
+ else:
23
+ print(*args)
24
+
25
+ def build_logger(logger_name, logger_filename):
26
+ global handler
27
+
28
+ formatter = logging.Formatter(
29
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
30
+ datefmt="%Y-%m-%d %H:%M:%S",
31
+ )
32
+
33
+ # Set the format of root handlers
34
+ if not logging.getLogger().handlers:
35
+ logging.basicConfig(level=logging.INFO)
36
+ logging.getLogger().handlers[0].setFormatter(formatter)
37
+
38
+ # Redirect stdout and stderr to loggers
39
+ stdout_logger = logging.getLogger("stdout")
40
+ stdout_logger.setLevel(logging.INFO)
41
+ sl = StreamToLogger(stdout_logger, logging.INFO)
42
+ sys.stdout = sl
43
+
44
+ stderr_logger = logging.getLogger("stderr")
45
+ stderr_logger.setLevel(logging.ERROR)
46
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
47
+ sys.stderr = sl
48
+
49
+ # Get logger
50
+ logger = logging.getLogger(logger_name)
51
+ logger.setLevel(logging.INFO)
52
+
53
+ # Add a file handler for all loggers
54
+ if handler is None:
55
+ os.makedirs(LOGDIR, exist_ok=True)
56
+ filename = os.path.join(LOGDIR, logger_filename)
57
+ handler = logging.handlers.TimedRotatingFileHandler(
58
+ filename, when='D', utc=True)
59
+ handler.setFormatter(formatter)
60
+
61
+ for name, item in logging.root.manager.loggerDict.items():
62
+ if isinstance(item, logging.Logger):
63
+ item.addHandler(handler)
64
+
65
+ return logger
66
+
67
+
68
+ class StreamToLogger(object):
69
+ """
70
+ Fake file-like stream object that redirects writes to a logger instance.
71
+ """
72
+ def __init__(self, logger, log_level=logging.INFO):
73
+ self.terminal = sys.stdout
74
+ self.logger = logger
75
+ self.log_level = log_level
76
+ self.linebuf = ''
77
+
78
+ def __getattr__(self, attr):
79
+ return getattr(self.terminal, attr)
80
+
81
+ def write(self, buf):
82
+ temp_linebuf = self.linebuf + buf
83
+ self.linebuf = ''
84
+ for line in temp_linebuf.splitlines(True):
85
+ # From the io.TextIOWrapper docs:
86
+ # On output, if newline is None, any '\n' characters written
87
+ # are translated to the system default line separator.
88
+ # By default sys.stdout.write() expects '\n' newlines and then
89
+ # translates them so this is still cross platform.
90
+ if line[-1] == '\n':
91
+ self.logger.log(self.log_level, line.rstrip())
92
+ else:
93
+ self.linebuf += line
94
+
95
+ def flush(self):
96
+ if self.linebuf != '':
97
+ self.logger.log(self.log_level, self.linebuf.rstrip())
98
+ self.linebuf = ''
99
+
100
+
101
+ def disable_torch_init():
102
+ """
103
+ Disable the redundant torch default initialization to accelerate model creation.
104
+ """
105
+ import torch
106
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
107
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
108
+
109
+
110
+ def violates_moderation(text):
111
+ """
112
+ Check whether the text violates OpenAI moderation API.
113
+ """
114
+ url = "https://api.openai.com/v1/moderations"
115
+ headers = {"Content-Type": "application/json",
116
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
117
+ text = text.replace("\n", "")
118
+ data = "{" + '"input": ' + f'"{text}"' + "}"
119
+ data = data.encode("utf-8")
120
+ try:
121
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
122
+ flagged = ret.json()["results"][0]["flagged"]
123
+ except requests.exceptions.RequestException as e:
124
+ flagged = False
125
+ except KeyError as e:
126
+ flagged = False
127
+
128
+ return flagged
129
+
130
+
131
+ def pretty_print_semaphore(semaphore):
132
+ if semaphore is None:
133
+ return "None"
134
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"