xzl12306 commited on
Commit
d6bc023
1 Parent(s): a53e26c

first commit

Browse files
Files changed (47) hide show
  1. .gitignore +59 -0
  2. app.py +497 -0
  3. images/college.png +0 -0
  4. images/demo.png +0 -0
  5. images/diseases.png +0 -0
  6. images/immigrants.png +0 -0
  7. images/market.png +0 -0
  8. images/sails.png +0 -0
  9. pyproject.toml +37 -0
  10. requirements.txt +23 -0
  11. scripts/calculate_metric.py +72 -0
  12. scripts/merge_jsonl_sort.py +26 -0
  13. scripts/split_jsonl_dataset.py +40 -0
  14. tinychart/__init__.py +1 -0
  15. tinychart/arguments.py +77 -0
  16. tinychart/constants.py +13 -0
  17. tinychart/conversation.py +491 -0
  18. tinychart/data/__init__.py +0 -0
  19. tinychart/data/dataset.py +185 -0
  20. tinychart/data/preprocess/__init__.py +0 -0
  21. tinychart/data/preprocess/default.py +104 -0
  22. tinychart/data/preprocess/phi.py +100 -0
  23. tinychart/data/preprocess/v1.py +120 -0
  24. tinychart/data/process.py +83 -0
  25. tinychart/eval/__init__.py +0 -0
  26. tinychart/eval/eval_metric.py +159 -0
  27. tinychart/eval/eval_model.py +139 -0
  28. tinychart/eval/run_eval.py +72 -0
  29. tinychart/eval/run_tiny_chart.py +127 -0
  30. tinychart/mm_utils.py +111 -0
  31. tinychart/model/__init__.py +1 -0
  32. tinychart/model/builder.py +127 -0
  33. tinychart/model/language_model/__init__.py +0 -0
  34. tinychart/model/language_model/llava_phi.py +164 -0
  35. tinychart/model/language_model/phi/cache_utils.py +322 -0
  36. tinychart/model/language_model/phi/configuration_phi.py +186 -0
  37. tinychart/model/language_model/phi/convert_phi_weights_to_hf.py +175 -0
  38. tinychart/model/language_model/phi/modeling_attn_mask_utils.py +497 -0
  39. tinychart/model/language_model/phi/modeling_phi.py +1345 -0
  40. tinychart/model/language_model/phi/utils.py +1428 -0
  41. tinychart/model/llava_arch.py +383 -0
  42. tinychart/model/model_factory.py +64 -0
  43. tinychart/model/multimodal_encoder/builder.py +7 -0
  44. tinychart/model/multimodal_encoder/merge.py +239 -0
  45. tinychart/model/multimodal_encoder/siglip_encoder.py +751 -0
  46. tinychart/model/multimodal_projector/builder.py +215 -0
  47. tinychart/utils.py +134 -0
.gitignore ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # These are some examples of commonly ignored file patterns.
2
+ # You should customize this list as applicable to your project.
3
+ # Learn more about .gitignore:
4
+ # https://www.atlassian.com/git/tutorials/saving-changes/gitignore
5
+
6
+ # Node artifact files
7
+ node_modules/
8
+ dist/
9
+
10
+ # Compiled Java class files
11
+ *.class
12
+
13
+ # Compiled Python bytecode
14
+ *.py[cod]
15
+
16
+ # Log files
17
+ *.log
18
+
19
+ # Package files
20
+ *.jar
21
+
22
+ # Maven
23
+ target/
24
+
25
+ # JetBrains IDE
26
+ .idea/
27
+
28
+ # Unit test reports
29
+ TEST*.xml
30
+
31
+ # Generated by MacOS
32
+ .DS_Store
33
+
34
+ Thumbs.db
35
+
36
+ # Applications
37
+ *.app
38
+ *.exe
39
+ *.war
40
+
41
+ # Large media files
42
+ *.mp4
43
+ *.tiff
44
+ *.avi
45
+ *.flv
46
+ *.mov
47
+ *.wmv
48
+
49
+ .ipynb_checkpoints
50
+ __pycache__
51
+ *.egg-info
52
+ .vscode/*
53
+ .idea/*
54
+ playground/
55
+
56
+ checkpoints
57
+ .logs
58
+ core-*
59
+ */.nfs*
app.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import json
4
+ import os
5
+ import time
6
+ from threading import Thread
7
+ import logging
8
+ import gradio as gr
9
+ import torch
10
+
11
+ from tinychart.model.builder import load_pretrained_model
12
+ from tinychart.mm_utils import (
13
+ KeywordsStoppingCriteria,
14
+ load_image_from_base64,
15
+ process_images,
16
+ tokenizer_image_token,
17
+ get_model_name_from_path,
18
+ )
19
+ from PIL import Image
20
+ from io import BytesIO
21
+ import base64
22
+ import torch
23
+ from transformers import StoppingCriteria
24
+ from tinychart.constants import (
25
+ DEFAULT_IM_END_TOKEN,
26
+ DEFAULT_IM_START_TOKEN,
27
+ DEFAULT_IMAGE_TOKEN,
28
+ IMAGE_TOKEN_INDEX,
29
+ )
30
+ from tinychart.conversation import SeparatorStyle, conv_templates, default_conversation
31
+ from tinychart.eval.eval_metric import parse_model_output, evaluate_cmds
32
+
33
+ from transformers import TextIteratorStreamer
34
+ from pathlib import Path
35
+
36
+ DEFAULT_MODEL_PATH = "mPLUG/TinyChart-3B-768"
37
+ DEFAULT_MODEL_NAME = "TinyChart-3B-768"
38
+
39
+
40
+ block_css = """
41
+
42
+ #buttons button {
43
+ min-width: min(120px,100%);
44
+ }
45
+ """
46
+ title_markdown = """
47
+ # TinyChart: Efficient Chart Understanding with Visual Token Merging and Program-of-Thoughts Learning
48
+ 🔗 [[Code](https://github.com/X-PLUG/mPLUG-DocOwl/tree/main/TinyChart)] | 📚 [[Paper](https://arxiv.org/abs/2404.16635)]
49
+
50
+ **Note:**
51
+ 1. Currently, this demo only supports English chart understanding and may not work well with other languages.
52
+ 2. To use Program-of-Thoughts answer, please append "Answer with detailed steps." to your question.
53
+ """
54
+ tos_markdown = """
55
+ ### Terms of use
56
+ By using this service, users are required to agree to the following terms:
57
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
58
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
59
+ """
60
+
61
+ def regenerate(state, image_process_mode):
62
+ state.messages[-1][-1] = None
63
+ prev_human_msg = state.messages[-2]
64
+ if type(prev_human_msg[1]) in (tuple, list):
65
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
66
+ state.skip_next = False
67
+ return (state, state.to_gradio_chatbot(), "", None)
68
+
69
+
70
+ def clear_history():
71
+ state = default_conversation.copy()
72
+ return (state, state.to_gradio_chatbot(), "", None)
73
+
74
+
75
+ def add_text(state, text, image, image_process_mode):
76
+ if len(text) <= 0 and image is None:
77
+ state.skip_next = True
78
+ return (state, state.to_gradio_chatbot(), "", None)
79
+
80
+ text = text[:1536] # Hard cut-off
81
+ if image is not None:
82
+ text = text[:1200] # Hard cut-off for images
83
+ if "<image>" not in text:
84
+ # text = '<Image><image></Image>' + text
85
+ # text = text + "\n<image>"
86
+ text = "<image>\n"+text
87
+ text = (text, image, image_process_mode)
88
+ if len(state.get_images(return_pil=True)) > 0:
89
+ state = default_conversation.copy()
90
+ state.append_message(state.roles[0], text)
91
+ state.append_message(state.roles[1], None)
92
+ state.skip_next = False
93
+ return (state, state.to_gradio_chatbot(), "", None)
94
+
95
+
96
+ def load_demo():
97
+ state = default_conversation.copy()
98
+ return state
99
+
100
+ def is_float(value):
101
+ try:
102
+ float(value)
103
+ return True
104
+ except ValueError:
105
+ return False
106
+
107
+
108
+ @torch.inference_mode()
109
+ def get_response(params):
110
+ prompt = params["prompt"]
111
+ ori_prompt = prompt
112
+ images = params.get("images", None)
113
+ num_image_tokens = 0
114
+ if images is not None and len(images) > 0:
115
+ if len(images) > 0:
116
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
117
+ raise ValueError(
118
+ "Number of images does not match number of <image> tokens in prompt"
119
+ )
120
+
121
+ images = [load_image_from_base64(image) for image in images]
122
+ images = process_images(images, image_processor, model.config)
123
+
124
+ if type(images) is list:
125
+ images = [
126
+ image.to(model.device, dtype=torch.float16) for image in images
127
+ ]
128
+ else:
129
+ images = images.to(model.device, dtype=torch.float16)
130
+
131
+ replace_token = DEFAULT_IMAGE_TOKEN
132
+ if getattr(model.config, "mm_use_im_start_end", False):
133
+ replace_token = (
134
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
135
+ )
136
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
137
+
138
+ if hasattr(model.get_vision_tower().config, "tome_r"):
139
+ num_image_tokens = (
140
+ prompt.count(replace_token) * model.get_vision_tower().num_patches - 26 * model.get_vision_tower().config.tome_r
141
+ )
142
+ else:
143
+ num_image_tokens = (
144
+ prompt.count(replace_token) * model.get_vision_tower().num_patches
145
+ )
146
+ else:
147
+ images = None
148
+ image_args = {"images": images}
149
+ else:
150
+ images = None
151
+ image_args = {}
152
+
153
+ temperature = float(params.get("temperature", 1.0))
154
+ top_p = float(params.get("top_p", 1.0))
155
+ max_context_length = getattr(model.config, "max_position_embeddings", 2048)
156
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
157
+ stop_str = params.get("stop", None)
158
+ do_sample = True if temperature > 0.001 else False
159
+ logger.info(prompt)
160
+ input_ids = (
161
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
162
+ .unsqueeze(0)
163
+ .to(model.device)
164
+ )
165
+ keywords = [stop_str]
166
+
167
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
168
+ streamer = TextIteratorStreamer(
169
+ tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
170
+ )
171
+
172
+ max_new_tokens = min(
173
+ max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens
174
+ )
175
+
176
+ if max_new_tokens < 1:
177
+ yield json.dumps(
178
+ {
179
+ "text": ori_prompt
180
+ + "Exceeds max token length. Please start a new conversation, thanks.",
181
+ "error_code": 0,
182
+ }
183
+ ).encode() + b"\0"
184
+ return
185
+
186
+ # local inference
187
+ # BUG: If stopping_criteria is set, an error occur:
188
+ # RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0
189
+ generate_kwargs = dict(
190
+ inputs=input_ids,
191
+ do_sample=do_sample,
192
+ temperature=temperature,
193
+ top_p=top_p,
194
+ max_new_tokens=max_new_tokens,
195
+ streamer=streamer,
196
+ # stopping_criteria=[stopping_criteria],
197
+ use_cache=True,
198
+ **image_args,
199
+ )
200
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
201
+ thread.start()
202
+ logger.debug(ori_prompt)
203
+ logger.debug(generate_kwargs)
204
+ generated_text = ori_prompt
205
+ for new_text in streamer:
206
+ generated_text += new_text
207
+ if generated_text.endswith(stop_str):
208
+ generated_text = generated_text[: -len(stop_str)]
209
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode()
210
+
211
+ if '<step>' in generated_text and '</step>' in generated_text and '<comment>' in generated_text and '</comment>' in generated_text:
212
+ program = generated_text
213
+ program = '<comment>#' + program.split('ASSISTANT: <comment>#')[-1]
214
+ print(program)
215
+ try:
216
+ execuate_result = evaluate_cmds(parse_model_output(program))
217
+ if is_float(execuate_result):
218
+ execuate_result = round(float(execuate_result), 4)
219
+ generated_text += f'\n\nExecute result: {execuate_result}'
220
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
221
+ except:
222
+ # execuate_result = 'Failed.'
223
+ generated_text += f'\n\nIt seems the execution of the above code encounters bugs. I\'m trying to answer this question directly...'
224
+ ori_generated_text = generated_text + '\nDirect Answer: '
225
+
226
+ direct_prompt = ori_prompt.replace(' Answer with detailed steps.', '')
227
+ direct_input_ids = (
228
+ tokenizer_image_token(direct_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
229
+ .unsqueeze(0)
230
+ .to(model.device)
231
+ )
232
+
233
+ generate_kwargs = dict(
234
+ inputs=direct_input_ids,
235
+ do_sample=do_sample,
236
+ temperature=temperature,
237
+ top_p=top_p,
238
+ max_new_tokens=max_new_tokens,
239
+ streamer=streamer,
240
+ use_cache=True,
241
+ **image_args,
242
+ )
243
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
244
+ thread.start()
245
+ generated_text = ori_generated_text
246
+ for new_text in streamer:
247
+ generated_text += new_text
248
+ if generated_text.endswith(stop_str):
249
+ generated_text = generated_text[: -len(stop_str)]
250
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode()
251
+
252
+
253
+
254
+ def http_bot(state, temperature, top_p, max_new_tokens):
255
+ if state.skip_next:
256
+ # This generate call is skipped due to invalid inputs
257
+ yield (state, state.to_gradio_chatbot())
258
+ return
259
+
260
+ if len(state.messages) == state.offset + 2:
261
+ # First round of conversation
262
+
263
+ template_name = 'phi'
264
+
265
+ new_state = conv_templates[template_name].copy()
266
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
267
+ new_state.append_message(new_state.roles[1], None)
268
+ state = new_state
269
+
270
+ # Construct prompt
271
+ prompt = state.get_prompt()
272
+
273
+ all_images = state.get_images(return_pil=True)
274
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
275
+
276
+ # Make requests
277
+ # pload = {"model": model_name, "prompt": prompt, "temperature": float(temperature), "top_p": float(top_p),
278
+ # "max_new_tokens": min(int(max_new_tokens), 1536), "stop": (
279
+ # state.sep
280
+ # if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
281
+ # else state.sep2
282
+ # ), "images": state.get_images()}
283
+
284
+ pload = {
285
+ "model": model_name,
286
+ "prompt": prompt,
287
+ "temperature": float(temperature),
288
+ "top_p": float(top_p),
289
+ "max_new_tokens": min(int(max_new_tokens), 1536),
290
+ "stop": (
291
+ state.sep
292
+ if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
293
+ else state.sep2
294
+ ), "images": state.get_images()}
295
+
296
+ state.messages[-1][-1] = "▌"
297
+ yield (state, state.to_gradio_chatbot())
298
+
299
+ # for stream
300
+ output = get_response(pload)
301
+ for chunk in output:
302
+ if chunk:
303
+ data = json.loads(chunk.decode().replace('\x00',''))
304
+
305
+ if data["error_code"] == 0:
306
+ output = data["text"][len(prompt) :].strip()
307
+ state.messages[-1][-1] = output + "▌"
308
+ yield (state, state.to_gradio_chatbot())
309
+ else:
310
+ output = data["text"] + f" (error_code: {data['error_code']})"
311
+ state.messages[-1][-1] = output
312
+ yield (state, state.to_gradio_chatbot())
313
+ return
314
+ time.sleep(0.03)
315
+
316
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
317
+ yield (state, state.to_gradio_chatbot())
318
+
319
+
320
+ def build_demo():
321
+ textbox = gr.Textbox(
322
+ show_label=False, placeholder="Enter text and press ENTER", container=False
323
+ )
324
+ with gr.Blocks(title="TinyLLaVA", theme=gr.themes.Default(), css=block_css) as demo:
325
+ state = gr.State()
326
+ gr.Markdown(title_markdown)
327
+
328
+ with gr.Row():
329
+ with gr.Column(scale=5):
330
+ with gr.Row(elem_id="Model ID"):
331
+ gr.Dropdown(
332
+ choices=[DEFAULT_MODEL_NAME],
333
+ value=DEFAULT_MODEL_NAME,
334
+ interactive=True,
335
+ label="Model ID",
336
+ container=False,
337
+ )
338
+ imagebox = gr.Image(type="pil")
339
+ image_process_mode = gr.Radio(
340
+ ["Crop", "Resize", "Pad", "Default"],
341
+ value="Default",
342
+ label="Preprocess for non-square image",
343
+ visible=False,
344
+ )
345
+
346
+ # cur_dir = os.path.dirname(os.path.abspath(__file__))
347
+ cur_dir = Path(__file__).parent
348
+ gr.Examples(
349
+ examples=[
350
+ [
351
+ f"{cur_dir}/examples/market.png",
352
+ "What is the highest number of companies in the domestic market? Answer with detailed steps.",
353
+ ],
354
+ [
355
+ f"{cur_dir}/examples/college.png",
356
+ "What is the difference between Asians and Whites degree distribution? Answer with detailed steps."
357
+ ],
358
+ [
359
+ f"{cur_dir}/examples/immigrants.png",
360
+ "How many immigrants are there in 1931?",
361
+ ],
362
+ [
363
+ f"{cur_dir}/examples/sails.png",
364
+ "By how much percentage wholesale is less than retail? Answer with detailed steps."
365
+ ],
366
+ [
367
+ f"{cur_dir}/examples/diseases.png",
368
+ "Is the median value of all the bars greater than 30? Answer with detailed steps.",
369
+ ],
370
+ [
371
+ f"{cur_dir}/examples/economy.png",
372
+ "Which team has higher economy in 28 min?"
373
+ ],
374
+ [
375
+ f"{cur_dir}/examples/workers.png",
376
+ "Generate underlying data table for the chart."
377
+ ],
378
+ [
379
+ f"{cur_dir}/examples/sports.png",
380
+ "Create a brief summarization or extract key insights based on the chart image."
381
+ ],
382
+ [
383
+ f"{cur_dir}/examples/albums.png",
384
+ "Redraw the chart with Python code."
385
+ ]
386
+ ],
387
+ inputs=[imagebox, textbox],
388
+ )
389
+
390
+ with gr.Accordion("Parameters", open=False) as _:
391
+ temperature = gr.Slider(
392
+ minimum=0.0,
393
+ maximum=1.0,
394
+ value=0.1,
395
+ step=0.1,
396
+ interactive=True,
397
+ label="Temperature",
398
+ )
399
+ top_p = gr.Slider(
400
+ minimum=0.0,
401
+ maximum=1.0,
402
+ value=0.7,
403
+ step=0.1,
404
+ interactive=True,
405
+ label="Top P",
406
+ )
407
+ max_output_tokens = gr.Slider(
408
+ minimum=0,
409
+ maximum=1024,
410
+ value=1024,
411
+ step=64,
412
+ interactive=True,
413
+ label="Max output tokens",
414
+ )
415
+
416
+ with gr.Column(scale=8):
417
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Chatbot", height=550)
418
+ with gr.Row():
419
+ with gr.Column(scale=8):
420
+ textbox.render()
421
+ with gr.Column(scale=1, min_width=50):
422
+ submit_btn = gr.Button(value="Send", variant="primary")
423
+ with gr.Row(elem_id="buttons") as _:
424
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
425
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=True)
426
+
427
+ gr.Markdown(tos_markdown)
428
+
429
+ regenerate_btn.click(
430
+ regenerate,
431
+ [state, image_process_mode],
432
+ [state, chatbot, textbox, imagebox],
433
+ queue=False,
434
+ ).then(
435
+ http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot]
436
+ )
437
+
438
+ clear_btn.click(
439
+ clear_history, None, [state, chatbot, textbox, imagebox], queue=False
440
+ )
441
+
442
+ textbox.submit(
443
+ add_text,
444
+ [state, textbox, imagebox, image_process_mode],
445
+ [state, chatbot, textbox, imagebox],
446
+ queue=False,
447
+ ).then(
448
+ http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot]
449
+ )
450
+
451
+ submit_btn.click(
452
+ add_text,
453
+ [state, textbox, imagebox, image_process_mode],
454
+ [state, chatbot, textbox, imagebox],
455
+ queue=False,
456
+ ).then(
457
+ http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot]
458
+ )
459
+
460
+ demo.load(load_demo, None, [state], queue=False)
461
+ return demo
462
+
463
+
464
+ def parse_args():
465
+ parser = argparse.ArgumentParser()
466
+ parser.add_argument("--host", type=str, default=None)
467
+ parser.add_argument("--port", type=int, default=None)
468
+ parser.add_argument("--share", default=None)
469
+ parser.add_argument("--model-path", type=str, default=DEFAULT_MODEL_PATH)
470
+ parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL_NAME)
471
+ parser.add_argument("--load-8bit", action="store_true")
472
+ parser.add_argument("--load-4bit", action="store_true")
473
+ args = parser.parse_args()
474
+ return args
475
+
476
+
477
+ if __name__ == "__main__":
478
+ logging.basicConfig(
479
+ level=logging.INFO,
480
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
481
+ )
482
+ logger = logging.getLogger(__name__)
483
+ logger.info(gr.__version__)
484
+ args = parse_args()
485
+ model_name = args.model_name
486
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
487
+ model_path=args.model_path,
488
+ model_base=None,
489
+ model_name=args.model_name,
490
+ device="cpu",
491
+ load_4bit=args.load_4bit,
492
+ load_8bit=args.load_8bit
493
+ )
494
+
495
+ demo = build_demo()
496
+ demo.queue()
497
+ demo.launch(server_name=args.host, server_port=args.port, share=args.share)
images/college.png ADDED
images/demo.png ADDED
images/diseases.png ADDED
images/immigrants.png ADDED
images/market.png ADDED
images/sails.png ADDED
pyproject.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "tinyllava"
7
+ version = "1.0.0"
8
+ description = "A Framework of Small-scale Large Multimodal Models."
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "torch==2.0.1", "torchvision==0.15.2", "tiktoken",
17
+ "transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
18
+ "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0",
19
+ "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
20
+ "gradio==3.35.2", "gradio_client==0.2.9",
21
+ "requests", "httpx==0.24.0", "uvicorn", "fastapi",
22
+ "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ train = ["deepspeed==0.9.5", "ninja", "wandb"]
27
+
28
+ [project.urls]
29
+ "Homepage" = "https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/TinyChart"
30
+ "Bug Tracker" = "https://github.com/X-PLUG/mPLUG-DocOwl/issues"
31
+
32
+ [tool.setuptools.packages.find]
33
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
34
+
35
+ [tool.wheel]
36
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
37
+
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ tiktoken==0.5.2
4
+ transformers==4.37.2
5
+ tokenizers==0.15.1
6
+ sentencepiece==0.1.99
7
+ shortuuid==1.0.11
8
+ accelerate==0.21.0
9
+ peft==0.4.0
10
+ bitsandbytes==0.41.0
11
+ pydantic<2,>=1
12
+ markdown2[all]
13
+ numpy
14
+ scikit-learn==1.2.2
15
+ gradio==3.35.2
16
+ gradio_client==0.2.9
17
+ requests
18
+ httpx==0.24.0
19
+ uvicorn
20
+ fastapi
21
+ einops==0.6.1
22
+ einops-exts==0.0.4
23
+ timm==0.6.13
scripts/calculate_metric.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import pandas as pd
5
+ from collections import defaultdict
6
+ from tinychart.eval.eval_chartqa_metric import chartqa_evaluator, chartqapot_evaluator
7
+ from tinychart.eval.eval_chartqa_metric import chartqa_oracle_merger_evaluator, chartqa_rule_merger_evaluator
8
+
9
+ def read_jsonl(jsonl_path):
10
+ with open(jsonl_path, 'r') as f:
11
+ data = [json.loads(line) for line in f]
12
+ return data
13
+
14
+ def write_jsonl(data, jsonl_path):
15
+ with open(jsonl_path, 'w', encoding='utf-8') as f:
16
+ for item in data:
17
+ f.write(json.dumps(item) + '\n')
18
+
19
+ if __name__ == '__main__':
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--input', default='./output/')
22
+
23
+ args = parser.parse_args()
24
+
25
+ result_files = os.listdir(args.input)
26
+ result_files = [f for f in result_files if f.endswith('.jsonl')]
27
+ result_files.sort()
28
+ direct_result, cot_result = None, None
29
+
30
+ dataset2metric = defaultdict(float)
31
+ for result_file in result_files:
32
+ # print(result_file)
33
+ dataset_name = '.'.join(result_file.split('.')[:-1])
34
+ file = os.path.join(args.input, result_file)
35
+ result_data = read_jsonl(file)
36
+ if 'chartqa-' in dataset_name:
37
+ direct_result, direct_acc = chartqa_evaluator(result_data, key='model_answer')
38
+ write_jsonl(direct_result, file)
39
+ dataset2metric[dataset_name] = round(direct_acc * 100, 2)
40
+ print(f'Direct Accuracy: {direct_acc}')
41
+ elif 'chartqagptpot-' in dataset_name or 'chartqatemplatepot-' in dataset_name:
42
+ pot_result, pot_acc, error_rate = chartqapot_evaluator(result_data)
43
+ write_jsonl(pot_result, file)
44
+ dataset2metric[dataset_name] = round(pot_acc * 100, 2)
45
+ print(f'PoT Accuracy: {pot_acc}')
46
+ print(f'PoT Error Rate: {error_rate}')
47
+
48
+ if direct_result is not None and pot_result is not None:
49
+ print("Calculate merging direct and pot results with simple divider")
50
+ oracle_results, oracle_acc = chartqa_oracle_merger_evaluator(direct_result, pot_result)
51
+ dataset2metric['merged-oracle'] = round(oracle_acc * 100, 2)
52
+ print(f'Oracle Merged Accuracy: {oracle_acc}')
53
+ write_jsonl(oracle_results, os.path.join(args.input, 'merged-oracle.jsonl'))
54
+ rule_results, rule_acc = chartqa_rule_merger_evaluator(direct_result, pot_result)
55
+ dataset2metric['merged-rule'] = round(rule_acc * 100, 2)
56
+ print(f'Rule Merged Accuracy: {rule_acc}')
57
+ write_jsonl(rule_results, os.path.join(args.input, 'merged-rule.jsonl'))
58
+
59
+ # save metrics into tsv with key as the first row
60
+ df = pd.DataFrame(dataset2metric, index=[0])
61
+ # if there is a metrics.tsv exists, add one in the name to avoid overwrite
62
+ tsv_name = os.path.join(args.input, 'metrics.tsv')
63
+ if os.path.exists(tsv_name):
64
+ # avoid overwrite. if there is metrics.1.tsv, name it metrics.2.tsv...
65
+ i = 1
66
+ tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
67
+ while os.path.exists(tsv_name):
68
+ i += 1
69
+ tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
70
+ df.to_csv(tsv_name, sep='\t', index=False)
71
+ print(f'Metrics saved at: {tsv_name}')
72
+ print(df)
scripts/merge_jsonl_sort.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+
5
+ def read_jsonl(jsonl_path):
6
+ with open(jsonl_path, 'r') as f:
7
+ data = [json.loads(line) for line in f]
8
+ return data
9
+
10
+ def write_jsonl(data, jsonl_path):
11
+ with open(jsonl_path, 'w', encoding='utf-8') as f:
12
+ for item in data:
13
+ f.write(json.dumps(item) + '\n')
14
+
15
+ if __name__ == '__main__':
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--input', default='temp/')
18
+ parser.add_argument('--output', default='chartqa_val.json')
19
+
20
+ args = parser.parse_args()
21
+ files = os.listdir(args.input)
22
+ files.sort()
23
+ data = []
24
+ for file in files:
25
+ data.extend(read_jsonl(os.path.join(args.input, file)))
26
+ write_jsonl(data, args.output)
scripts/split_jsonl_dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import argparse
4
+ from collections import defaultdict
5
+
6
+ def read_jsonl(jsonl_path):
7
+ with open(jsonl_path, 'r') as f:
8
+ data = [json.loads(line) for line in f]
9
+ return data
10
+
11
+ def write_jsonl(data, jsonl_path):
12
+ with open(jsonl_path, 'w', encoding='utf-8') as f:
13
+ for item in data:
14
+ f.write(json.dumps(item) + '\n')
15
+
16
+ if __name__ == '__main__':
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--input', default='all.json')
19
+ parser.add_argument('--output', default='./output/')
20
+
21
+ args = parser.parse_args()
22
+
23
+ all_data = read_jsonl(args.input)
24
+
25
+ dataset2jsonl = defaultdict(list)
26
+
27
+ for item in all_data:
28
+ int_id = item['id'].split('_')[-1]
29
+ dataset_name_split = '_'.join(item['id'].split('_')[:-1])
30
+
31
+ if '-two_col-' in dataset_name_split:
32
+ dataset_name_split = dataset_name_split.replace('-two_col-', '-')
33
+ if '-multi_col-' in dataset_name_split:
34
+ dataset_name_split = dataset_name_split.replace('-multi_col-', '-')
35
+
36
+ dataset2jsonl[dataset_name_split].append(item)
37
+
38
+ for dataset_name_split, data in dataset2jsonl.items():
39
+ data.sort(key=lambda x: int(x['id'].split('_')[-1]))
40
+ write_jsonl(data, os.path.join(args.output, f'{dataset_name_split}.jsonl'))
tinychart/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from tinychart.model import *
tinychart/arguments.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, Optional, Sequence, List
3
+
4
+ import transformers
5
+
6
+ @dataclass
7
+ class ModelArguments:
8
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
9
+ version: Optional[str] = field(default="v0")
10
+ freeze_backbone: bool = field(default=False)
11
+ tune_mm_mlp_adapter: bool = field(default=False)
12
+ vision_tower: Optional[str] = field(default=None)
13
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
14
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
15
+ mm_projector_type: Optional[str] = field(default='linear')
16
+ mm_use_im_start_end: bool = field(default=False)
17
+ mm_use_im_patch_token: bool = field(default=True)
18
+ mm_patch_merge_type: Optional[str] = field(default='flat')
19
+ mm_vision_select_feature: Optional[str] = field(default="patch")
20
+ resampler_hidden_size: Optional[int] = field(default=768)
21
+ num_queries: Optional[int] = field(default=128)
22
+ num_resampler_layers: Optional[int] = field(default=3)
23
+ tune_vision_tower: bool = field(default=False)
24
+ tune_entire_model: bool = field(default=False)
25
+ tune_vit_from_layer: Optional[int] = field(default=100)
26
+ tune_embed_tokens: Optional[int] = field(default=False)
27
+
28
+
29
+ @dataclass
30
+ class DataArguments:
31
+ data_path: str = field(default=None,
32
+ metadata={"help": "Path to the training data."})
33
+ eval_data_path: str = field(default=None,
34
+ metadata={"help": "Path to the evaluation data."})
35
+ lazy_preprocess: bool = False
36
+ is_multimodal: bool = False
37
+ image_folder: Optional[str] = field(default=None)
38
+ image_aspect_ratio: str = 'square'
39
+
40
+
41
+ @dataclass
42
+ class TrainingArguments(transformers.TrainingArguments):
43
+ cache_dir: Optional[str] = field(default=None)
44
+ optim: str = field(default="adamw_torch")
45
+ remove_unused_columns: bool = field(default=False)
46
+ freeze_mm_mlp_adapter: bool = field(default=False)
47
+ mpt_attn_impl: Optional[str] = field(default="triton")
48
+ model_max_length: int = field(
49
+ default=512,
50
+ metadata={
51
+ "help":
52
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
53
+ },
54
+ )
55
+ double_quant: bool = field(
56
+ default=True,
57
+ metadata={"help": "Compress the quantization statistics through double quantization."}
58
+ )
59
+ quant_type: str = field(
60
+ default="nf4",
61
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
62
+ )
63
+ bits: int = field(
64
+ default=16,
65
+ metadata={"help": "How many bits to use."}
66
+ )
67
+ lora_enable: bool = False
68
+ lora_r: int = 64
69
+ lora_alpha: int = 16
70
+ lora_dropout: float = 0.05
71
+ lora_weight_path: str = ""
72
+ lora_bias: str = "none"
73
+ mm_projector_lr: Optional[float] = None
74
+ group_by_modality_length: bool = field(default=False)
75
+ vision_tower_lr: Optional[float] = None
76
+ tune_vit_posemb_only: bool = field(default=False)
77
+ tune_vit_only: bool = field(default=False)
tinychart/constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
tinychart/conversation.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+ TINY_LLAMA = auto()
14
+ QWEN_2 = auto()
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class Conversation:
19
+ """A class that keeps all conversation history."""
20
+ system: str
21
+ roles: List[str]
22
+ messages: List[List[str]]
23
+ offset: int
24
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
25
+ sep: str = "###"
26
+ sep2: str = None
27
+ version: str = "Unknown"
28
+
29
+ skip_next: bool = False
30
+
31
+ def get_prompt(self):
32
+ messages = self.messages
33
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
34
+ messages = self.messages.copy()
35
+ init_role, init_msg = messages[0].copy()
36
+ init_msg = init_msg[0].replace("<image>", "").strip()
37
+ if 'mmtag' in self.version:
38
+ messages[0] = (init_role, init_msg)
39
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
40
+ messages.insert(1, (self.roles[1], "Received."))
41
+ else:
42
+ messages[0] = (init_role, "<image>\n" + init_msg)
43
+
44
+ if self.sep_style == SeparatorStyle.SINGLE:
45
+ ret = self.system + self.sep
46
+ for role, message in messages:
47
+ if message:
48
+ if type(message) is tuple:
49
+ message, _, _ = message
50
+ ret += role + ": " + message + self.sep
51
+ else:
52
+ ret += role + ":"
53
+ elif self.sep_style == SeparatorStyle.TWO:
54
+ seps = [self.sep, self.sep2]
55
+ ret = self.system + seps[0]
56
+ for i, (role, message) in enumerate(messages):
57
+ if message:
58
+ if type(message) is tuple:
59
+ message, _, _ = message
60
+ ret += role + ": " + message + seps[i % 2]
61
+ else:
62
+ ret += role + ":"
63
+ elif self.sep_style == SeparatorStyle.MPT:
64
+ ret = self.system + self.sep
65
+ for role, message in messages:
66
+ if message:
67
+ if type(message) is tuple:
68
+ message, _, _ = message
69
+ ret += role + message + self.sep
70
+ else:
71
+ ret += role
72
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
73
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
74
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
75
+ ret = ""
76
+
77
+ for i, (role, message) in enumerate(messages):
78
+ if i == 0:
79
+ assert message, "first message should not be none"
80
+ assert role == self.roles[0], "first message should come from user"
81
+ if message:
82
+ if type(message) is tuple:
83
+ message, _, _ = message
84
+ if i == 0: message = wrap_sys(self.system) + message
85
+ if i % 2 == 0:
86
+ message = wrap_inst(message)
87
+ ret += self.sep + message
88
+ else:
89
+ ret += " " + message + " " + self.sep2
90
+ else:
91
+ ret += ""
92
+ ret = ret.lstrip(self.sep)
93
+ elif self.sep_style == SeparatorStyle.TINY_LLAMA:
94
+ sep = "</s>"
95
+ wrap_sys = lambda msg: f"<|system|>\n{msg}\n"
96
+ wrap_user = lambda msg: f"<|user|>\n{msg}\n"
97
+ wrap_assistant = lambda msg: f"<|assistant|>\n{msg}"
98
+ ret = ""
99
+
100
+ for i, (role, message) in enumerate(messages):
101
+ if i == 0:
102
+ assert message, "first message should not be none"
103
+ assert role == self.roles[0], "first message should come from user"
104
+ if message:
105
+ if type(message) is tuple:
106
+ message, _, _ = message
107
+ if i % 2 == 0:
108
+ message = wrap_user(message)
109
+ if i == 0:
110
+ message = wrap_sys(self.system) + message
111
+ ret += self.sep + message
112
+ else:
113
+ message = wrap_assistant(message) + self.sep2
114
+ ret += message
115
+ else:
116
+ ret += "<|assistant|>\n"
117
+ ret = ret.lstrip(self.sep)
118
+ elif self.sep_style == SeparatorStyle.QWEN_2:
119
+ ret = self.system + self.sep
120
+ for role, message in messages:
121
+ if message:
122
+ if type(message) is tuple:
123
+ message, _, _ = message
124
+ ret += role + message + self.sep
125
+ else:
126
+ ret += role
127
+ elif self.sep_style == SeparatorStyle.PLAIN:
128
+ seps = [self.sep, self.sep2]
129
+ ret = self.system
130
+ for i, (role, message) in enumerate(messages):
131
+ if message:
132
+ if type(message) is tuple:
133
+ message, _, _ = message
134
+ ret += message + seps[i % 2]
135
+ else:
136
+ ret += ""
137
+ else:
138
+ raise ValueError(f"Invalid style: {self.sep_style}")
139
+
140
+ return ret
141
+
142
+ def append_message(self, role, message):
143
+ self.messages.append([role, message])
144
+
145
+ def get_images(self, return_pil=False):
146
+ images = []
147
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
148
+ if i % 2 == 0:
149
+ if type(msg) is tuple:
150
+ import base64
151
+ from io import BytesIO
152
+ from PIL import Image
153
+ msg, image, image_process_mode = msg
154
+ if image_process_mode == "Pad":
155
+ def expand2square(pil_img, background_color=(122, 116, 104)):
156
+ width, height = pil_img.size
157
+ if width == height:
158
+ return pil_img
159
+ elif width > height:
160
+ result = Image.new(pil_img.mode, (width, width), background_color)
161
+ result.paste(pil_img, (0, (width - height) // 2))
162
+ return result
163
+ else:
164
+ result = Image.new(pil_img.mode, (height, height), background_color)
165
+ result.paste(pil_img, ((height - width) // 2, 0))
166
+ return result
167
+ image = expand2square(image)
168
+ elif image_process_mode in ["Default", "Crop"]:
169
+ pass
170
+ elif image_process_mode == "Resize":
171
+ image = image.resize((336, 336))
172
+ else:
173
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
174
+ max_hw, min_hw = max(image.size), min(image.size)
175
+ aspect_ratio = max_hw / min_hw
176
+ max_len, min_len = 800, 400
177
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
178
+ longest_edge = int(shortest_edge * aspect_ratio)
179
+ W, H = image.size
180
+ if longest_edge != max(image.size):
181
+ if H > W:
182
+ H, W = longest_edge, shortest_edge
183
+ else:
184
+ H, W = shortest_edge, longest_edge
185
+ image = image.resize((W, H))
186
+ if return_pil:
187
+ images.append(image)
188
+ else:
189
+ buffered = BytesIO()
190
+ image.save(buffered, format="PNG")
191
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
192
+ images.append(img_b64_str)
193
+ return images
194
+
195
+ def to_gradio_chatbot(self):
196
+ ret = []
197
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
198
+ if i % 2 == 0:
199
+ if type(msg) is tuple:
200
+ import base64
201
+ from io import BytesIO
202
+ msg, image, image_process_mode = msg
203
+ max_hw, min_hw = max(image.size), min(image.size)
204
+ aspect_ratio = max_hw / min_hw
205
+ max_len, min_len = 800, 400
206
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
207
+ longest_edge = int(shortest_edge * aspect_ratio)
208
+ W, H = image.size
209
+ if H > W:
210
+ H, W = longest_edge, shortest_edge
211
+ else:
212
+ H, W = shortest_edge, longest_edge
213
+ image = image.resize((W, H))
214
+ buffered = BytesIO()
215
+ image.save(buffered, format="JPEG")
216
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
217
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
218
+ msg = img_str + msg.replace('<image>', '').strip()
219
+ ret.append([msg, None])
220
+ else:
221
+ ret.append([msg, None])
222
+ else:
223
+ ret[-1][-1] = msg
224
+ return ret
225
+
226
+ def copy(self):
227
+ return Conversation(
228
+ system=self.system,
229
+ roles=self.roles,
230
+ messages=[[x, y] for x, y in self.messages],
231
+ offset=self.offset,
232
+ sep_style=self.sep_style,
233
+ sep=self.sep,
234
+ sep2=self.sep2,
235
+ version=self.version)
236
+
237
+ def dict(self):
238
+ if len(self.get_images()) > 0:
239
+ return {
240
+ "system": self.system,
241
+ "roles": self.roles,
242
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
243
+ "offset": self.offset,
244
+ "sep": self.sep,
245
+ "sep2": self.sep2,
246
+ }
247
+ return {
248
+ "system": self.system,
249
+ "roles": self.roles,
250
+ "messages": self.messages,
251
+ "offset": self.offset,
252
+ "sep": self.sep,
253
+ "sep2": self.sep2,
254
+ }
255
+
256
+
257
+ conv_vicuna_v0 = Conversation(
258
+ system="A chat between a curious human and an artificial intelligence assistant. "
259
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
260
+ roles=("Human", "Assistant"),
261
+ messages=(
262
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
263
+ ("Assistant",
264
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
265
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
266
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
267
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
268
+ "renewable and non-renewable energy sources:\n"
269
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
270
+ "energy sources are finite and will eventually run out.\n"
271
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
272
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
273
+ "and other negative effects.\n"
274
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
275
+ "have lower operational costs than non-renewable sources.\n"
276
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
277
+ "locations than non-renewable sources.\n"
278
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
279
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
280
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
281
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
282
+ ),
283
+ offset=2,
284
+ sep_style=SeparatorStyle.SINGLE,
285
+ sep="###",
286
+ )
287
+
288
+ conv_vicuna_v1 = Conversation(
289
+ system="A chat between a curious user and an artificial intelligence assistant. "
290
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
291
+ roles=("USER", "ASSISTANT"),
292
+ version="v1",
293
+ messages=(),
294
+ offset=0,
295
+ sep_style=SeparatorStyle.TWO,
296
+ sep=" ",
297
+ sep2="</s>",
298
+ )
299
+
300
+ conv_llama_2 = Conversation(
301
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
302
+
303
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
304
+ roles=("USER", "ASSISTANT"),
305
+ version="llama_v2",
306
+ messages=(),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.LLAMA_2,
309
+ sep="<s>",
310
+ sep2="</s>",
311
+ )
312
+
313
+ conv_llava_llama_2 = Conversation(
314
+ system="You are a helpful language and vision assistant. "
315
+ "You are able to understand the visual content that the user provides, "
316
+ "and assist the user with a variety of tasks using natural language.",
317
+ roles=("USER", "ASSISTANT"),
318
+ version="llama_v2",
319
+ messages=(),
320
+ offset=0,
321
+ sep_style=SeparatorStyle.LLAMA_2,
322
+ sep="<s>",
323
+ sep2="</s>",
324
+ )
325
+
326
+ conv_tiny_llava_tiny_llama = Conversation(
327
+ system="You are a helpful language and vision assistant. "
328
+ "You are able to understand the visual content that the user provides, "
329
+ "and assist the user with a variety of tasks using natural language.",
330
+ roles=("USER", "ASSISTANT"),
331
+ version="tiny_llama",
332
+ messages=(),
333
+ offset=0,
334
+ sep_style=SeparatorStyle.TINY_LLAMA,
335
+ sep="<s>",
336
+ sep2="</s>"
337
+ )
338
+
339
+
340
+ conv_mpt = Conversation(
341
+ system="""<|im_start|>system
342
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
343
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
344
+ version="mpt",
345
+ messages=(),
346
+ offset=0,
347
+ sep_style=SeparatorStyle.MPT,
348
+ sep="<|im_end|>",
349
+ )
350
+
351
+ conv_llava_plain = Conversation(
352
+ system="",
353
+ roles=("", ""),
354
+ messages=(
355
+ ),
356
+ version='plain',
357
+ offset=0,
358
+ sep_style=SeparatorStyle.PLAIN,
359
+ sep="\n",
360
+ )
361
+
362
+ conv_llava_v0 = Conversation(
363
+ system="A chat between a curious human and an artificial intelligence assistant. "
364
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
365
+ roles=("Human", "Assistant"),
366
+ messages=(
367
+ ),
368
+ offset=0,
369
+ sep_style=SeparatorStyle.SINGLE,
370
+ sep="###",
371
+ )
372
+
373
+ conv_llava_v0_mmtag = Conversation(
374
+ system="A chat between a curious user and an artificial intelligence assistant. "
375
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
376
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
377
+ roles=("Human", "Assistant"),
378
+ messages=(
379
+ ),
380
+ offset=0,
381
+ sep_style=SeparatorStyle.SINGLE,
382
+ sep="###",
383
+ version="v0_mmtag",
384
+ )
385
+
386
+ conv_llava_v1 = Conversation(
387
+ system="A chat between a curious human and an artificial intelligence assistant. "
388
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
389
+ roles=("USER", "ASSISTANT"),
390
+ version="v1",
391
+ messages=(),
392
+ offset=0,
393
+ sep_style=SeparatorStyle.TWO,
394
+ sep=" ",
395
+ sep2="</s>",
396
+ )
397
+
398
+ conv_llava_v1_mmtag = Conversation(
399
+ system="A chat between a curious user and an artificial intelligence assistant. "
400
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
401
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
402
+ roles=("USER", "ASSISTANT"),
403
+ messages=(),
404
+ offset=0,
405
+ sep_style=SeparatorStyle.TWO,
406
+ sep=" ",
407
+ sep2="</s>",
408
+ version="v1_mmtag",
409
+ )
410
+
411
+ conv_phi_v0 = Conversation(
412
+ system="A chat between a curious user and an artificial intelligence assistant. "
413
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
414
+ roles=("USER", "ASSISTANT"),
415
+ version="phi",
416
+ messages=(),
417
+ offset=0,
418
+ sep_style=SeparatorStyle.TWO,
419
+ sep=" ",
420
+ sep2="<|endoftext|>",
421
+ )
422
+
423
+ conv_stablelm = Conversation(
424
+ system="A chat between a curious user and an artificial intelligence assistant. "
425
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
426
+ roles=("USER", "ASSISTANT"),
427
+ version="stablelm",
428
+ messages=(),
429
+ offset=0,
430
+ sep_style=SeparatorStyle.TWO,
431
+ sep=" ",
432
+ sep2="<|endoftext|>",
433
+ )
434
+
435
+ conv_mistral_instruct = Conversation(
436
+ system="",
437
+ roles=("USER", "ASSISTANT"),
438
+ version="llama_v2",
439
+ messages=(),
440
+ offset=0,
441
+ sep_style=SeparatorStyle.LLAMA_2,
442
+ sep="",
443
+ sep2="</s>",
444
+ )
445
+
446
+ conv_chatml_direct = Conversation(
447
+ system="""<|im_start|>system
448
+ Answer the questions.""",
449
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
450
+ version="mpt",
451
+ messages=(),
452
+ offset=0,
453
+ sep_style=SeparatorStyle.MPT,
454
+ sep="<|im_end|>",
455
+ )
456
+
457
+ conv_qwen2 = Conversation(
458
+ system="<|im_start|>system\nYou are a helpful assistant",
459
+ roles=("<im_start>user\n", "<im_start>assistant\n"),
460
+ version="mpt",
461
+ messages=(),
462
+ offset=0,
463
+ sep_style=SeparatorStyle.MPT,
464
+ sep="<im_end>"
465
+ )
466
+
467
+ default_conversation = conv_vicuna_v1
468
+ conv_templates = {
469
+ "default": conv_vicuna_v0,
470
+ "v0": conv_vicuna_v0,
471
+ "v1": conv_vicuna_v1,
472
+ "vicuna_v1": conv_vicuna_v1,
473
+ "llama_2": conv_llama_2,
474
+
475
+ "plain": conv_llava_plain,
476
+ "v0_plain": conv_llava_plain,
477
+ "llava_v0": conv_llava_v0,
478
+ "v0_mmtag": conv_llava_v0_mmtag,
479
+ "llava_v1": conv_llava_v1,
480
+ "v1_mmtag": conv_llava_v1_mmtag,
481
+ "llava_llama_2": conv_llava_llama_2,
482
+
483
+ "mpt": conv_mpt,
484
+
485
+ "tiny_llama": conv_tiny_llava_tiny_llama,
486
+ "phi": conv_phi_v0,
487
+ }
488
+
489
+
490
+ if __name__ == "__main__":
491
+ print(default_conversation.get_prompt())
tinychart/data/__init__.py ADDED
File without changes
tinychart/data/dataset.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ import json
4
+ from typing import Dict, Sequence
5
+
6
+
7
+ import transformers
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from PIL import Image, ImageFile
11
+
12
+ from tinychart.arguments import *
13
+ from tinychart.utils import *
14
+ from tinychart.data.process import *
15
+ from tinychart.constants import *
16
+
17
+
18
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
19
+
20
+ class LazySupervisedDataset(Dataset):
21
+ """Dataset for supervised fine-tuning."""
22
+
23
+ def __init__(self, data_path: str,
24
+ tokenizer: transformers.PreTrainedTokenizer,
25
+ data_args: DataArguments):
26
+ super(LazySupervisedDataset, self).__init__()
27
+ list_data_dict = json.load(open(data_path, "r"))
28
+
29
+ rank0_print("Formatting inputs...Skip in lazy mode")
30
+ self.tokenizer = tokenizer
31
+ self.list_data_dict = list_data_dict
32
+ self.data_args = data_args
33
+
34
+ def __len__(self):
35
+ return len(self.list_data_dict)
36
+
37
+ @property
38
+ def lengths(self):
39
+ length_list = []
40
+ for sample in self.list_data_dict:
41
+ img_tokens = 128 if 'image' in sample else 0
42
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
43
+ return length_list
44
+
45
+ @property
46
+ def modality_lengths(self):
47
+ length_list = []
48
+ for sample in self.list_data_dict:
49
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
50
+ cur_len = cur_len if 'image' in sample else -cur_len
51
+ length_list.append(cur_len)
52
+ return length_list
53
+
54
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
55
+ sources = self.list_data_dict[i]
56
+ if isinstance(i, int):
57
+ sources = [sources]
58
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
59
+ if 'image' in sources[0]:
60
+ image_file = self.list_data_dict[i]['image']
61
+ image_folder = self.data_args.image_folder
62
+ processor = self.data_args.image_processor
63
+ image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
64
+ if self.data_args.image_aspect_ratio == 'pad':
65
+ def expand2square(pil_img, background_color):
66
+ width, height = pil_img.size
67
+ if width == height:
68
+ return pil_img
69
+ elif width > height:
70
+ result = Image.new(pil_img.mode, (width, width), background_color)
71
+ result.paste(pil_img, (0, (width - height) // 2))
72
+ return result
73
+ else:
74
+ result = Image.new(pil_img.mode, (height, height), background_color)
75
+ result.paste(pil_img, ((height - width) // 2, 0))
76
+ return result
77
+
78
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
79
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
80
+ else:
81
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
82
+ sources = preprocess_multimodal(
83
+ copy.deepcopy([e["conversations"] for e in sources]),
84
+ self.data_args)
85
+ else:
86
+ sources = copy.deepcopy([e["conversations"] for e in sources])
87
+ data_dict = preprocess(
88
+ sources,
89
+ self.tokenizer,
90
+ has_image=('image' in self.list_data_dict[i]))
91
+ if isinstance(i, int):
92
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
93
+ labels=data_dict["labels"][0])
94
+
95
+ # image exist in the data
96
+ if 'image' in self.list_data_dict[i]:
97
+ data_dict['image'] = image
98
+ elif self.data_args.is_multimodal:
99
+ # image does not exist in the data, but the model is multimodal
100
+ crop_size = self.data_args.image_processor.crop_size
101
+ data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
102
+ return data_dict
103
+
104
+
105
+ @dataclass
106
+ class DataCollatorForSupervisedDataset(object):
107
+ """Collate examples for supervised fine-tuning."""
108
+
109
+ tokenizer: transformers.PreTrainedTokenizer
110
+
111
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
112
+ input_ids, labels = tuple([instance[key] for instance in instances]
113
+ for key in ("input_ids", "labels"))
114
+ if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
115
+ for input_id in input_ids:
116
+ input_id[input_id == self.tokenizer.eos_token_id] = -300
117
+ input_ids = torch.nn.utils.rnn.pad_sequence(
118
+ input_ids,
119
+ batch_first=True,
120
+ padding_value=self.tokenizer.pad_token_id)
121
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
122
+ batch_first=True,
123
+ padding_value=IGNORE_INDEX)
124
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
125
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
126
+ labels = labels[:, :self.tokenizer.model_max_length]
127
+ # FIXME: This is a hack for handling phi and stablelm, as they have the same eos, pad and unk. We want the model
128
+ # FIXME: to predict the eos in the input ids, but we also use the id of eos to pad sequence, so we use a temp
129
+ # FIXME: eos id first, and convert them back.
130
+ if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
131
+ for input_id in input_ids:
132
+ input_id[input_id == -300] = self.tokenizer.eos_token_id
133
+
134
+ batch = dict(
135
+ input_ids=input_ids,
136
+ labels=labels,
137
+ attention_mask=attention_mask,
138
+ )
139
+
140
+ if 'image' in instances[0]:
141
+ images = [instance['image'] for instance in instances]
142
+ if all(x is not None and x.shape == images[0].shape for x in images):
143
+ batch['images'] = torch.stack(images)
144
+ else:
145
+ batch['images'] = images
146
+
147
+ if 'question' in instances[0]:
148
+ questions = [instance['question'] for instance in instances]
149
+ batch['questions'] = questions
150
+
151
+ if 'answer' in instances[0]:
152
+ answers = [instance['answer'] for instance in instances]
153
+ batch['answers'] = answers
154
+
155
+ return batch
156
+
157
+
158
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
159
+ data_args) -> Dict:
160
+ """Make dataset and collator for supervised fine-tuning."""
161
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
162
+ data_path=data_args.data_path,
163
+ data_args=data_args)
164
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
165
+ return dict(train_dataset=train_dataset,
166
+ eval_dataset=None,
167
+ data_collator=data_collator)
168
+
169
+ def make_supervised_data_module_with_eval(tokenizer: transformers.PreTrainedTokenizer,
170
+ data_args) -> Dict:
171
+ """Make dataset and collator for supervised fine-tuning."""
172
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
173
+ data_path=data_args.data_path,
174
+ data_args=data_args)
175
+ if data_args.eval_data_path is None or data_args.eval_data_path == "":
176
+ print('Evaluation dataset not specified, skipping...')
177
+ eval_dataset = None
178
+ else:
179
+ eval_dataset = LazySupervisedDataset(tokenizer=tokenizer,
180
+ data_path=data_args.eval_data_path,
181
+ data_args=data_args)
182
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
183
+ return dict(train_dataset=train_dataset,
184
+ eval_dataset=eval_dataset,
185
+ data_collator=data_collator)
tinychart/data/preprocess/__init__.py ADDED
File without changes
tinychart/data/preprocess/default.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Sequence, List
2
+ import copy
3
+
4
+ import transformers
5
+ import torch
6
+
7
+ from tinychart.data.process import register_preprocess
8
+ from tinychart.mm_utils import tokenizer_image_token
9
+ from tinychart import conversation as conversation_lib
10
+ from tinychart.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
11
+ DEFAULT_IM_END_TOKEN
12
+
13
+
14
+ @register_preprocess('default')
15
+ def preprocess_default(
16
+ sources: Sequence[str],
17
+ tokenizer: transformers.PreTrainedTokenizer,
18
+ has_image: bool = False
19
+ ) -> Dict:
20
+ conversations = []
21
+ for source in sources:
22
+ header = f"{conversation_lib.default_conversation.system}\n\n"
23
+ conversation = _add_speaker_and_signal(header, source)
24
+ conversations.append(conversation)
25
+
26
+ # tokenize conversations
27
+ def get_tokenize_len(prompts):
28
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
29
+
30
+ if has_image:
31
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
32
+ else:
33
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
34
+ input_ids = conversations_tokenized["input_ids"]
35
+
36
+ targets = copy.deepcopy(input_ids)
37
+ for target, source in zip(targets, sources):
38
+ if has_image:
39
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
40
+ else:
41
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
42
+ speakers = [sentence["from"] for sentence in source]
43
+ _mask_targets(target, tokenized_lens, speakers)
44
+
45
+ return dict(input_ids=input_ids, labels=targets)
46
+
47
+
48
+ def _tokenize_fn(strings: Sequence[str],
49
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
50
+ """Tokenize a list of strings."""
51
+ tokenized_list = [
52
+ tokenizer(
53
+ text,
54
+ return_tensors="pt",
55
+ padding="longest",
56
+ max_length=tokenizer.model_max_length,
57
+ truncation=True,
58
+ ) for text in strings
59
+ ]
60
+ input_ids = labels = [
61
+ tokenized.input_ids[0] for tokenized in tokenized_list
62
+ ]
63
+ input_ids_lens = labels_lens = [
64
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
65
+ for tokenized in tokenized_list
66
+ ]
67
+ return dict(
68
+ input_ids=input_ids,
69
+ labels=labels,
70
+ input_ids_lens=input_ids_lens,
71
+ labels_lens=labels_lens,
72
+ )
73
+
74
+
75
+ def _add_speaker_and_signal(header, source, get_conversation=True):
76
+ """Add speaker and start/end signal on each round."""
77
+ BEGIN_SIGNAL = "### "
78
+ END_SIGNAL = "\n"
79
+ conversation = header
80
+ for sentence in source:
81
+ from_str = sentence["from"]
82
+ if from_str.lower() == "human":
83
+ from_str = conversation_lib.default_conversation.roles[0]
84
+ elif from_str.lower() == "gpt":
85
+ from_str = conversation_lib.default_conversation.roles[1]
86
+ else:
87
+ from_str = 'unknown'
88
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
89
+ sentence["value"] + END_SIGNAL)
90
+ if get_conversation:
91
+ conversation += sentence["value"]
92
+ conversation += BEGIN_SIGNAL
93
+ return conversation
94
+
95
+
96
+ def _mask_targets(target, tokenized_lens, speakers):
97
+ # cur_idx = 0
98
+ cur_idx = tokenized_lens[0]
99
+ tokenized_lens = tokenized_lens[1:]
100
+ target[:cur_idx] = IGNORE_INDEX
101
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
102
+ if speaker == "human":
103
+ target[cur_idx + 2:cur_idx + tokenized_len] = IGNORE_INDEX
104
+ cur_idx += tokenized_len
tinychart/data/preprocess/phi.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Sequence, List
2
+ import copy
3
+
4
+ import transformers
5
+ import torch
6
+
7
+ from tinychart.data.process import register_preprocess
8
+ from tinychart.mm_utils import tokenizer_image_token
9
+ from tinychart import conversation as conversation_lib
10
+ from tinychart.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
11
+ DEFAULT_IM_END_TOKEN
12
+
13
+
14
+ @register_preprocess('phi')
15
+ def preprocess_phi(
16
+ sources,
17
+ tokenizer: transformers.PreTrainedTokenizer,
18
+ has_image: bool = False
19
+ ) -> Dict:
20
+ conv = conversation_lib.default_conversation.copy()
21
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
22
+
23
+ # print('00000000000', sources)
24
+ # Apply prompt templates
25
+ conversations = []
26
+
27
+ for i, source in enumerate(sources):
28
+ if roles[source[0]["from"]] != conv.roles[0]:
29
+ # Skip the first one if it is not from human
30
+ source = source[1:]
31
+
32
+ conv.messages = []
33
+ for j, sentence in enumerate(source):
34
+ role = roles[sentence["from"]]
35
+ assert role == conv.roles[j % 2], f"{i}"
36
+ conv.append_message(role, sentence["value"])
37
+ conversations.append(conv.get_prompt())
38
+ # Tokenize conversations
39
+ if has_image:
40
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
41
+ else:
42
+ input_ids = tokenizer(
43
+ conversations,
44
+ return_tensors="pt",
45
+ padding="longest",
46
+ max_length=tokenizer.model_max_length,
47
+ truncation=True,
48
+ ).input_ids
49
+
50
+ targets = input_ids.clone()
51
+
52
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
53
+ # print(tokenizer)
54
+ # Mask targets
55
+ sep = conv.sep + conv.roles[1] + ": "
56
+ for conversation, target in zip(conversations, targets):
57
+ total_len = int(target.ne(tokenizer.pad_token_id).sum()) + conversation.count(conv.sep2)
58
+
59
+ rounds = conversation.split(conv.sep2)
60
+ cur_len = 0
61
+ # target[:cur_len] = IGNORE_INDEX
62
+ for i, rou in enumerate(rounds):
63
+ if rou == "":
64
+ break
65
+
66
+ parts = rou.split(sep)
67
+ if len(parts) != 2:
68
+ break
69
+ parts[0] += sep
70
+
71
+ if has_image:
72
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + 1
73
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
74
+ else:
75
+ round_len = len(tokenizer(rou).input_ids) + 1
76
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
77
+
78
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
79
+
80
+ cur_len += round_len
81
+ target[cur_len:] = IGNORE_INDEX
82
+
83
+ if cur_len < tokenizer.model_max_length:
84
+ if cur_len != total_len:
85
+ print(
86
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
87
+ f" (ignored)"
88
+ )
89
+ print("number of rounds: ", len(rounds) - 1)
90
+ print("rounds: ", rounds[:-1])
91
+ print("conversation: ", conversations)
92
+ print(target)
93
+ print(input_ids)
94
+ time.sleep(5)
95
+ target[:] = IGNORE_INDEX
96
+
97
+ return dict(
98
+ input_ids=input_ids,
99
+ labels=targets,
100
+ )
tinychart/data/preprocess/v1.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Dict, Optional, Sequence, List
3
+ import copy
4
+
5
+ import transformers
6
+ import tokenizers
7
+ import torch
8
+
9
+ from tinychart.data.process import register_preprocess
10
+ from tinychart.mm_utils import tokenizer_image_token
11
+ from tinychart import conversation as conversation_lib
12
+ from tinychart.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
13
+ DEFAULT_IM_END_TOKEN
14
+
15
+ from packaging import version
16
+
17
+ # IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
18
+
19
+
20
+ @register_preprocess('v1')
21
+ def preprocess_v1(
22
+ sources,
23
+ tokenizer: transformers.PreTrainedTokenizer,
24
+ has_image: bool = False
25
+ ) -> Dict:
26
+ # conv = conversation_lib.default_conversation.copy()
27
+ conv = conversation_lib.conv_phi_v0.copy()
28
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
29
+
30
+ # Apply prompt templates
31
+ conversations = []
32
+ for i, source in enumerate(sources):
33
+ if roles[source[0]["from"]] != conv.roles[0]:
34
+ # Skip the first one if it is not from human
35
+ source = source[1:]
36
+
37
+ conv.messages = []
38
+ for j, sentence in enumerate(source):
39
+ role = roles[sentence["from"]]
40
+ assert role == conv.roles[j % 2], f"{i}"
41
+ conv.append_message(role, sentence["value"])
42
+ conversations.append(conv.get_prompt())
43
+
44
+ # Tokenize conversations
45
+
46
+ if has_image:
47
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
48
+ else:
49
+ input_ids = tokenizer(
50
+ conversations,
51
+ return_tensors="pt",
52
+ padding="longest",
53
+ max_length=tokenizer.model_max_length,
54
+ truncation=True,
55
+ ).input_ids
56
+
57
+ targets = input_ids.clone()
58
+
59
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
60
+
61
+ # Mask targets
62
+ sep = conv.sep + conv.roles[1] + ": "
63
+ for conversation, target in zip(conversations, targets):
64
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
65
+ # total_len = len(target)
66
+
67
+ rounds = conversation.split(conv.sep2)
68
+ cur_len = 0
69
+ # cur_len = 1
70
+ # cur_len = 1 + 1
71
+ target[:cur_len] = IGNORE_INDEX
72
+ for i, rou in enumerate(rounds):
73
+ if rou == "":
74
+ break
75
+
76
+ parts = rou.split(sep)
77
+ if len(parts) != 2:
78
+ break
79
+ parts[0] += sep
80
+
81
+ if has_image:
82
+ round_len = len(tokenizer_image_token(rou, tokenizer))
83
+ # round_len = len(tokenizer_image_token(rou, tokenizer)) - 2 + 1
84
+ # instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
85
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
86
+ else:
87
+ round_len = len(tokenizer(rou).input_ids)
88
+ # round_len = len(tokenizer(rou).input_ids) - 2 + 1
89
+ # instruction_len = len(tokenizer(parts[0]).input_ids) - 2
90
+ instruction_len = len(tokenizer(parts[0]).input_ids)
91
+
92
+ # if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
93
+ # round_len -= 1
94
+ # instruction_len -= 1
95
+ instruction_len -= 1
96
+
97
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
98
+
99
+ cur_len += round_len
100
+ # target[cur_len:] = IGNORE_INDEX
101
+ # import pdb;pdb.set_trace()
102
+
103
+ if cur_len < tokenizer.model_max_length:
104
+ if cur_len != total_len:
105
+
106
+ print(
107
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
108
+ f" (ignored)"
109
+ )
110
+ print("number of rounds: ", len(rounds) - 1)
111
+ print("rounds: ", rounds[:-1])
112
+ print("conversation: ", conversations)
113
+ print(target)
114
+ print(input_ids)
115
+ time.sleep(5)
116
+ target[:] = IGNORE_INDEX
117
+ return dict(
118
+ input_ids=input_ids,
119
+ labels=targets,
120
+ )
tinychart/data/process.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+ from typing import Dict, Optional, Sequence, List
4
+
5
+ import transformers
6
+
7
+ from tinychart.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
8
+ from tinychart import conversation as conversation_lib
9
+ from tinychart.arguments import *
10
+
11
+ PREPROCESS_REGISTRY = {}
12
+
13
+ def register_preprocess(name):
14
+ def register_preprocess_cls(cls):
15
+ if name in PREPROCESS_REGISTRY:
16
+ return PREPROCESS_REGISTRY[name]
17
+
18
+ PREPROCESS_REGISTRY[name] = cls
19
+ return cls
20
+
21
+ return register_preprocess_cls
22
+
23
+
24
+ def import_modules(modules_dir, namespace):
25
+ for file in os.listdir(modules_dir):
26
+ path = os.path.join(modules_dir, file)
27
+
28
+ if (
29
+ not file.startswith("_")
30
+ and not file.startswith(".")
31
+ and (file.endswith(".py") or os.path.isdir(path))
32
+ ):
33
+ module_name = file[: file.find(".py")] if file.endswith(".py") else file
34
+ importlib.import_module(namespace + "." + module_name)
35
+
36
+ models_dir = os.path.join(os.path.dirname(__file__), 'preprocess')
37
+ import_modules(models_dir, "tinychart.data.preprocess")
38
+
39
+
40
+ def PreprocessSelect(version):
41
+ result = PREPROCESS_REGISTRY.get(version, None)
42
+ if result is None:
43
+ for name in PREPROCESS_REGISTRY.keys():
44
+ if version in name:
45
+ result = PREPROCESS_REGISTRY[name]
46
+ break
47
+ if result is None:
48
+ result = PREPROCESS_REGISTRY['default']
49
+ return result
50
+
51
+
52
+
53
+ def preprocess_multimodal(
54
+ sources: Sequence[str],
55
+ data_args: DataArguments
56
+ ) -> Dict:
57
+ is_multimodal = data_args.is_multimodal
58
+ if not is_multimodal:
59
+ return sources
60
+
61
+ for source in sources:
62
+ for sentence in source:
63
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
64
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
65
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
66
+ sentence['value'] = sentence['value'].strip()
67
+ if "mmtag" in conversation_lib.default_conversation.version:
68
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN,
69
+ '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
70
+ replace_token = DEFAULT_IMAGE_TOKEN
71
+ if data_args.mm_use_im_start_end:
72
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
73
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
74
+
75
+ return sources
76
+
77
+
78
+ def preprocess(
79
+ sources: Sequence[str],
80
+ tokenizer: transformers.PreTrainedTokenizer,
81
+ has_image: bool = False
82
+ ) -> Dict:
83
+ return PreprocessSelect(conversation_lib.default_conversation.version)(sources, tokenizer, has_image)
tinychart/eval/__init__.py ADDED
File without changes
tinychart/eval/eval_metric.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import os
4
+ import math
5
+ import copy
6
+ import argparse
7
+ import numpy as np
8
+
9
+ def write_jsonl(data, filename):
10
+ with open(filename, 'w') as f:
11
+ for item in data:
12
+ f.write(json.dumps(item) + '\n')
13
+
14
+ def RelaxedAccuracy(pred, gt):
15
+ try:
16
+ gt = float(gt)
17
+ pred = float(pred)
18
+ if gt == 0.0:
19
+ if pred == gt:
20
+ return 1.0
21
+ else:
22
+ return 0.0
23
+ else:
24
+ if abs(pred-gt) / gt <= 0.05:
25
+ return 1.0
26
+ else:
27
+ return 0.0
28
+ except:
29
+ if str(gt) == str(pred):
30
+ return 1.0
31
+ else:
32
+ return 0.0
33
+
34
+ def evaluate_cmds(cmds):
35
+ for cmd in cmds:
36
+ exec(cmd)
37
+ answer = eval('Answer')
38
+ if (isinstance(answer, list) or isinstance(answer, np.ndarray)) and len(answer) == 1:
39
+ answer = answer[0]
40
+ if isinstance(answer, list) or isinstance(answer, np.ndarray):
41
+ new_answer = answer[0]
42
+ for i in range(1, len(answer)-1):
43
+ new_answer = new_answer + ', ' + answer[i]
44
+ new_answer += ' and ' + answer[-1]
45
+ answer = new_answer
46
+ if isinstance(answer, bool) or isinstance(answer, np.bool_):
47
+ if answer == True:
48
+ answer = 'Yes'
49
+ elif answer == False:
50
+ answer = 'No'
51
+ return answer
52
+
53
+ def parse_model_output(cmdstr):
54
+ lines = cmdstr.split('\n')
55
+ new_lines = []
56
+ for line in lines:
57
+ if '<step>' in line or '</step>' in line:
58
+ line = line.replace('<step>', '').replace('</step>', '')
59
+ new_lines.append(line)
60
+ return new_lines
61
+
62
+ def chartqa_evaluator(data, key='final_model_answer'):
63
+ acc = 0
64
+ for item in data:
65
+ item['relaxed_acc'] = RelaxedAccuracy(item[key], item['gt_answer'].split('<pot_note>')[0])
66
+ if item['relaxed_acc'] == 1.0:
67
+ acc += 1
68
+ accuracy = acc/len(data)
69
+ return data, accuracy
70
+
71
+ def chartqapot_evaluator(output_data):
72
+ correct_items = []
73
+ wrong_items = []
74
+ error_items = []
75
+ output_data = copy.deepcopy(output_data)
76
+ acc = 0
77
+ for item in output_data:
78
+ # cmds = parse_gpt_cmd(gpt_item['eval_cmd'])
79
+ eval_cmds = parse_model_output(item['model_answer'])
80
+ try:
81
+ answer = evaluate_cmds(eval_cmds)
82
+ item['final_model_answer'] = str(answer)
83
+ except:
84
+ error_items.append(item)
85
+ item['final_model_answer'] = 'Execute <error>'
86
+ item['relaxed_acc'] = 0.0
87
+ continue
88
+ item['gt_answer'] = item['gt_answer'].split('<cot_note>')[0]
89
+ item['relaxed_acc'] = RelaxedAccuracy(str(answer), item['gt_answer'])
90
+
91
+ if item['relaxed_acc'] == 1.0:
92
+ correct_items.append(item)
93
+ else:
94
+ wrong_items.append(item)
95
+ total = len(output_data)
96
+ accuracy = len(correct_items)/total
97
+ error_rate = len(error_items)/total
98
+
99
+ return output_data, accuracy, error_rate
100
+
101
+ def rule_based_divider(question):
102
+ calculate_words = [
103
+ 'sum', 'difference', 'times', 'summation', 'exceed',
104
+ 'below', 'addition', 'fewer', 'subtract', ' mode ',
105
+ 'ratio', 'division', 'average', 'mean', 'bigger',
106
+ 'greater', ' less ', 'tallest', 'number', 'divide',
107
+ ' add ', 'absolute', 'dividing', 'differ', ' minus ',
108
+ 'how many colors', 'lowest', 'what is the value', 'higher',
109
+ 'longer', ' biggest ', 'lowest'
110
+ ]
111
+
112
+ for w in calculate_words:
113
+ if w in question.lower():
114
+ return 'pot'
115
+ return 'direct'
116
+
117
+ def chartqa_rule_merger_evaluator(direct_data, pot_data):
118
+ direct_data, _ = chartqa_evaluator(direct_data, key='model_answer')
119
+ assert len(direct_data) == len(pot_data), 'direct and pot num inconsistent'
120
+ acc_count = 0
121
+ merged_data = []
122
+ for datum1, datum2 in zip(direct_data, pot_data):
123
+ if rule_based_divider(datum1['question']) == 'pot' and '<error>' not in datum2['final_model_answer'] and datum2['final_model_answer'] not in ['inf', '-inf', 'nan', 'np.nan', 'np.inf', '-np.inf']:
124
+ acc_count += datum2['relaxed_acc']
125
+ merged_data.append(datum2)
126
+ else:
127
+ acc_count += datum1['relaxed_acc']
128
+ merged_data.append(datum1)
129
+ accuracy = acc_count/len(direct_data)
130
+ return merged_data, accuracy
131
+
132
+ def chartqa_oracle_merger_evaluator(direct_data, pot_data):
133
+ direct_data, _ = chartqa_evaluator(direct_data, key='model_answer')
134
+ assert len(direct_data) == len(pot_data), 'direct and pot num inconsistent'
135
+ acc_count = 0
136
+ merged_data = []
137
+ for datum1, datum2 in zip(direct_data, pot_data):
138
+ if datum1['relaxed_acc'] != 1.0:
139
+ acc_count += datum2['relaxed_acc']
140
+ merged_data.append(datum2)
141
+ else:
142
+ acc_count += datum1['relaxed_acc']
143
+ merged_data.append(datum1)
144
+ accuracy = acc_count/len(direct_data)
145
+ return merged_data, accuracy
146
+
147
+
148
+ if __name__ == '__main__':
149
+ parser = argparse.ArgumentParser()
150
+ parser.add_argument('--direct', default='../eval_iter12000_0226/ChartQA_test_12000_pred.jsonl')
151
+ parser.add_argument('--pot', default='../eval_iter12000_0226/ChartQA_test_pot_12000_eval.jsonl')
152
+ parser.add_argument('--output', default='../eval_iter12000_0226/ChartQA_test_pot_12000_merged.jsonl')
153
+
154
+ args = parser.parse_args()
155
+
156
+ merged = oracle_merger(args.direct, args.pot)
157
+ merged = rule_based_merger(args.direct, args.pot)
158
+
159
+ write_jsonl(merged, args.output)
tinychart/eval/eval_model.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+
4
+ import torch
5
+ import os
6
+ import json
7
+ from tqdm import tqdm
8
+ import shortuuid
9
+
10
+ from tinychart.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
11
+ from tinychart.conversation import conv_templates, SeparatorStyle
12
+ from tinychart.model.builder import load_pretrained_model
13
+ from tinychart.utils import disable_torch_init
14
+ from tinychart.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path, KeywordsStoppingCriteria
15
+ from torch.utils.data import Dataset, DataLoader
16
+
17
+ from PIL import Image
18
+ import math
19
+
20
+
21
+ def split_list(lst, n):
22
+ """Split a list into n (roughly) equal-sized chunks"""
23
+ chunk_size = math.ceil(len(lst) / n) # integer division
24
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
25
+
26
+
27
+ def get_chunk(lst, n, k):
28
+ chunks = split_list(lst, n)
29
+ return chunks[k]
30
+
31
+ class EvalDataset(Dataset):
32
+ def __init__(self, data_items, image_folder, tokenizer, image_processor, model_config):
33
+ self.data_items = data_items
34
+ self.image_folder = image_folder
35
+ self.tokenizer = tokenizer
36
+ self.image_processor = image_processor
37
+ self.model_config = model_config
38
+
39
+ def __getitem__(self, index):
40
+ line = self.data_items[index]
41
+ image_file = line["image"]
42
+ qs = line["conversations"][0]["value"]
43
+ # if self.model_config.mm_use_im_start_end:
44
+ # qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
45
+ # else:
46
+ # qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
47
+
48
+ conv = conv_templates[args.conv_mode].copy()
49
+ conv.append_message(conv.roles[0], qs)
50
+ conv.append_message(conv.roles[1], None)
51
+ prompt = conv.get_prompt()
52
+
53
+ image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
54
+ image_tensor = process_images([image], self.image_processor, self.model_config)[0]
55
+
56
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
57
+
58
+ return input_ids, image_tensor, image.size
59
+
60
+ def __len__(self):
61
+ return len(self.data_items)
62
+
63
+
64
+ def collate_fn(batch):
65
+ input_ids, image_tensors, image_sizes = zip(*batch)
66
+ input_ids = torch.stack(input_ids, dim=0)
67
+ image_tensors = torch.stack(image_tensors, dim=0)
68
+ return input_ids, image_tensors, image_sizes
69
+
70
+
71
+ # DataLoader
72
+ def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
73
+ assert batch_size == 1, "batch_size must be 1"
74
+ dataset = EvalDataset(questions, image_folder, tokenizer, image_processor, model_config)
75
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
76
+ return data_loader
77
+
78
+
79
+ def eval_model(args):
80
+ disable_torch_init()
81
+ model_path = os.path.expanduser(args.model_path)
82
+ model_name = get_model_name_from_path(model_path)
83
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
84
+
85
+ all_data = json.load(open(args.data_path, "r"))
86
+ all_data = get_chunk(all_data, args.num_chunks, args.chunk_idx)
87
+ answers_file = os.path.expanduser(args.output_path)
88
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
89
+ ans_file = open(answers_file, "w")
90
+
91
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
92
+ args.conv_mode = args.conv_mode + '_mmtag'
93
+ print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
94
+
95
+ data_loader = create_data_loader(all_data, args.image_folder, tokenizer, image_processor, model.config)
96
+ for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, all_data), total=len(all_data)):
97
+ idx = line["id"]
98
+ cur_prompt = line["conversations"][0]["value"]
99
+ input_ids = input_ids.to(device='cuda', non_blocking=True)
100
+ with torch.inference_mode():
101
+ output_ids = model.generate(
102
+ input_ids,
103
+ images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
104
+ pad_token_id=tokenizer.pad_token_id,
105
+ do_sample=True if args.temperature > 0 else False,
106
+ temperature=args.temperature,
107
+ top_p=args.top_p,
108
+ num_beams=args.num_beams,
109
+ max_new_tokens=args.max_new_tokens,
110
+ min_new_tokens=args.min_new_tokens,
111
+ use_cache=True)
112
+
113
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
114
+ ans_id = shortuuid.uuid()
115
+ ans_file.write(json.dumps({"id": idx,
116
+ "question": cur_prompt,
117
+ "gt_answer": line["conversations"][1]["value"],
118
+ "model_answer": outputs}) + "\n")
119
+ ans_file.flush()
120
+ ans_file.close()
121
+
122
+ if __name__ == "__main__":
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument("--model_path", type=str, default="facebook/opt-350m")
125
+ parser.add_argument("--model_base", type=str, default=None)
126
+ parser.add_argument("--image_folder", type=str, default="")
127
+ parser.add_argument("--data_path", type=str, default="./data/test_chartqa+cot_shuffle.json")
128
+ parser.add_argument("--output_path", type=str, default="./output/")
129
+ parser.add_argument("--conv_mode", type=str, default="phi")
130
+ parser.add_argument("--num_chunks", type=int, default=1)
131
+ parser.add_argument("--chunk_idx", type=int, default=0)
132
+ parser.add_argument("--temperature", type=float, default=0.0)
133
+ parser.add_argument("--top_p", type=float, default=None)
134
+ parser.add_argument("--num_beams", type=int, default=1)
135
+ parser.add_argument("--max_new_tokens", type=int, default=1024)
136
+ parser.add_argument("--min_new_tokens", type=int, default=0)
137
+ args = parser.parse_args()
138
+
139
+ eval_model(args)
tinychart/eval/run_eval.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import pandas as pd
5
+ from collections import defaultdict
6
+ from tinychart.eval.eval_metric import chartqa_evaluator, chartqapot_evaluator
7
+ from tinychart.eval.eval_metric import chartqa_oracle_merger_evaluator, chartqa_rule_merger_evaluator
8
+
9
+ def read_jsonl(jsonl_path):
10
+ with open(jsonl_path, 'r') as f:
11
+ data = [json.loads(line) for line in f]
12
+ return data
13
+
14
+ def write_jsonl(data, jsonl_path):
15
+ with open(jsonl_path, 'w', encoding='utf-8') as f:
16
+ for item in data:
17
+ f.write(json.dumps(item) + '\n')
18
+
19
+ if __name__ == '__main__':
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--input', default='./output/')
22
+
23
+ args = parser.parse_args()
24
+
25
+ result_files = os.listdir(args.input)
26
+ result_files = [f for f in result_files if f.endswith('.jsonl')]
27
+ result_files.sort()
28
+ direct_result, pot_result = None, None
29
+
30
+ dataset2metric = defaultdict(float)
31
+ for result_file in result_files:
32
+ # print(result_file)
33
+ dataset_name = '.'.join(result_file.split('.')[:-1])
34
+ file = os.path.join(args.input, result_file)
35
+ result_data = read_jsonl(file)
36
+ if 'chartqa-' in dataset_name:
37
+ direct_result, direct_acc = chartqa_evaluator(result_data, key='model_answer')
38
+ write_jsonl(direct_result, file)
39
+ dataset2metric[dataset_name] = round(direct_acc * 100, 2)
40
+ print(f'Direct Accuracy: {direct_acc}')
41
+ elif 'chartqagptpot-' in dataset_name or 'chartqatemplatepot-' in dataset_name:
42
+ pot_result, pot_acc, error_rate = chartqapot_evaluator(result_data)
43
+ write_jsonl(pot_result, file)
44
+ dataset2metric[dataset_name] = round(pot_acc * 100, 2)
45
+ print(f'PoT Accuracy: {pot_acc}')
46
+ print(f'PoT Error Rate: {error_rate}')
47
+
48
+ if direct_result is not None and pot_result is not None:
49
+ print("Calculate merging direct and pot results with simple divider")
50
+ oracle_results, oracle_acc = chartqa_oracle_merger_evaluator(direct_result, pot_result)
51
+ dataset2metric['merged-oracle'] = round(oracle_acc * 100, 2)
52
+ print(f'Oracle Merged Accuracy: {oracle_acc}')
53
+ write_jsonl(oracle_results, os.path.join(args.input, 'merged-oracle.jsonl'))
54
+ rule_results, rule_acc = chartqa_rule_merger_evaluator(direct_result, pot_result)
55
+ dataset2metric['merged-rule'] = round(rule_acc * 100, 2)
56
+ print(f'Rule Merged Accuracy: {rule_acc}')
57
+ write_jsonl(rule_results, os.path.join(args.input, 'merged-rule.jsonl'))
58
+
59
+ # save metrics into tsv with key as the first row
60
+ df = pd.DataFrame(dataset2metric, index=[0])
61
+ # if there is a metrics.tsv exists, add one in the name to avoid overwrite
62
+ tsv_name = os.path.join(args.input, 'metrics.tsv')
63
+ if os.path.exists(tsv_name):
64
+ # avoid overwrite. if there is metrics.1.tsv, name it metrics.2.tsv...
65
+ i = 1
66
+ tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
67
+ while os.path.exists(tsv_name):
68
+ i += 1
69
+ tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
70
+ df.to_csv(tsv_name, sep='\t', index=False)
71
+ print(f'Metrics saved at: {tsv_name}')
72
+ print(df)
tinychart/eval/run_tiny_chart.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from tinychart.constants import (
5
+ IMAGE_TOKEN_INDEX,
6
+ DEFAULT_IMAGE_TOKEN,
7
+ DEFAULT_IM_START_TOKEN,
8
+ DEFAULT_IM_END_TOKEN,
9
+ IMAGE_PLACEHOLDER,
10
+ )
11
+ from tinychart.conversation import conv_templates, SeparatorStyle
12
+ from tinychart.model.builder import load_pretrained_model
13
+ from tinychart.utils import disable_torch_init
14
+ from tinychart.mm_utils import (
15
+ process_images,
16
+ tokenizer_image_token,
17
+ get_model_name_from_path,
18
+ KeywordsStoppingCriteria,
19
+ )
20
+
21
+ from PIL import Image
22
+
23
+ import requests
24
+ from PIL import Image
25
+ from io import BytesIO
26
+ import re
27
+
28
+
29
+ def image_parser(args):
30
+ out = args.image_file.split(args.sep)
31
+ return out
32
+
33
+
34
+ def load_image(image_file):
35
+ if image_file.startswith("http") or image_file.startswith("https"):
36
+ response = requests.get(image_file)
37
+ image = Image.open(BytesIO(response.content)).convert("RGB")
38
+ else:
39
+ image = Image.open(image_file).convert("RGB")
40
+ return image
41
+
42
+
43
+ def load_images(image_files):
44
+ out = []
45
+ for image_file in image_files:
46
+ image = load_image(image_file)
47
+ out.append(image)
48
+ return out
49
+
50
+
51
+ def inference_model(image_files, query, model, tokenizer, image_processor, context_len, conv_mode, temperature=0, max_new_tokens=100):
52
+ qs = query
53
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
54
+ if IMAGE_PLACEHOLDER in qs:
55
+ if model.config.mm_use_im_start_end:
56
+ qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
57
+ else:
58
+ qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
59
+ else:
60
+ if model.config.mm_use_im_start_end:
61
+ qs = image_token_se + "\n" + qs
62
+ else:
63
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
64
+
65
+ conv = conv_templates[conv_mode].copy()
66
+ conv.append_message(conv.roles[0], qs)
67
+ conv.append_message(conv.roles[1], None)
68
+ prompt = conv.get_prompt()
69
+
70
+ images = load_images(image_files)
71
+ images_tensor = process_images(
72
+ images,
73
+ image_processor,
74
+ model.config
75
+ ).to(model.device, dtype=torch.float16)
76
+
77
+ input_ids = (
78
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
79
+ .unsqueeze(0)
80
+ .cuda()
81
+ )
82
+
83
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
84
+ keywords = [stop_str]
85
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
86
+
87
+ with torch.inference_mode():
88
+ output_ids = model.generate(
89
+ input_ids,
90
+ images=images_tensor,
91
+ do_sample=True if temperature > 0 else False,
92
+ temperature=temperature,
93
+ # top_p=top_p,
94
+ # num_beams=args.num_beams,
95
+ pad_token_id=tokenizer.pad_token_id,
96
+ max_new_tokens=max_new_tokens,
97
+ use_cache=True,
98
+ stopping_criteria=[stopping_criteria],
99
+ )
100
+
101
+ outputs = tokenizer.batch_decode(
102
+ output_ids, skip_special_tokens=True
103
+ )[0]
104
+ outputs = outputs.strip()
105
+ if outputs.endswith(stop_str):
106
+ outputs = outputs[: -len(stop_str)]
107
+ outputs = outputs.strip()
108
+ print(outputs)
109
+ return outputs
110
+
111
+
112
+
113
+ if __name__ == "__main__":
114
+ parser = argparse.ArgumentParser()
115
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
116
+ parser.add_argument("--model-base", type=str, default=None)
117
+ parser.add_argument("--image-file", type=str, required=True)
118
+ parser.add_argument("--query", type=str, required=True)
119
+ parser.add_argument("--conv-mode", type=str, default=None)
120
+ parser.add_argument("--sep", type=str, default=",")
121
+ parser.add_argument("--temperature", type=float, default=0.2)
122
+ parser.add_argument("--top_p", type=float, default=None)
123
+ parser.add_argument("--num_beams", type=int, default=1)
124
+ parser.add_argument("--max_new_tokens", type=int, default=512)
125
+ args = parser.parse_args()
126
+
127
+ inference_model(args)
tinychart/mm_utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from tinychart.constants import IMAGE_TOKEN_INDEX
8
+ import math
9
+ import ast
10
+
11
+
12
+ def load_image_from_base64(image):
13
+ return Image.open(BytesIO(base64.b64decode(image)))
14
+
15
+
16
+ def expand2square(pil_img, background_color):
17
+ width, height = pil_img.size
18
+ if width == height:
19
+ return pil_img
20
+ elif width > height:
21
+ result = Image.new(pil_img.mode, (width, width), background_color)
22
+ result.paste(pil_img, (0, (width - height) // 2))
23
+ return result
24
+ else:
25
+ result = Image.new(pil_img.mode, (height, height), background_color)
26
+ result.paste(pil_img, ((height - width) // 2, 0))
27
+ return result
28
+
29
+
30
+ def process_images(images, image_processor, model_cfg):
31
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
32
+ new_images = []
33
+ if image_aspect_ratio == 'pad':
34
+ for image in images:
35
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
36
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
37
+ new_images.append(image)
38
+ elif image_aspect_ratio == "anyres":
39
+ for image in images:
40
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
41
+ new_images.append(image)
42
+ else:
43
+ return image_processor(images, return_tensors='pt')['pixel_values']
44
+ if all(x.shape == new_images[0].shape for x in new_images):
45
+ new_images = torch.stack(new_images, dim=0)
46
+ return new_images
47
+
48
+
49
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
50
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
51
+
52
+ def insert_separator(X, sep):
53
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
54
+
55
+ input_ids = []
56
+ offset = 0
57
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
58
+ offset = 1
59
+ input_ids.append(prompt_chunks[0][0])
60
+
61
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
62
+ input_ids.extend(x[offset:])
63
+
64
+ if return_tensors is not None:
65
+ if return_tensors == 'pt':
66
+ return torch.tensor(input_ids, dtype=torch.long)
67
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
68
+ return input_ids
69
+
70
+
71
+ def get_model_name_from_path(model_path):
72
+ model_path = model_path.strip("/")
73
+ model_paths = model_path.split("/")
74
+ if model_paths[-1].startswith('checkpoint-'):
75
+ return model_paths[-2] + "_" + model_paths[-1]
76
+ else:
77
+ return model_paths[-1]
78
+
79
+
80
+ class KeywordsStoppingCriteria(StoppingCriteria):
81
+ def __init__(self, keywords, tokenizer, input_ids):
82
+ self.keywords = keywords
83
+ self.keyword_ids = []
84
+ self.max_keyword_len = 0
85
+ for keyword in keywords:
86
+ cur_keyword_ids = tokenizer(keyword).input_ids
87
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
88
+ cur_keyword_ids = cur_keyword_ids[1:]
89
+ if len(cur_keyword_ids) > self.max_keyword_len:
90
+ self.max_keyword_len = len(cur_keyword_ids)
91
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
92
+ self.tokenizer = tokenizer
93
+ self.start_len = input_ids.shape[1]
94
+
95
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
96
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
97
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
98
+ for keyword_id in self.keyword_ids:
99
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
100
+ return True
101
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
102
+ for keyword in self.keywords:
103
+ if keyword in outputs:
104
+ return True
105
+ return False
106
+
107
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
108
+ outputs = []
109
+ for i in range(output_ids.shape[0]):
110
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
111
+ return all(outputs)
tinychart/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from tinychart.model.language_model.llava_phi import TinyChartPhiForCausalLM, TinyChartPhiConfig
tinychart/model/builder.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Baichuan Zhou , Junlong Jia, Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from tinychart.model import *
23
+ from tinychart.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto",
27
+ device="cuda", **kwargs):
28
+ kwargs = {"device_map": device_map, **kwargs}
29
+
30
+ if device != "cuda":
31
+ kwargs['device_map'] = {"": device}
32
+
33
+ if load_8bit:
34
+ kwargs['load_in_8bit'] = True
35
+ elif load_4bit:
36
+ kwargs['load_in_4bit'] = True
37
+ kwargs['quantization_config'] = BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_compute_dtype=torch.float16,
40
+ bnb_4bit_use_double_quant=True,
41
+ bnb_4bit_quant_type='nf4'
42
+ )
43
+ else:
44
+ kwargs['torch_dtype'] = torch.float16
45
+
46
+ # Load LLaVA model
47
+ if 'lora' in model_name.lower() and model_base is None:
48
+ warnings.warn(
49
+ 'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
50
+ if 'lora' in model_name.lower() and model_base is not None:
51
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
52
+
53
+ print('Loading LLaVA from base model...')
54
+ tokenizer = AutoTokenizer.from_pretrained(model_base, padding_side="right")
55
+ model = TinyChartPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
56
+ config=lora_cfg_pretrained, **kwargs)
57
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
58
+ if model.lm_head.weight.shape[0] != token_num:
59
+ model.lm_head.weight = torch.nn.Parameter(
60
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
61
+ model.model.embed_tokens.weight = torch.nn.Parameter(
62
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
63
+
64
+ print('Loading additional LLaVA weights...')
65
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
66
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
67
+ else:
68
+ # this is probably from HF Hub
69
+ from huggingface_hub import hf_hub_download
70
+ def load_from_hf(repo_id, filename, subfolder=None):
71
+ cache_file = hf_hub_download(
72
+ repo_id=repo_id,
73
+ filename=filename,
74
+ subfolder=subfolder)
75
+ return torch.load(cache_file, map_location='cpu')
76
+
77
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
78
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
79
+ non_lora_trainables.items()}
80
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
81
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
82
+ model.load_state_dict(non_lora_trainables, strict=False)
83
+
84
+ from peft import PeftModel
85
+ print('Loading LoRA weights...')
86
+ model = PeftModel.from_pretrained(model, model_path)
87
+ print('Merging LoRA weights...')
88
+ model = model.merge_and_unload()
89
+ print('Model is loaded...')
90
+ elif model_base is not None:
91
+ # this may be mm projector only
92
+ print('Loading LLaVA from base model...')
93
+
94
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, padding_side="right")
95
+ cfg_pretrained = TinyChartPhiConfig.from_pretrained(model_path)
96
+ model = TinyChartPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
97
+ **kwargs)
98
+
99
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
100
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
101
+ model.load_state_dict(mm_projector_weights, strict=False)
102
+ else:
103
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="right")
104
+ model = TinyChartPhiForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
105
+
106
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
107
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
108
+ if mm_use_im_patch_token:
109
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
110
+ if mm_use_im_start_end:
111
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
112
+ model.resize_token_embeddings(len(tokenizer))
113
+ vision_tower = model.get_vision_tower()
114
+ if not vision_tower.is_loaded:
115
+ vision_tower.load_model()
116
+
117
+ if device != "auto":
118
+ vision_tower.to(device=device, dtype=torch.float16)
119
+
120
+ image_processor = vision_tower.image_processor
121
+
122
+ if hasattr(model.config, "max_sequence_length"):
123
+ context_len = model.config.max_sequence_length
124
+ else:
125
+ context_len = 2048
126
+
127
+ return tokenizer, model, image_processor, context_len
tinychart/model/language_model/__init__.py ADDED
File without changes
tinychart/model/language_model/llava_phi.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Baichuan Zhou , Junlong Jia, Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM
23
+
24
+ from transformers import PhiConfig, PhiModel, PhiForCausalLM
25
+
26
+ from transformers.generation.utils import GenerateOutput
27
+
28
+ from transformers.modeling_outputs import CausalLMOutputWithPast
29
+
30
+ from tinychart.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
31
+ from tinychart.model.model_factory import *
32
+
33
+ class TinyChartPhiConfig(PhiConfig):
34
+ model_type = "tiny_chart_phi"
35
+
36
+
37
+ class TinyChartPhiModel(LlavaMetaModel, PhiModel):
38
+ config_class = TinyChartPhiConfig
39
+
40
+ def __init__(self, config: PhiConfig):
41
+ super(TinyChartPhiModel, self).__init__(config)
42
+ self.gradient_checkpointing = False
43
+
44
+ @register_model('tinychart-3b')
45
+ class TinyChartPhiForCausalLM(PhiForCausalLM, LlavaMetaForCausalLM):
46
+ config_class = TinyChartPhiConfig
47
+
48
+ def __init__(self, config):
49
+ super(PhiForCausalLM, self).__init__(config)
50
+ self.model = TinyChartPhiModel(config)
51
+ self.vocab_size = config.vocab_size
52
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
53
+
54
+ # Initialize weights and apply final processing
55
+ self.post_init()
56
+
57
+ def get_model(self):
58
+ return self.model
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: torch.LongTensor = None,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ position_ids: Optional[torch.LongTensor] = None,
65
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
66
+ inputs_embeds: Optional[torch.FloatTensor] = None,
67
+ labels: Optional[torch.LongTensor] = None,
68
+ use_cache: Optional[bool] = None,
69
+ output_attentions: Optional[bool] = None,
70
+ output_hidden_states: Optional[bool] = None,
71
+ images: Optional[torch.FloatTensor] = None,
72
+ return_dict: Optional[bool] = None,
73
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
74
+
75
+ if inputs_embeds is None:
76
+ (
77
+ input_ids,
78
+ position_ids,
79
+ attention_mask,
80
+ past_key_values,
81
+ inputs_embeds,
82
+ labels
83
+ ) = self.prepare_inputs_labels_for_multimodal(
84
+ input_ids,
85
+ position_ids,
86
+ attention_mask,
87
+ past_key_values,
88
+ labels,
89
+ images,
90
+ )
91
+
92
+ return super().forward(
93
+ input_ids=input_ids,
94
+ attention_mask=attention_mask,
95
+ position_ids=position_ids,
96
+ past_key_values=past_key_values,
97
+ inputs_embeds=inputs_embeds,
98
+ labels=labels,
99
+ use_cache=use_cache,
100
+ output_attentions=output_attentions,
101
+ output_hidden_states=output_hidden_states,
102
+ return_dict=return_dict
103
+ )
104
+
105
+ @torch.no_grad()
106
+ def generate(
107
+ self,
108
+ inputs: Optional[torch.Tensor] = None,
109
+ images: Optional[torch.Tensor] = None,
110
+ **kwargs,
111
+ ) -> Union[GenerateOutput, torch.LongTensor]:
112
+ position_ids = kwargs.pop("position_ids", None)
113
+ attention_mask = kwargs.pop("attention_mask", None)
114
+ if "inputs_embeds" in kwargs:
115
+ raise NotImplementedError("`inputs_embeds` is not supported")
116
+
117
+ if images is not None:
118
+ (
119
+ inputs,
120
+ position_ids,
121
+ attention_mask,
122
+ _,
123
+ inputs_embeds,
124
+ _
125
+ ) = self.prepare_inputs_labels_for_multimodal(
126
+ inputs,
127
+ position_ids,
128
+ attention_mask,
129
+ None,
130
+ None,
131
+ images,
132
+ )
133
+ else:
134
+ inputs_embeds = self.get_model().embed_tokens(inputs)
135
+
136
+ return super().generate(
137
+ position_ids=position_ids,
138
+ attention_mask=attention_mask,
139
+ inputs_embeds=inputs_embeds,
140
+ **kwargs
141
+ )
142
+
143
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
144
+ inputs_embeds=None, **kwargs):
145
+ images = kwargs.pop("images", None)
146
+ image_sizes = kwargs.pop("image_sizes", None)
147
+ inputs = super().prepare_inputs_for_generation(
148
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
149
+ )
150
+ if images is not None:
151
+ inputs['images'] = images
152
+ if image_sizes is not None:
153
+ inputs['image_sizes'] = image_sizes
154
+ return inputs
155
+
156
+ @register_tokenizer('phi')
157
+ def get_tokenizer():
158
+ from transformers import AutoTokenizer
159
+ def post_init(tokenizer):
160
+ return tokenizer
161
+ return AutoTokenizer, post_init
162
+
163
+ AutoConfig.register("tiny_chart_phi", TinyChartPhiConfig)
164
+ AutoModelForCausalLM.register(TinyChartPhiConfig, TinyChartPhiForCausalLM)
tinychart/model/language_model/phi/cache_utils.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ class Cache:
7
+ """
8
+ Base, abstract class for all caches. The actual data structure is specific to each subclass.
9
+ """
10
+
11
+ def update(
12
+ self,
13
+ key_states: torch.Tensor,
14
+ value_states: torch.Tensor,
15
+ layer_idx: int,
16
+ cache_kwargs: Optional[Dict[str, Any]] = None,
17
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
18
+ """
19
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
20
+
21
+ Parameters:
22
+ key_states (`torch.Tensor`):
23
+ The new key states to cache.
24
+ value_states (`torch.Tensor`):
25
+ The new value states to cache.
26
+ layer_idx (`int`):
27
+ The index of the layer to cache the states for.
28
+ cache_kwargs (`Dict[str, Any]`, `optional`):
29
+ Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
30
+ cache to be created.
31
+
32
+ Return:
33
+ A tuple containing the updated key and value states.
34
+ """
35
+ raise NotImplementedError("Make sure to implement `update` in a subclass.")
36
+
37
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
38
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
39
+ raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
40
+
41
+ def get_max_length(self) -> Optional[int]:
42
+ """Returns the maximum sequence length of the cached states, if there is any."""
43
+ raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
44
+
45
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
46
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
47
+ # Cache without size limit -> all cache is usable
48
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
49
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
50
+ max_length = self.get_max_length()
51
+ previous_seq_length = self.get_seq_length(layer_idx)
52
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
53
+ return max_length - new_seq_length
54
+ return previous_seq_length
55
+
56
+
57
+ class DynamicCache(Cache):
58
+ """
59
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
60
+
61
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
62
+ `[batch_size, num_heads, seq_len, head_dim]`.
63
+ """
64
+
65
+ def __init__(self) -> None:
66
+ self.key_cache: List[torch.Tensor] = []
67
+ self.value_cache: List[torch.Tensor] = []
68
+ self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
69
+
70
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
71
+ """
72
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
73
+ sequence length.
74
+ """
75
+ if layer_idx < len(self):
76
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
77
+ else:
78
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
79
+
80
+ def __iter__(self):
81
+ """
82
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
83
+ keys and values
84
+ """
85
+ for layer_idx in range(len(self)):
86
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
87
+
88
+ def __len__(self):
89
+ """
90
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
91
+ to the number of layers in the model.
92
+ """
93
+ return len(self.key_cache)
94
+
95
+ def update(
96
+ self,
97
+ key_states: torch.Tensor,
98
+ value_states: torch.Tensor,
99
+ layer_idx: int,
100
+ cache_kwargs: Optional[Dict[str, Any]] = None,
101
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
102
+ """
103
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
104
+
105
+ Parameters:
106
+ key_states (`torch.Tensor`):
107
+ The new key states to cache.
108
+ value_states (`torch.Tensor`):
109
+ The new value states to cache.
110
+ layer_idx (`int`):
111
+ The index of the layer to cache the states for.
112
+ cache_kwargs (`Dict[str, Any]`, `optional`):
113
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
114
+
115
+ Return:
116
+ A tuple containing the updated key and value states.
117
+ """
118
+ # Update the number of seen tokens
119
+ if layer_idx == 0:
120
+ self.seen_tokens += key_states.shape[-2]
121
+
122
+ # Update the cache
123
+ if len(self.key_cache) <= layer_idx:
124
+ self.key_cache.append(key_states)
125
+ self.value_cache.append(value_states)
126
+ else:
127
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
128
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
129
+
130
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
131
+
132
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
133
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
134
+ if len(self.key_cache) <= layer_idx:
135
+ return 0
136
+ return self.key_cache[layer_idx].shape[-2]
137
+
138
+ def get_max_length(self) -> Optional[int]:
139
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
140
+ return None
141
+
142
+ def reorder_cache(self, beam_idx: torch.LongTensor):
143
+ """Reorders the cache for beam search, given the selected beam indices."""
144
+ for layer_idx in range(len(self.key_cache)):
145
+ device = self.key_cache[layer_idx].device
146
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
147
+ device = self.value_cache[layer_idx].device
148
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
149
+
150
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
151
+ """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
152
+ legacy_cache = ()
153
+ for layer_idx in range(len(self)):
154
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
155
+ return legacy_cache
156
+
157
+ @classmethod
158
+ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
159
+ """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
160
+ cache = cls()
161
+ if past_key_values is not None:
162
+ for layer_idx in range(len(past_key_values)):
163
+ key_states, value_states = past_key_values[layer_idx]
164
+ cache.update(key_states, value_states, layer_idx)
165
+ return cache
166
+
167
+
168
+ class SinkCache(Cache):
169
+ """
170
+ A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
171
+ generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
172
+ tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
173
+
174
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
175
+ `[batch_size, num_heads, seq_len, head_dim]`.
176
+
177
+ Parameters:
178
+ window_length (`int`):
179
+ The length of the context window.
180
+ num_sink_tokens (`int`):
181
+ The number of sink tokens. See the original paper for more information.
182
+ """
183
+
184
+ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
185
+ self.key_cache: List[torch.Tensor] = []
186
+ self.value_cache: List[torch.Tensor] = []
187
+ self.window_length = window_length
188
+ self.num_sink_tokens = num_sink_tokens
189
+ self.cos_sin_cache = {}
190
+ self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
191
+
192
+ @staticmethod
193
+ def _rotate_half(x):
194
+ x1 = x[..., : x.shape[-1] // 2]
195
+ x2 = x[..., x.shape[-1] // 2 :]
196
+ return torch.cat((-x2, x1), dim=-1)
197
+
198
+ def _apply_key_rotary_pos_emb(
199
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
200
+ ) -> torch.Tensor:
201
+ rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
202
+ return rotated_key_states
203
+
204
+ def _get_rerotation_cos_sin(
205
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
206
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ if key_states.shape[-2] not in self.cos_sin_cache:
208
+ # Upcast to float32 temporarily for better accuracy
209
+ cos = cos.to(torch.float32)
210
+ sin = sin.to(torch.float32)
211
+
212
+ # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
213
+ original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
214
+ shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
215
+ original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
216
+ shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
217
+ rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
218
+ rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
219
+
220
+ self.cos_sin_cache[key_states.shape[-2]] = (
221
+ rerotation_cos.to(key_states.dtype).unsqueeze(0),
222
+ rerotation_sin.to(key_states.dtype).unsqueeze(0),
223
+ )
224
+ return self.cos_sin_cache[key_states.shape[-2]]
225
+
226
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
227
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
228
+ # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
229
+ if len(self.key_cache) <= layer_idx:
230
+ return 0
231
+ return self.key_cache[layer_idx].shape[-2]
232
+
233
+ def get_max_length(self) -> Optional[int]:
234
+ """Returns the maximum sequence length of the cached states."""
235
+ return self.window_length
236
+
237
+ def update(
238
+ self,
239
+ key_states: torch.Tensor,
240
+ value_states: torch.Tensor,
241
+ layer_idx: int,
242
+ cache_kwargs: Optional[Dict[str, Any]] = None,
243
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
244
+ """
245
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
246
+
247
+ Parameters:
248
+ key_states (`torch.Tensor`):
249
+ The new key states to cache.
250
+ value_states (`torch.Tensor`):
251
+ The new value states to cache.
252
+ layer_idx (`int`):
253
+ The index of the layer to cache the states for.
254
+ cache_kwargs (`Dict[str, Any]`, `optional`):
255
+ Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
256
+ `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
257
+ rotation as the tokens are shifted.
258
+
259
+ Return:
260
+ A tuple containing the updated key and value states.
261
+ """
262
+ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
263
+ # with partially rotated position embeddings, like Phi or Persimmon.
264
+ sin = cache_kwargs.get("sin")
265
+ cos = cache_kwargs.get("cos")
266
+ partial_rotation_size = cache_kwargs.get("partial_rotation_size")
267
+ using_rope = cos is not None and sin is not None
268
+
269
+ # Update the number of seen tokens
270
+ if layer_idx == 0:
271
+ self.seen_tokens += key_states.shape[-2]
272
+
273
+ # [bsz, num_heads, seq_len, head_dim]
274
+ if len(self.key_cache) <= layer_idx:
275
+ # Empty cache
276
+ self.key_cache.append(key_states)
277
+ self.value_cache.append(value_states)
278
+
279
+ elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
280
+ # Growing cache
281
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
282
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
283
+
284
+ else:
285
+ # Shifting cache
286
+ keys_to_keep = self.key_cache[layer_idx][
287
+ :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
288
+ ]
289
+
290
+ # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
291
+ if using_rope:
292
+ rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
293
+ key_states, cos[: self.window_length], sin[: self.window_length]
294
+ )
295
+ if partial_rotation_size is not None:
296
+ keys_to_keep, keys_pass = (
297
+ keys_to_keep[..., :partial_rotation_size],
298
+ keys_to_keep[..., partial_rotation_size:],
299
+ )
300
+ keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
301
+ if partial_rotation_size is not None:
302
+ keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
303
+
304
+ # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
305
+ sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
306
+ self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
307
+
308
+ sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
309
+ values_to_keep = self.value_cache[layer_idx][
310
+ :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
311
+ ]
312
+ self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
313
+
314
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
315
+
316
+ def reorder_cache(self, beam_idx: torch.LongTensor):
317
+ """Reorders the cache for beam search, given the selected beam indices."""
318
+ for layer_idx in range(len(self.key_cache)):
319
+ device = self.key_cache[layer_idx].device
320
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
321
+ device = self.value_cache[layer_idx].device
322
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
tinychart/model/language_model/phi/configuration_phi.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Phi model configuration"""
17
+
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
+ "microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json",
27
+ }
28
+
29
+
30
+ class PhiConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
33
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34
+ defaults will yield a similar configuration to that of the Phi
35
+ [microsoft/phi-1](https://huggingface.co/microsoft/phi-1).
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 51200):
40
+ Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`PhiModel`].
42
+ hidden_size (`int`, *optional*, defaults to 2048):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 8192):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 24):
47
+ Number of hidden layers in the Transformer decoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer decoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
59
+ Dropout probability for mlp outputs.
60
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
61
+ The dropout ratio for the embeddings.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio after computing the attention scores.
64
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
65
+ The non-linear activation function (function or string) in the decoder.
66
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
67
+ The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
68
+ tokens.
69
+ initializer_range (`float`, *optional*, defaults to 0.02):
70
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
71
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
72
+ The epsilon used by the rms normalization layers.
73
+ use_cache (`bool`, *optional*, defaults to `True`):
74
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
75
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
76
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
77
+ Whether to tie weight embeddings
78
+ rope_theta (`float`, *optional*, defaults to 10000.0):
79
+ The base period of the RoPE embeddings.
80
+ rope_scaling (`Dict`, *optional*):
81
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
82
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
83
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
84
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
85
+ these scaling strategies behave:
86
+ https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
87
+ is an experimental feature, subject to breaking API changes in future versions.
88
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
89
+ Percentage of the query and keys which will have rotary embedding.
90
+ qk_layernorm (`bool`, *optional*, defaults to `False`):
91
+ Whether or not to normalize the Queries and Keys after projecting the hidden states.
92
+ bos_token_id (`int`, *optional*, defaults to 1):
93
+ Denotes beginning of sequences token id.
94
+ eos_token_id (`int`, *optional*, defaults to 2):
95
+ Denotes end of sequences token id.
96
+ Example:
97
+ ```python
98
+ >>> from transformers import PhiModel, PhiConfig
99
+ >>> # Initializing a Phi-1 style configuration
100
+ >>> configuration = PhiConfig.from_pretrained("microsoft/phi-1")
101
+ >>> # Initializing a model from the configuration
102
+ >>> model = PhiModel(configuration)
103
+ >>> # Accessing the model configuration
104
+ >>> configuration = model.config
105
+ ```"""
106
+
107
+ model_type = "phi"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+
110
+ def __init__(
111
+ self,
112
+ vocab_size=51200,
113
+ hidden_size=2048,
114
+ intermediate_size=8192,
115
+ num_hidden_layers=24,
116
+ num_attention_heads=32,
117
+ num_key_value_heads=None,
118
+ resid_pdrop=0.0,
119
+ embd_pdrop=0.0,
120
+ attention_dropout=0.0,
121
+ hidden_act="gelu_new",
122
+ max_position_embeddings=2048,
123
+ initializer_range=0.02,
124
+ layer_norm_eps=1e-5,
125
+ use_cache=True,
126
+ tie_word_embeddings=False,
127
+ rope_theta=10000.0,
128
+ rope_scaling=None,
129
+ partial_rotary_factor=0.5,
130
+ qk_layernorm=False,
131
+ bos_token_id=1,
132
+ eos_token_id=2,
133
+ **kwargs,
134
+ ):
135
+ self.vocab_size = vocab_size
136
+ self.hidden_size = hidden_size
137
+ self.intermediate_size = intermediate_size
138
+ self.num_hidden_layers = num_hidden_layers
139
+ self.num_attention_heads = num_attention_heads
140
+
141
+ if num_key_value_heads is None:
142
+ num_key_value_heads = num_attention_heads
143
+
144
+ self.num_key_value_heads = num_key_value_heads
145
+ self.resid_pdrop = resid_pdrop
146
+ self.embd_pdrop = embd_pdrop
147
+ self.attention_dropout = attention_dropout
148
+ self.hidden_act = hidden_act
149
+ self.max_position_embeddings = max_position_embeddings
150
+ self.initializer_range = initializer_range
151
+ self.layer_norm_eps = layer_norm_eps
152
+ self.use_cache = use_cache
153
+ self.rope_theta = rope_theta
154
+ self.rope_scaling = rope_scaling
155
+ self.partial_rotary_factor = partial_rotary_factor
156
+ self.qk_layernorm = qk_layernorm
157
+ self._rope_scaling_validation()
158
+
159
+ super().__init__(
160
+ bos_token_id=bos_token_id,
161
+ eos_token_id=eos_token_id,
162
+ tie_word_embeddings=tie_word_embeddings,
163
+ **kwargs,
164
+ )
165
+
166
+ # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
167
+ def _rope_scaling_validation(self):
168
+ """
169
+ Validate the `rope_scaling` configuration.
170
+ """
171
+ if self.rope_scaling is None:
172
+ return
173
+
174
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
175
+ raise ValueError(
176
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
177
+ f"got {self.rope_scaling}"
178
+ )
179
+ rope_scaling_type = self.rope_scaling.get("type", None)
180
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
181
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
182
+ raise ValueError(
183
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
184
+ )
185
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
186
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
tinychart/model/language_model/phi/convert_phi_weights_to_hf.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Weights conversion script for Phi
18
+
19
+ This script downloads both Phi-1 and Phi-1.5 checkpoints to "checkpoint_path" and then converts the weights to
20
+ HugfgingFace model's format and saves them in "pytorch_dump_folder_path".
21
+ """
22
+
23
+ import argparse
24
+ import gc
25
+ import os
26
+
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+
30
+ from modeling_phi import PhiConfig, PhiForCausalLM
31
+
32
+
33
+ _MODELS = {
34
+ "microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/blob/main/pytorch_model.bin",
35
+ "microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/blob/main/pytorch_model.bin",
36
+ }
37
+
38
+
39
+ PHI_MAPPING = {
40
+ "layers.0.wte.weight": "model.embed_tokens.weight",
41
+ "layers.25.linear.bias": "lm_head.bias",
42
+ "layers.25.linear.weight": "lm_head.weight",
43
+ "layers.25.ln.bias": "model.final_layernorm.bias",
44
+ "layers.25.ln.weight": "model.final_layernorm.weight",
45
+ "layers": "model.layers",
46
+ "ln": "input_layernorm",
47
+ "mixer": "self_attn",
48
+ "Wqkv": "query_key_value",
49
+ "out_proj": "dense",
50
+ }
51
+
52
+
53
+ def convert_weights(original_weights, mapping, config):
54
+ converted_weights = {}
55
+ original_weights_keys = sorted(original_weights.keys())
56
+
57
+ # we change names (1-24) -> layers(0-23) for Phi model layers
58
+ range_change = {
59
+ f"layers.{k}.": f"layers.{v}."
60
+ for k, v in zip(range(1, config.num_hidden_layers + 1), range(0, config.num_hidden_layers))
61
+ }
62
+
63
+ mapping.update(**range_change)
64
+
65
+ for original_weights_key in original_weights_keys:
66
+ new_key = original_weights_key
67
+
68
+ if "rotary_emb" in new_key:
69
+ continue
70
+
71
+ if "Wqkv" in new_key:
72
+ if "weight" in new_key:
73
+ weight = original_weights[new_key]
74
+ weights_shape = weight.shape
75
+ weight = (
76
+ weight.view(3, config.num_attention_heads, -1, config.hidden_size)
77
+ .transpose(0, 1)
78
+ .reshape(*weights_shape)
79
+ )
80
+ original_weights[new_key] = weight
81
+ elif "bias" in new_key:
82
+ bias = original_weights[new_key]
83
+ bias_shape = bias.shape
84
+ bias = bias.view(3, config.num_attention_heads, -1).transpose(0, 1).reshape(*bias_shape)
85
+ original_weights[new_key] = bias
86
+
87
+ for k, v in mapping.items():
88
+ if k in new_key:
89
+ new_key = new_key.replace(k, v)
90
+
91
+ converted_weights[new_key] = original_weights.pop(original_weights_key)
92
+
93
+ return converted_weights
94
+
95
+
96
+ def _download(url: str, root: str):
97
+ repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
98
+ filename = f"{url.split('/')[-1]}"
99
+ hf_hub_download(
100
+ repo_id=repo_id,
101
+ filename=filename,
102
+ force_filename=root,
103
+ local_dir_use_symlinks=False,
104
+ )
105
+
106
+
107
+ def convert_phi_weights(checkpoint_path, pytorch_dump_folder_path, use_cuda, save_weights_directly):
108
+ device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
109
+ for each_model_name, each_model_url in _MODELS.items():
110
+ converted_checkpoint = {}
111
+
112
+ model_path = os.path.join(checkpoint_path, each_model_name + "_" + each_model_url.split("/")[-1])
113
+ if not os.path.exists(model_path):
114
+ print(f"\n{each_model_name} was not found! Downloading it to {model_path}")
115
+ _download(url=each_model_url, root=model_path)
116
+ model_checkpoint = torch.load(model_path, map_location=device)
117
+ model_type = each_model_name.split("/")[1] # phi-1 or phi-1_5
118
+ config = PhiConfig.from_pretrained(f"susnato/{model_type}_dev")
119
+
120
+ # Converting the weights
121
+ converted_checkpoint.update(**convert_weights(model_checkpoint, PHI_MAPPING, config))
122
+
123
+ # Save either the whole model or the converted weights
124
+ if save_weights_directly:
125
+ save_weights_path = os.path.join(
126
+ pytorch_dump_folder_path, each_model_name.split("/")[-1] + "_" + each_model_url.split("/")[-1]
127
+ )
128
+ torch.save(converted_checkpoint, save_weights_path)
129
+ print(f"Model weights saved at {save_weights_path}!")
130
+
131
+ else:
132
+ model = PhiForCausalLM(config).to(device)
133
+ model.load_state_dict(converted_checkpoint, strict=True)
134
+ save_model_path = os.path.join(pytorch_dump_folder_path, model_type)
135
+ model.save_pretrained(save_model_path)
136
+ print(f"Model saved at {save_model_path}!")
137
+
138
+ # release GPU memory for the 2nd model if cuda was used.
139
+ del config, model
140
+
141
+ # release GPU memory for the 2nd model if cuda was used.
142
+ del model_checkpoint, converted_checkpoint
143
+ if use_cuda:
144
+ torch.cuda.empty_cache()
145
+ gc.collect()
146
+
147
+
148
+ if __name__ == "__main__":
149
+ parser = argparse.ArgumentParser()
150
+ # # Required parameters
151
+ parser.add_argument(
152
+ "--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
153
+ )
154
+ parser.add_argument(
155
+ "--pytorch_dump_folder_path",
156
+ default=None,
157
+ type=str,
158
+ help="Path to the output PyTorch model. (Please enter full path)",
159
+ )
160
+ parser.add_argument(
161
+ "--use_cuda",
162
+ default=False,
163
+ type=bool,
164
+ help="Whether to load the weights on GPU during conversion or not, False by default",
165
+ )
166
+ parser.add_argument(
167
+ "--save_weights_directly",
168
+ default=True,
169
+ type=bool,
170
+ help="Whether to save the weights directly after conversion or load the weight to the Phi model and then save "
171
+ "the Phi model along with weights. True by default",
172
+ )
173
+
174
+ args = parser.parse_args()
175
+ convert_phi_weights(args.checkpoint_path, args.pytorch_dump_folder_path, args.use_cuda, args.save_weights_directly)
tinychart/model/language_model/phi/modeling_attn_mask_utils.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+
19
+
20
+ @dataclass
21
+ class AttentionMaskConverter:
22
+ """
23
+ A utility attention mask class that allows one to:
24
+ - Create a causal 4d mask
25
+ - Create a causal 4d mask with slided window
26
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
27
+ key_value_length) that can be multiplied with attention scores
28
+
29
+ Examples:
30
+
31
+ ```python
32
+ >>> import torch
33
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
34
+
35
+ >>> converter = AttentionMaskConverter(True)
36
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
37
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
38
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
39
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
40
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
41
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
42
+ ```
43
+
44
+ Parameters:
45
+ is_causal (`bool`):
46
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
47
+
48
+ sliding_window (`int`, *optional*):
49
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
50
+ """
51
+
52
+ is_causal: bool
53
+ sliding_window: int
54
+
55
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
56
+ self.is_causal = is_causal
57
+ self.sliding_window = sliding_window
58
+
59
+ if self.sliding_window is not None and self.sliding_window <= 0:
60
+ raise ValueError(
61
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
62
+ )
63
+
64
+ def to_causal_4d(
65
+ self,
66
+ batch_size: int,
67
+ query_length: int,
68
+ key_value_length: int,
69
+ dtype: torch.dtype,
70
+ device: Union[torch.device, "str"] = "cpu",
71
+ ) -> Optional[torch.Tensor]:
72
+ """
73
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
74
+ bias to upper right hand triangular matrix (causal mask).
75
+ """
76
+ if not self.is_causal:
77
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
78
+
79
+ # If shape is not cached, create a new causal mask and cache it
80
+ input_shape = (batch_size, query_length)
81
+ past_key_values_length = key_value_length - query_length
82
+
83
+ # create causal mask
84
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
85
+ causal_4d_mask = None
86
+ if input_shape[-1] > 1 or self.sliding_window is not None:
87
+ causal_4d_mask = self._make_causal_mask(
88
+ input_shape,
89
+ dtype,
90
+ device=device,
91
+ past_key_values_length=past_key_values_length,
92
+ sliding_window=self.sliding_window,
93
+ )
94
+
95
+ return causal_4d_mask
96
+
97
+ def to_4d(
98
+ self,
99
+ attention_mask_2d: torch.Tensor,
100
+ query_length: int,
101
+ dtype: torch.dtype,
102
+ key_value_length: Optional[int] = None,
103
+ ) -> torch.Tensor:
104
+ """
105
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
106
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
107
+ causal, a causal mask will be added.
108
+ """
109
+ input_shape = (attention_mask_2d.shape[0], query_length)
110
+
111
+ # create causal mask
112
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
113
+ causal_4d_mask = None
114
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
115
+ if key_value_length is None:
116
+ raise ValueError(
117
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
118
+ )
119
+
120
+ past_key_values_length = key_value_length - query_length
121
+ causal_4d_mask = self._make_causal_mask(
122
+ input_shape,
123
+ dtype,
124
+ device=attention_mask_2d.device,
125
+ past_key_values_length=past_key_values_length,
126
+ sliding_window=self.sliding_window,
127
+ )
128
+ elif self.sliding_window is not None:
129
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
130
+
131
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
132
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
133
+ attention_mask_2d.device
134
+ )
135
+ if causal_4d_mask is not None:
136
+ expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
137
+
138
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
139
+ expanded_4d_mask = expanded_attn_mask
140
+
141
+ return expanded_4d_mask
142
+
143
+ @staticmethod
144
+ def _make_causal_mask(
145
+ input_ids_shape: torch.Size,
146
+ dtype: torch.dtype,
147
+ device: torch.device,
148
+ past_key_values_length: int = 0,
149
+ sliding_window: Optional[int] = None,
150
+ ):
151
+ """
152
+ Make causal mask used for bi-directional self-attention.
153
+ """
154
+ bsz, tgt_len = input_ids_shape
155
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
156
+ mask_cond = torch.arange(mask.size(-1), device=device)
157
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
158
+
159
+ mask = mask.to(dtype)
160
+
161
+ if past_key_values_length > 0:
162
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
163
+
164
+ # add lower triangular sliding window mask if necessary
165
+ if sliding_window is not None:
166
+ diagonal = past_key_values_length - sliding_window + 1
167
+
168
+ context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
169
+ mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
170
+
171
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
172
+
173
+ @staticmethod
174
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
175
+ """
176
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
177
+ """
178
+ bsz, src_len = mask.size()
179
+ tgt_len = tgt_len if tgt_len is not None else src_len
180
+
181
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
182
+
183
+ inverted_mask = 1.0 - expanded_mask
184
+
185
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
186
+
187
+ @staticmethod
188
+ def _unmask_unattended(
189
+ expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
190
+ ):
191
+ # fmt: off
192
+ """
193
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
194
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
195
+ Details: https://github.com/pytorch/pytorch/issues/110213
196
+
197
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
198
+ `attention_mask` is [bsz, src_seq_len].
199
+
200
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
201
+
202
+ For example, if `attention_mask` is
203
+ ```
204
+ [[0, 0, 1],
205
+ [1, 1, 1],
206
+ [0, 1, 1]]
207
+ ```
208
+ and `expanded_mask` is (e.g. here left-padding case)
209
+ ```
210
+ [[[[0, 0, 0],
211
+ [0, 0, 0],
212
+ [0, 0, 1]]],
213
+ [[[1, 0, 0],
214
+ [1, 1, 0],
215
+ [1, 1, 1]]],
216
+ [[[0, 0, 0],
217
+ [0, 1, 0],
218
+ [0, 1, 1]]]]
219
+ ```
220
+ then the modified `expanded_mask` will be
221
+ ```
222
+ [[[[1, 1, 1], <-- modified
223
+ [1, 1, 1], <-- modified
224
+ [0, 0, 1]]],
225
+ [[[1, 0, 0],
226
+ [1, 1, 0],
227
+ [1, 1, 1]]],
228
+ [[[1, 1, 1], <-- modified
229
+ [0, 1, 0],
230
+ [0, 1, 1]]]]
231
+ ```
232
+ """
233
+ # fmt: on
234
+
235
+ # Get the index of the first non-zero value for every sample in the batch.
236
+ # In the above example, indices = [[2], [0], [1]]]
237
+ tmp = torch.arange(attention_mask.shape[1], 0, -1)
238
+ indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
239
+
240
+ # Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
241
+ # expanded mask will be completely unattended.
242
+ left_masked_rows = torch.where(indices > 0)[0]
243
+
244
+ if left_masked_rows.shape[0] == 0:
245
+ return expanded_mask
246
+ indices = indices[left_masked_rows]
247
+
248
+ max_len = torch.max(indices)
249
+ range_tensor = torch.arange(max_len).unsqueeze(0)
250
+ range_tensor = range_tensor.repeat(indices.size(0), 1)
251
+
252
+ # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
253
+ range_tensor[range_tensor >= indices] = 0
254
+
255
+ # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
256
+ if expanded_mask.dim() == 4:
257
+ num_masks = expanded_mask.shape[1]
258
+ if num_masks == 1:
259
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
260
+ mask_slice = (left_masked_rows[:, None], 0, range_tensor)
261
+ else:
262
+ # Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
263
+ mask_slice = (
264
+ left_masked_rows[:, None, None],
265
+ torch.arange(num_masks)[None, :, None],
266
+ range_tensor[:, None, :],
267
+ )
268
+ else:
269
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
270
+ mask_slice = (left_masked_rows[:, None], range_tensor)
271
+
272
+ expanded_mask[mask_slice] = unmasked_value
273
+
274
+ return expanded_mask
275
+
276
+
277
+ def _prepare_4d_causal_attention_mask(
278
+ attention_mask: Optional[torch.Tensor],
279
+ input_shape: Union[torch.Size, Tuple, List],
280
+ inputs_embeds: torch.Tensor,
281
+ past_key_values_length: int,
282
+ sliding_window: Optional[int] = None,
283
+ ):
284
+ """
285
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
286
+ `(batch_size, key_value_length)`
287
+
288
+ Args:
289
+ attention_mask (`torch.Tensor` or `None`):
290
+ A 2D attention mask of shape `(batch_size, key_value_length)`
291
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
292
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
293
+ inputs_embeds (`torch.Tensor`):
294
+ The embedded inputs as a torch Tensor.
295
+ past_key_values_length (`int`):
296
+ The length of the key value cache.
297
+ sliding_window (`int`, *optional*):
298
+ If the model uses windowed attention, a sliding window should be passed.
299
+ """
300
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
301
+
302
+ key_value_length = input_shape[-1] + past_key_values_length
303
+
304
+ # 4d mask is passed through the layers
305
+ if attention_mask is not None and len(attention_mask.shape) == 2:
306
+ attention_mask = attn_mask_converter.to_4d(
307
+ attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
308
+ )
309
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
310
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
311
+ if tuple(attention_mask.shape) != expected_shape:
312
+ raise ValueError(
313
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
314
+ )
315
+ else:
316
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
317
+ inverted_mask = 1.0 - attention_mask
318
+ attention_mask = inverted_mask.masked_fill(
319
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
320
+ )
321
+ else:
322
+ attention_mask = attn_mask_converter.to_causal_4d(
323
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
324
+ )
325
+
326
+ return attention_mask
327
+
328
+
329
+ # Adapted from _prepare_4d_causal_attention_mask
330
+ def _prepare_4d_causal_attention_mask_for_sdpa(
331
+ attention_mask: Optional[torch.Tensor],
332
+ input_shape: Union[torch.Size, Tuple, List],
333
+ inputs_embeds: torch.Tensor,
334
+ past_key_values_length: int,
335
+ sliding_window: Optional[int] = None,
336
+ ):
337
+ """
338
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
339
+
340
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
341
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
342
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
343
+ """
344
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
345
+
346
+ key_value_length = input_shape[-1] + past_key_values_length
347
+ batch_size, query_length = input_shape
348
+
349
+ # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
350
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
351
+ # TODO: Fix this as well when using torchdynamo with fullgraph=True.
352
+ is_tracing = torch.jit.is_tracing()
353
+
354
+ if attention_mask is not None:
355
+ # 4d mask is passed through
356
+ if len(attention_mask.shape) == 4:
357
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
358
+ if tuple(attention_mask.shape) != expected_shape:
359
+ raise ValueError(
360
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
361
+ )
362
+ else:
363
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
364
+ inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
365
+ attention_mask = inverted_mask.masked_fill(
366
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
367
+ )
368
+ return attention_mask
369
+
370
+ elif torch.all(attention_mask == 1):
371
+ if is_tracing:
372
+ pass
373
+ elif query_length == 1:
374
+ # For query_length == 1, causal attention and bi-directional attention are the same.
375
+ attention_mask = None
376
+ elif key_value_length == query_length:
377
+ attention_mask = None
378
+ else:
379
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
380
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
381
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
382
+ pass
383
+ elif query_length > 1 and key_value_length != query_length:
384
+ # See the comment above (https://github.com/pytorch/pytorch/issues/108108).
385
+ # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
386
+ attention_mask = True
387
+ elif is_tracing:
388
+ raise ValueError(
389
+ 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
390
+ )
391
+
392
+ if attention_mask is None:
393
+ expanded_4d_mask = None
394
+ elif attention_mask is True:
395
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
396
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
397
+ )
398
+ else:
399
+ expanded_4d_mask = attn_mask_converter.to_4d(
400
+ attention_mask,
401
+ input_shape[-1],
402
+ dtype=inputs_embeds.dtype,
403
+ key_value_length=key_value_length,
404
+ )
405
+
406
+ # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
407
+ # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
408
+ if query_length > 1:
409
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
410
+ expanded_4d_mask, attention_mask, unmasked_value=0.0
411
+ )
412
+
413
+ return expanded_4d_mask
414
+
415
+
416
+ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
417
+ """
418
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
419
+ `(batch_size, key_value_length)`
420
+
421
+ Args:
422
+ mask (`torch.Tensor` or `None`):
423
+ A 2D attention mask of shape `(batch_size, key_value_length)`
424
+ dtype (`torch.dtype`):
425
+ The torch dtype the created mask shall have.
426
+ tgt_len (`int`):
427
+ The target length or query length the created mask shall have.
428
+ """
429
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
430
+
431
+
432
+ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
433
+ """
434
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
435
+ `(batch_size, key_value_length)`
436
+
437
+ Args:
438
+ mask (`torch.Tensor` or `None`):
439
+ A 2D attention mask of shape `(batch_size, key_value_length)`
440
+ dtype (`torch.dtype`):
441
+ The torch dtype the created mask shall have.
442
+ tgt_len (`int`):
443
+ The target length or query length the created mask shall have.
444
+ """
445
+ batch_size, key_value_length = mask.shape
446
+ tgt_len = tgt_len if tgt_len is not None else key_value_length
447
+
448
+ # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
449
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
450
+ # TODO: Fix this as well when using torchdynamo with fullgraph=True.
451
+ is_tracing = torch.jit.is_tracing()
452
+
453
+ if torch.all(mask == 1):
454
+ if is_tracing:
455
+ pass
456
+ elif tgt_len == 1:
457
+ # For query_length == 1, causal attention and bi-directional attention are the same.
458
+ return None
459
+ elif key_value_length == tgt_len:
460
+ return None
461
+ else:
462
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
463
+ # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
464
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
465
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
466
+ else:
467
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
468
+
469
+
470
+ def _create_4d_causal_attention_mask(
471
+ input_shape: Union[torch.Size, Tuple, List],
472
+ dtype: torch.dtype,
473
+ device: torch.device,
474
+ past_key_values_length: int = 0,
475
+ sliding_window: Optional[int] = None,
476
+ ) -> Optional[torch.Tensor]:
477
+ """
478
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
479
+
480
+ Args:
481
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
482
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
483
+ dtype (`torch.dtype`):
484
+ The torch dtype the created mask shall have.
485
+ device (`int`):
486
+ The torch device the created mask shall have.
487
+ sliding_window (`int`, *optional*):
488
+ If the model uses windowed attention, a sliding window should be passed.
489
+ """
490
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
491
+
492
+ key_value_length = past_key_values_length + input_shape[-1]
493
+ attention_mask = attn_mask_converter.to_causal_4d(
494
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
495
+ )
496
+
497
+ return attention_mask
tinychart/model/language_model/phi/modeling_phi.py ADDED
@@ -0,0 +1,1345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch Phi model."""
17
+
18
+
19
+ import math
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPast,
33
+ CausalLMOutputWithPast,
34
+ SequenceClassifierOutputWithPast,
35
+ TokenClassifierOutput,
36
+ )
37
+ from transformers.modeling_utils import PreTrainedModel
38
+ from transformers.utils import (
39
+ add_code_sample_docstrings,
40
+ add_start_docstrings,
41
+ add_start_docstrings_to_model_forward,
42
+ is_flash_attn_2_available,
43
+ is_flash_attn_greater_or_equal_2_10,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_phi import PhiConfig
48
+
49
+
50
+ try:
51
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
52
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
53
+ except:
54
+ pass
55
+
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ _CHECKPOINT_FOR_DOC = "microsoft/phi-2"
60
+ _CONFIG_FOR_DOC = "PhiConfig"
61
+
62
+ PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "microsoft/phi-2",
64
+ # See all Phi models at https://huggingface.co/models?filter=phi
65
+ ]
66
+
67
+
68
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
69
+ def _get_unpad_data(attention_mask):
70
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
71
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
72
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
73
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
74
+ return (
75
+ indices,
76
+ cu_seqlens,
77
+ max_seqlen_in_batch,
78
+ )
79
+
80
+
81
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
82
+ class PhiRotaryEmbedding(nn.Module):
83
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
84
+ super().__init__()
85
+
86
+ self.dim = dim
87
+ self.max_position_embeddings = max_position_embeddings
88
+ self.base = base
89
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
90
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
91
+
92
+ # Build here to make `torch.jit.trace` work.
93
+ self._set_cos_sin_cache(
94
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
95
+ )
96
+
97
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
98
+ self.max_seq_len_cached = seq_len
99
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
100
+
101
+ freqs = torch.outer(t, self.inv_freq)
102
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
103
+ emb = torch.cat((freqs, freqs), dim=-1)
104
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
105
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
106
+
107
+ def forward(self, x, seq_len=None):
108
+ # x: [bs, num_attention_heads, seq_len, head_size]
109
+ if seq_len > self.max_seq_len_cached:
110
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
111
+
112
+ return (
113
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
114
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
115
+ )
116
+
117
+
118
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
119
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
120
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
121
+
122
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
123
+ self.scaling_factor = scaling_factor
124
+ super().__init__(dim, max_position_embeddings, base, device)
125
+
126
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
127
+ self.max_seq_len_cached = seq_len
128
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
129
+ t = t / self.scaling_factor
130
+
131
+ freqs = torch.outer(t, self.inv_freq)
132
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
133
+ emb = torch.cat((freqs, freqs), dim=-1)
134
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
135
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
136
+
137
+
138
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
139
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
140
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
141
+
142
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
143
+ self.scaling_factor = scaling_factor
144
+ super().__init__(dim, max_position_embeddings, base, device)
145
+
146
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
147
+ self.max_seq_len_cached = seq_len
148
+
149
+ if seq_len > self.max_position_embeddings:
150
+ base = self.base * (
151
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
152
+ ) ** (self.dim / (self.dim - 2))
153
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
154
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
155
+
156
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
157
+
158
+ freqs = torch.outer(t, self.inv_freq)
159
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
160
+ emb = torch.cat((freqs, freqs), dim=-1)
161
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
162
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
163
+
164
+
165
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
166
+ def rotate_half(x):
167
+ """Rotates half the hidden dims of the input."""
168
+ x1 = x[..., : x.shape[-1] // 2]
169
+ x2 = x[..., x.shape[-1] // 2 :]
170
+ return torch.cat((-x2, x1), dim=-1)
171
+
172
+
173
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
174
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
175
+ """Applies Rotary Position Embedding to the query and key tensors.
176
+ Args:
177
+ q (`torch.Tensor`): The query tensor.
178
+ k (`torch.Tensor`): The key tensor.
179
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
180
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
181
+ position_ids (`torch.Tensor`):
182
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
183
+ used to pass offsetted position ids when working with a KV-cache.
184
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
185
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
186
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
187
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
188
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
189
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
190
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
191
+ Returns:
192
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
193
+ """
194
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
195
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
196
+ q_embed = (q * cos) + (rotate_half(q) * sin)
197
+ k_embed = (k * cos) + (rotate_half(k) * sin)
198
+ return q_embed, k_embed
199
+
200
+
201
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
202
+ class PhiMLP(nn.Module):
203
+ def __init__(self, config):
204
+ super().__init__()
205
+ self.config = config
206
+ self.activation_fn = ACT2FN[config.hidden_act]
207
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
208
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
209
+
210
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
211
+ hidden_states = self.fc1(hidden_states)
212
+ hidden_states = self.activation_fn(hidden_states)
213
+ hidden_states = self.fc2(hidden_states)
214
+ return hidden_states
215
+
216
+
217
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
218
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
219
+ """
220
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
221
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
222
+ """
223
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
224
+ if n_rep == 1:
225
+ return hidden_states
226
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
227
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
228
+
229
+
230
+ class PhiAttention(nn.Module):
231
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
232
+
233
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
234
+ super().__init__()
235
+ self.config = config
236
+ self.layer_idx = layer_idx
237
+ if layer_idx is None:
238
+ logger.warning_once(
239
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
240
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
241
+ "when creating this class."
242
+ )
243
+
244
+ self.attention_dropout = config.attention_dropout
245
+ self.hidden_size = config.hidden_size
246
+ self.num_heads = config.num_attention_heads
247
+ self.head_dim = self.hidden_size // self.num_heads
248
+ self.num_key_value_heads = config.num_key_value_heads
249
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
250
+ self.max_position_embeddings = config.max_position_embeddings
251
+ self.rope_theta = config.rope_theta
252
+ self.partial_rotary_factor = config.partial_rotary_factor
253
+ self.is_causal = True
254
+
255
+ if (self.head_dim * self.num_heads) != self.hidden_size:
256
+ raise ValueError(
257
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
258
+ f" and `num_heads`: {self.num_heads})."
259
+ )
260
+
261
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
262
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
263
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
264
+ self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
265
+
266
+ self.qk_layernorm = config.qk_layernorm
267
+ if self.qk_layernorm:
268
+ self.q_layernorm = nn.LayerNorm(
269
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
270
+ )
271
+ self.k_layernorm = nn.LayerNorm(
272
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
273
+ )
274
+
275
+ self._init_rope()
276
+
277
+ def _init_rope(self):
278
+ if self.config.rope_scaling is None:
279
+ self.rotary_emb = PhiRotaryEmbedding(
280
+ int(self.partial_rotary_factor * self.head_dim),
281
+ max_position_embeddings=self.max_position_embeddings,
282
+ base=self.rope_theta,
283
+ )
284
+ else:
285
+ scaling_type = self.config.rope_scaling["type"]
286
+ scaling_factor = self.config.rope_scaling["factor"]
287
+ if scaling_type == "linear":
288
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
289
+ int(self.partial_rotary_factor * self.head_dim),
290
+ max_position_embeddings=self.max_position_embeddings,
291
+ scaling_factor=scaling_factor,
292
+ base=self.rope_theta,
293
+ )
294
+ elif scaling_type == "dynamic":
295
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
296
+ int(self.partial_rotary_factor * self.head_dim),
297
+ max_position_embeddings=self.max_position_embeddings,
298
+ scaling_factor=scaling_factor,
299
+ base=self.rope_theta,
300
+ )
301
+ else:
302
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
303
+
304
+ # Phi-2 has an attention overflow issue (with FP16) and requires autocast to be disabled
305
+ @torch.autocast("cpu", enabled=False)
306
+ @torch.autocast("cuda", enabled=False)
307
+ def forward(
308
+ self,
309
+ hidden_states: torch.Tensor,
310
+ attention_mask: Optional[torch.Tensor] = None,
311
+ position_ids: Optional[torch.LongTensor] = None,
312
+ past_key_value: Optional[Cache] = None,
313
+ output_attentions: bool = False,
314
+ use_cache: bool = False,
315
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
316
+ bsz, q_len, _ = hidden_states.size()
317
+
318
+ query_states = self.q_proj(hidden_states)
319
+ key_states = self.k_proj(hidden_states)
320
+ value_states = self.v_proj(hidden_states)
321
+
322
+ if self.qk_layernorm:
323
+ query_states = self.q_layernorm(query_states)
324
+ key_states = self.k_layernorm(key_states)
325
+
326
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
327
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
328
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
329
+
330
+ kv_seq_len = key_states.shape[-2]
331
+ if past_key_value is not None:
332
+ if self.layer_idx is None:
333
+ raise ValueError(
334
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
335
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
336
+ "with a layer index."
337
+ )
338
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
339
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
340
+
341
+ # Partial rotary embedding
342
+ query_rot, query_pass = (
343
+ query_states[..., : self.rotary_emb.dim],
344
+ query_states[..., self.rotary_emb.dim :],
345
+ )
346
+ key_rot, key_pass = (
347
+ key_states[..., : self.rotary_emb.dim],
348
+ key_states[..., self.rotary_emb.dim :],
349
+ )
350
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
351
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
352
+
353
+ # [batch_size, seq_length, num_heads, head_dim]
354
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
355
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
356
+
357
+ if past_key_value is not None:
358
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
359
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
360
+
361
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
362
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
363
+
364
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
365
+ attn_weights = torch.matmul(
366
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
367
+ ) / math.sqrt(self.head_dim)
368
+
369
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
370
+ raise ValueError(
371
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
372
+ f" {attn_weights.size()}"
373
+ )
374
+
375
+ if attention_mask is not None:
376
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
377
+ raise ValueError(
378
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
379
+ )
380
+ attn_weights = attn_weights + attention_mask
381
+
382
+ # upcast attention to fp32
383
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
384
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
385
+
386
+ attn_output = torch.matmul(attn_weights, value_states)
387
+
388
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
389
+ raise ValueError(
390
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
391
+ f" {attn_output.size()}"
392
+ )
393
+
394
+ attn_output = attn_output.transpose(1, 2).contiguous()
395
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
396
+
397
+ attn_output = self.dense(attn_output)
398
+
399
+ if not output_attentions:
400
+ attn_weights = None
401
+
402
+ return attn_output, attn_weights, past_key_value
403
+
404
+
405
+ class PhiFlashAttention2(PhiAttention):
406
+ """
407
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
408
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
409
+ flash attention and deal with padding tokens in case the input contains any of them.
410
+ """
411
+
412
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
413
+ def __init__(self, *args, **kwargs):
414
+ super().__init__(*args, **kwargs)
415
+
416
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
417
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
418
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
419
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: torch.Tensor,
424
+ attention_mask: Optional[torch.LongTensor] = None,
425
+ position_ids: Optional[torch.LongTensor] = None,
426
+ past_key_value: Optional[Cache] = None,
427
+ output_attentions: bool = False,
428
+ use_cache: bool = False,
429
+ **kwargs,
430
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
431
+ # PhiFlashAttention2 attention does not support output_attentions
432
+
433
+ output_attentions = False
434
+
435
+ bsz, q_len, _ = hidden_states.size()
436
+
437
+ query_states = self.q_proj(hidden_states)
438
+ key_states = self.k_proj(hidden_states)
439
+ value_states = self.v_proj(hidden_states)
440
+
441
+ if self.qk_layernorm:
442
+ query_states = self.q_layernorm(query_states)
443
+ key_states = self.k_layernorm(key_states)
444
+
445
+ # Flash attention requires the input to have the shape
446
+ # batch_size x seq_length x head_dim x hidden_dim
447
+ # therefore we just need to keep the original shape
448
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
449
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
450
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
451
+
452
+ kv_seq_len = key_states.shape[-2]
453
+ if past_key_value is not None:
454
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
455
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
456
+
457
+ # Partial rotary embedding
458
+ query_rot, query_pass = (
459
+ query_states[..., : self.rotary_emb.dim],
460
+ query_states[..., self.rotary_emb.dim :],
461
+ )
462
+ key_rot, key_pass = (
463
+ key_states[..., : self.rotary_emb.dim],
464
+ key_states[..., self.rotary_emb.dim :],
465
+ )
466
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
467
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
468
+
469
+ # [batch_size, seq_length, num_heads, head_dim]
470
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
471
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
472
+
473
+ if past_key_value is not None:
474
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
475
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
476
+
477
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
478
+ # to be able to avoid many of these transpose/reshape/view.
479
+ query_states = query_states.transpose(1, 2)
480
+ key_states = key_states.transpose(1, 2)
481
+ value_states = value_states.transpose(1, 2)
482
+
483
+ attn_dropout = self.attention_dropout if self.training else 0.0
484
+
485
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
486
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
487
+ # cast them back in the correct dtype just to be sure everything works as expected.
488
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
489
+ # in fp32.
490
+
491
+ if query_states.dtype == torch.float32:
492
+ if torch.is_autocast_enabled():
493
+ target_dtype = torch.get_autocast_gpu_dtype()
494
+ # Handle the case where the model is quantized
495
+ elif hasattr(self.config, "_pre_quantization_dtype"):
496
+ target_dtype = self.config._pre_quantization_dtype
497
+ else:
498
+ target_dtype = self.q_proj.weight.dtype
499
+
500
+ logger.warning_once(
501
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
502
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
503
+ f" {target_dtype}."
504
+ )
505
+
506
+ query_states = query_states.to(target_dtype)
507
+ key_states = key_states.to(target_dtype)
508
+ value_states = value_states.to(target_dtype)
509
+
510
+ attn_output = self._flash_attention_forward(
511
+ query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
512
+ )
513
+
514
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
515
+ attn_output = self.dense(attn_output)
516
+
517
+ if not output_attentions:
518
+ attn_weights = None
519
+
520
+ return attn_output, attn_weights, past_key_value
521
+
522
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
523
+ def _flash_attention_forward(
524
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
525
+ ):
526
+ """
527
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
528
+ first unpad the input, then computes the attention scores and pad the final attention scores.
529
+ Args:
530
+ query_states (`torch.Tensor`):
531
+ Input query states to be passed to Flash Attention API
532
+ key_states (`torch.Tensor`):
533
+ Input key states to be passed to Flash Attention API
534
+ value_states (`torch.Tensor`):
535
+ Input value states to be passed to Flash Attention API
536
+ attention_mask (`torch.Tensor`):
537
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
538
+ position of padding tokens and 1 for the position of non-padding tokens.
539
+ dropout (`int`, *optional*):
540
+ Attention dropout
541
+ softmax_scale (`float`, *optional*):
542
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
543
+ """
544
+ if not self._flash_attn_uses_top_left_mask:
545
+ causal = self.is_causal
546
+ else:
547
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
548
+ causal = self.is_causal and query_length != 1
549
+
550
+ # Contains at least one padding token in the sequence
551
+ if attention_mask is not None:
552
+ batch_size = query_states.shape[0]
553
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
554
+ query_states, key_states, value_states, attention_mask, query_length
555
+ )
556
+
557
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
558
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
559
+
560
+ attn_output_unpad = flash_attn_varlen_func(
561
+ query_states,
562
+ key_states,
563
+ value_states,
564
+ cu_seqlens_q=cu_seqlens_q,
565
+ cu_seqlens_k=cu_seqlens_k,
566
+ max_seqlen_q=max_seqlen_in_batch_q,
567
+ max_seqlen_k=max_seqlen_in_batch_k,
568
+ dropout_p=dropout,
569
+ softmax_scale=softmax_scale,
570
+ causal=causal,
571
+ )
572
+
573
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
574
+ else:
575
+ attn_output = flash_attn_func(
576
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
577
+ )
578
+
579
+ return attn_output
580
+
581
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
582
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
583
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
584
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
585
+
586
+ key_layer = index_first_axis(
587
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
588
+ )
589
+ value_layer = index_first_axis(
590
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
591
+ )
592
+ if query_length == kv_seq_len:
593
+ query_layer = index_first_axis(
594
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
595
+ )
596
+ cu_seqlens_q = cu_seqlens_k
597
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
598
+ indices_q = indices_k
599
+ elif query_length == 1:
600
+ max_seqlen_in_batch_q = 1
601
+ cu_seqlens_q = torch.arange(
602
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
603
+ ) # There is a memcpy here, that is very bad.
604
+ indices_q = cu_seqlens_q[:-1]
605
+ query_layer = query_layer.squeeze(1)
606
+ else:
607
+ # The -q_len: slice assumes left padding.
608
+ attention_mask = attention_mask[:, -query_length:]
609
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
610
+
611
+ return (
612
+ query_layer,
613
+ key_layer,
614
+ value_layer,
615
+ indices_q,
616
+ (cu_seqlens_q, cu_seqlens_k),
617
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
618
+ )
619
+
620
+
621
+ PHI_ATTENTION_CLASSES = {
622
+ "eager": PhiAttention,
623
+ "flash_attention_2": PhiFlashAttention2,
624
+ }
625
+
626
+
627
+ class PhiDecoderLayer(nn.Module):
628
+ def __init__(self, config: PhiConfig, layer_idx: int):
629
+ super().__init__()
630
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
631
+ self.mlp = PhiMLP(config)
632
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
633
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
634
+
635
+ def forward(
636
+ self,
637
+ hidden_states: torch.Tensor,
638
+ attention_mask: Optional[torch.Tensor] = None,
639
+ position_ids: Optional[torch.LongTensor] = None,
640
+ output_attentions: Optional[bool] = False,
641
+ use_cache: Optional[bool] = False,
642
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
643
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
644
+ """
645
+ Args:
646
+ hidden_states (`torch.FloatTensor`):
647
+ input to the layer of shape `(batch, seq_len, embed_dim)`
648
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
649
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
650
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
651
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
652
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
653
+ output_attentions (`bool`, *optional*):
654
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
655
+ returned tensors for more detail.
656
+ use_cache (`bool`, *optional*):
657
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
658
+ (see `past_key_values`).
659
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
660
+ """
661
+
662
+ residual = hidden_states
663
+
664
+ hidden_states = self.input_layernorm(hidden_states)
665
+
666
+ # Self Attention
667
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
668
+ hidden_states=hidden_states,
669
+ attention_mask=attention_mask,
670
+ position_ids=position_ids,
671
+ past_key_value=past_key_value,
672
+ output_attentions=output_attentions,
673
+ use_cache=use_cache,
674
+ )
675
+ attn_outputs = self.resid_dropout(attn_outputs)
676
+
677
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
678
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
679
+ outputs = (hidden_states,)
680
+
681
+ if output_attentions:
682
+ outputs += (self_attn_weights,)
683
+
684
+ if use_cache:
685
+ outputs += (present_key_value,)
686
+
687
+ return outputs
688
+
689
+
690
+ PHI_START_DOCSTRING = r"""
691
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
692
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
693
+ etc.)
694
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
695
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
696
+ and behavior.
697
+ Parameters:
698
+ config ([`PhiConfig`]):
699
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
700
+ load the weights associated with the model, only the configuration. Check out the
701
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
702
+ """
703
+
704
+
705
+ @add_start_docstrings(
706
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
707
+ PHI_START_DOCSTRING,
708
+ )
709
+ class PhiPreTrainedModel(PreTrainedModel):
710
+ config_class = PhiConfig
711
+ base_model_prefix = "model"
712
+ supports_gradient_checkpointing = True
713
+ _no_split_modules = ["PhiDecoderLayer"]
714
+ _skip_keys_device_placement = "past_key_values"
715
+ _supports_flash_attn_2 = True
716
+ _supports_cache_class = True
717
+
718
+ def _init_weights(self, module):
719
+ std = self.config.initializer_range
720
+ if isinstance(module, nn.Linear):
721
+ module.weight.data.normal_(mean=0.0, std=std)
722
+ if module.bias is not None:
723
+ module.bias.data.zero_()
724
+ elif isinstance(module, nn.Embedding):
725
+ module.weight.data.normal_(mean=0.0, std=std)
726
+ if module.padding_idx is not None:
727
+ module.weight.data[module.padding_idx].zero_()
728
+
729
+
730
+ PHI_INPUTS_DOCSTRING = r"""
731
+ Args:
732
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
733
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
734
+ it.
735
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
736
+ [`PreTrainedTokenizer.__call__`] for details.
737
+ [What are input IDs?](../glossary#input-ids)
738
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
739
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
740
+ - 1 for tokens that are **not masked**,
741
+ - 0 for tokens that are **masked**.
742
+ [What are attention masks?](../glossary#attention-mask)
743
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
744
+ [`PreTrainedTokenizer.__call__`] for details.
745
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
746
+ `past_key_values`).
747
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
748
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
749
+ information on the default strategy.
750
+ - 1 indicates the head is **not masked**,
751
+ - 0 indicates the head is **masked**.
752
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
753
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
754
+ config.n_positions - 1]`.
755
+ [What are position IDs?](../glossary#position-ids)
756
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
757
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
758
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
759
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
760
+ Two formats are allowed:
761
+ - a [`~cache_utils.Cache`] instance;
762
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
763
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
764
+ cache format.
765
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
766
+ legacy cache format will be returned.
767
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
768
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
769
+ of shape `(batch_size, sequence_length)`.
770
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
771
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
772
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
773
+ model's internal embedding lookup matrix.
774
+ use_cache (`bool`, *optional*):
775
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
776
+ `past_key_values`).
777
+ output_attentions (`bool`, *optional*):
778
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
779
+ tensors for more detail.
780
+ output_hidden_states (`bool`, *optional*):
781
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
782
+ more detail.
783
+ return_dict (`bool`, *optional*):
784
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
785
+ """
786
+
787
+
788
+ @add_start_docstrings(
789
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
790
+ PHI_START_DOCSTRING,
791
+ )
792
+ class PhiModel(PhiPreTrainedModel):
793
+ """
794
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
795
+ Args:
796
+ config: PhiConfig
797
+ """
798
+
799
+ def __init__(self, config: PhiConfig):
800
+ super().__init__(config)
801
+ self.padding_idx = config.pad_token_id
802
+ self.vocab_size = config.vocab_size
803
+
804
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
805
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
806
+ self.layers = nn.ModuleList(
807
+ [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
808
+ )
809
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
810
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
811
+
812
+ self.gradient_checkpointing = False
813
+ # Initialize weights and apply final processing
814
+ self.post_init()
815
+
816
+ def get_input_embeddings(self):
817
+ return self.embed_tokens
818
+
819
+ def set_input_embeddings(self, value):
820
+ self.embed_tokens = value
821
+
822
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
823
+ def forward(
824
+ self,
825
+ input_ids: torch.LongTensor = None,
826
+ attention_mask: Optional[torch.Tensor] = None,
827
+ position_ids: Optional[torch.LongTensor] = None,
828
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
829
+ inputs_embeds: Optional[torch.FloatTensor] = None,
830
+ use_cache: Optional[bool] = None,
831
+ output_attentions: Optional[bool] = None,
832
+ output_hidden_states: Optional[bool] = None,
833
+ return_dict: Optional[bool] = None,
834
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
835
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
836
+ output_hidden_states = (
837
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
838
+ )
839
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
840
+
841
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
842
+
843
+ # retrieve input_ids and inputs_embeds
844
+ if input_ids is not None and inputs_embeds is not None:
845
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
846
+ elif input_ids is not None:
847
+ batch_size, seq_length = input_ids.shape[:2]
848
+ elif inputs_embeds is not None:
849
+ batch_size, seq_length = inputs_embeds.shape[:2]
850
+ else:
851
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
852
+
853
+ past_key_values_length = 0
854
+
855
+ if self.gradient_checkpointing and self.training:
856
+ if use_cache:
857
+ logger.warning_once(
858
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
859
+ )
860
+ use_cache = False
861
+
862
+ if use_cache:
863
+ use_legacy_cache = not isinstance(past_key_values, Cache)
864
+ if use_legacy_cache:
865
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
866
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
867
+
868
+ if position_ids is None:
869
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
870
+ position_ids = torch.arange(
871
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
872
+ )
873
+ position_ids = position_ids.unsqueeze(0)
874
+
875
+ if inputs_embeds is None:
876
+ inputs_embeds = self.embed_tokens(input_ids)
877
+
878
+ inputs_embeds = self.embed_dropout(inputs_embeds)
879
+
880
+ # Attention mask.
881
+ if self._use_flash_attention_2:
882
+ # 2d mask is passed through the layers
883
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
884
+ else:
885
+ # 4d mask is passed through the layers
886
+ attention_mask = _prepare_4d_causal_attention_mask(
887
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
888
+ )
889
+
890
+ hidden_states = inputs_embeds
891
+
892
+ # decoder layers
893
+ all_hidden_states = () if output_hidden_states else None
894
+ all_self_attns = () if output_attentions else None
895
+ next_decoder_cache = None
896
+
897
+ for decoder_layer in self.layers:
898
+ if output_hidden_states:
899
+ all_hidden_states += (hidden_states,)
900
+
901
+ if self.gradient_checkpointing and self.training:
902
+ layer_outputs = self._gradient_checkpointing_func(
903
+ decoder_layer.__call__,
904
+ hidden_states,
905
+ attention_mask,
906
+ position_ids,
907
+ past_key_values,
908
+ output_attentions,
909
+ )
910
+ else:
911
+ layer_outputs = decoder_layer(
912
+ hidden_states,
913
+ attention_mask=attention_mask,
914
+ position_ids=position_ids,
915
+ past_key_value=past_key_values,
916
+ output_attentions=output_attentions,
917
+ use_cache=use_cache,
918
+ )
919
+
920
+ hidden_states = layer_outputs[0]
921
+
922
+ if use_cache:
923
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
924
+
925
+ if output_attentions:
926
+ all_self_attns += (layer_outputs[1],)
927
+
928
+ hidden_states = self.final_layernorm(hidden_states)
929
+
930
+ # add hidden states from the last decoder layer
931
+ if output_hidden_states:
932
+ all_hidden_states += (hidden_states,)
933
+
934
+ next_cache = None
935
+ if use_cache:
936
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
937
+ if not return_dict:
938
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
939
+ return BaseModelOutputWithPast(
940
+ last_hidden_state=hidden_states,
941
+ past_key_values=next_cache,
942
+ hidden_states=all_hidden_states,
943
+ attentions=all_self_attns,
944
+ )
945
+
946
+
947
+ class PhiForCausalLM(PhiPreTrainedModel):
948
+ _tied_weights_keys = ["lm_head.weight"]
949
+
950
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
951
+ def __init__(self, config):
952
+ super().__init__(config)
953
+ self.model = PhiModel(config)
954
+ self.vocab_size = config.vocab_size
955
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
956
+
957
+ # Initialize weights and apply final processing
958
+ self.post_init()
959
+
960
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
961
+ def get_input_embeddings(self):
962
+ return self.model.embed_tokens
963
+
964
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
965
+ def set_input_embeddings(self, value):
966
+ self.model.embed_tokens = value
967
+
968
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
969
+ def get_output_embeddings(self):
970
+ return self.lm_head
971
+
972
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
973
+ def set_output_embeddings(self, new_embeddings):
974
+ self.lm_head = new_embeddings
975
+
976
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
977
+ def set_decoder(self, decoder):
978
+ self.model = decoder
979
+
980
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
981
+ def get_decoder(self):
982
+ return self.model
983
+
984
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
985
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
986
+ def forward(
987
+ self,
988
+ input_ids: torch.LongTensor = None,
989
+ attention_mask: Optional[torch.Tensor] = None,
990
+ position_ids: Optional[torch.LongTensor] = None,
991
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
992
+ inputs_embeds: Optional[torch.FloatTensor] = None,
993
+ labels: Optional[torch.LongTensor] = None,
994
+ use_cache: Optional[bool] = None,
995
+ output_attentions: Optional[bool] = None,
996
+ output_hidden_states: Optional[bool] = None,
997
+ return_dict: Optional[bool] = None,
998
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
999
+ r"""
1000
+ Args:
1001
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1002
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1003
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1004
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1005
+ Returns:
1006
+ Example:
1007
+ ```python
1008
+ >>> from transformers import AutoTokenizer, PhiForCausalLM
1009
+ >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1010
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1011
+ >>> prompt = "This is an example script ."
1012
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1013
+ >>> # Generate
1014
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1015
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1016
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1017
+ ```"""
1018
+
1019
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1020
+ output_hidden_states = (
1021
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1022
+ )
1023
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1024
+
1025
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1026
+ outputs = self.model(
1027
+ input_ids=input_ids,
1028
+ attention_mask=attention_mask,
1029
+ position_ids=position_ids,
1030
+ past_key_values=past_key_values,
1031
+ inputs_embeds=inputs_embeds,
1032
+ use_cache=use_cache,
1033
+ output_attentions=output_attentions,
1034
+ output_hidden_states=output_hidden_states,
1035
+ return_dict=return_dict,
1036
+ )
1037
+
1038
+ hidden_states = outputs[0]
1039
+ logits = self.lm_head(hidden_states)
1040
+ logits = logits.float()
1041
+
1042
+ loss = None
1043
+ if labels is not None:
1044
+ # Shift so that tokens < n predict n
1045
+ shift_logits = logits[..., :-1, :].contiguous()
1046
+ shift_labels = labels[..., 1:].contiguous()
1047
+ # import pdb;pdb.set_trace()
1048
+ # Flatten the tokens
1049
+ loss_fct = CrossEntropyLoss()
1050
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1051
+ shift_labels = shift_labels.view(-1)
1052
+ # Enable model parallelism
1053
+ shift_labels = shift_labels.to(shift_logits.device)
1054
+ loss = loss_fct(shift_logits, shift_labels)
1055
+
1056
+ if not return_dict:
1057
+ output = (logits,) + outputs[1:]
1058
+ return (loss,) + output if loss is not None else output
1059
+
1060
+ return CausalLMOutputWithPast(
1061
+ loss=loss,
1062
+ logits=logits,
1063
+ past_key_values=outputs.past_key_values,
1064
+ hidden_states=outputs.hidden_states,
1065
+ attentions=outputs.attentions,
1066
+ )
1067
+
1068
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1069
+ def prepare_inputs_for_generation(
1070
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1071
+ ):
1072
+ if past_key_values is not None:
1073
+ if isinstance(past_key_values, Cache):
1074
+ cache_length = past_key_values.get_seq_length()
1075
+ past_length = past_key_values.seen_tokens
1076
+ max_cache_length = past_key_values.get_max_length()
1077
+ else:
1078
+ cache_length = past_length = past_key_values[0][0].shape[2]
1079
+ max_cache_length = None
1080
+
1081
+ # Keep only the unprocessed tokens:
1082
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1083
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1084
+ # input)
1085
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1086
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1087
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1088
+ # input_ids based on the past_length.
1089
+ elif past_length < input_ids.shape[1]:
1090
+ input_ids = input_ids[:, past_length:]
1091
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1092
+
1093
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1094
+ if (
1095
+ max_cache_length is not None
1096
+ and attention_mask is not None
1097
+ and cache_length + input_ids.shape[1] > max_cache_length
1098
+ ):
1099
+ attention_mask = attention_mask[:, -max_cache_length:]
1100
+
1101
+ position_ids = kwargs.get("position_ids", None)
1102
+ if attention_mask is not None and position_ids is None:
1103
+ # create position_ids on the fly for batch generation
1104
+ position_ids = attention_mask.long().cumsum(-1) - 1
1105
+ position_ids.masked_fill_(attention_mask == 0, 1)
1106
+ if past_key_values:
1107
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1108
+
1109
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1110
+ if inputs_embeds is not None and past_key_values is None:
1111
+ model_inputs = {"inputs_embeds": inputs_embeds}
1112
+ else:
1113
+ model_inputs = {"input_ids": input_ids}
1114
+
1115
+ model_inputs.update(
1116
+ {
1117
+ "position_ids": position_ids,
1118
+ "past_key_values": past_key_values,
1119
+ "use_cache": kwargs.get("use_cache"),
1120
+ "attention_mask": attention_mask,
1121
+ }
1122
+ )
1123
+ return model_inputs
1124
+
1125
+ @staticmethod
1126
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1127
+ def _reorder_cache(past_key_values, beam_idx):
1128
+ reordered_past = ()
1129
+ for layer_past in past_key_values:
1130
+ reordered_past += (
1131
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1132
+ )
1133
+ return reordered_past
1134
+
1135
+
1136
+ @add_start_docstrings(
1137
+ """
1138
+ The PhiModel with a sequence classification head on top (linear layer).
1139
+ [`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1140
+ (e.g. GPT-2) do.
1141
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1142
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1143
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1144
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1145
+ each row of the batch).
1146
+ """,
1147
+ PHI_START_DOCSTRING,
1148
+ )
1149
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
1150
+ class PhiForSequenceClassification(PhiPreTrainedModel):
1151
+ def __init__(self, config):
1152
+ super().__init__(config)
1153
+ self.num_labels = config.num_labels
1154
+ self.model = PhiModel(config)
1155
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1156
+
1157
+ # Initialize weights and apply final processing
1158
+ self.post_init()
1159
+
1160
+ def get_input_embeddings(self):
1161
+ return self.model.embed_tokens
1162
+
1163
+ def set_input_embeddings(self, value):
1164
+ self.model.embed_tokens = value
1165
+
1166
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1167
+ def forward(
1168
+ self,
1169
+ input_ids: torch.LongTensor = None,
1170
+ attention_mask: Optional[torch.Tensor] = None,
1171
+ position_ids: Optional[torch.LongTensor] = None,
1172
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1173
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1174
+ labels: Optional[torch.LongTensor] = None,
1175
+ use_cache: Optional[bool] = None,
1176
+ output_attentions: Optional[bool] = None,
1177
+ output_hidden_states: Optional[bool] = None,
1178
+ return_dict: Optional[bool] = None,
1179
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1180
+ r"""
1181
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1182
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1183
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1184
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1185
+ """
1186
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1187
+
1188
+ model_outputs = self.model(
1189
+ input_ids,
1190
+ attention_mask=attention_mask,
1191
+ position_ids=position_ids,
1192
+ past_key_values=past_key_values,
1193
+ inputs_embeds=inputs_embeds,
1194
+ use_cache=use_cache,
1195
+ output_attentions=output_attentions,
1196
+ output_hidden_states=output_hidden_states,
1197
+ return_dict=return_dict,
1198
+ )
1199
+ hidden_states = model_outputs[0]
1200
+ logits = self.score(hidden_states)
1201
+
1202
+ if input_ids is not None:
1203
+ batch_size = input_ids.shape[0]
1204
+ else:
1205
+ batch_size = inputs_embeds.shape[0]
1206
+
1207
+ if self.config.pad_token_id is None and batch_size != 1:
1208
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1209
+ if self.config.pad_token_id is None:
1210
+ sequence_lengths = -1
1211
+ else:
1212
+ if input_ids is not None:
1213
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1214
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1215
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1216
+ sequence_lengths = sequence_lengths.to(logits.device)
1217
+ else:
1218
+ sequence_lengths = -1
1219
+
1220
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1221
+
1222
+ loss = None
1223
+ if labels is not None:
1224
+ labels = labels.to(logits.device)
1225
+ if self.config.problem_type is None:
1226
+ if self.num_labels == 1:
1227
+ self.config.problem_type = "regression"
1228
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1229
+ self.config.problem_type = "single_label_classification"
1230
+ else:
1231
+ self.config.problem_type = "multi_label_classification"
1232
+
1233
+ if self.config.problem_type == "regression":
1234
+ loss_fct = MSELoss()
1235
+ if self.num_labels == 1:
1236
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1237
+ else:
1238
+ loss = loss_fct(pooled_logits, labels)
1239
+ elif self.config.problem_type == "single_label_classification":
1240
+ loss_fct = CrossEntropyLoss()
1241
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1242
+ elif self.config.problem_type == "multi_label_classification":
1243
+ loss_fct = BCEWithLogitsLoss()
1244
+ loss = loss_fct(pooled_logits, labels)
1245
+ if not return_dict:
1246
+ output = (pooled_logits,) + model_outputs[1:]
1247
+ return ((loss,) + output) if loss is not None else output
1248
+
1249
+ return SequenceClassifierOutputWithPast(
1250
+ loss=loss,
1251
+ logits=pooled_logits,
1252
+ past_key_values=model_outputs.past_key_values,
1253
+ hidden_states=model_outputs.hidden_states,
1254
+ attentions=model_outputs.attentions,
1255
+ )
1256
+
1257
+
1258
+ @add_start_docstrings(
1259
+ """
1260
+ PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1261
+ Named-Entity-Recognition (NER) tasks.
1262
+ """,
1263
+ PHI_START_DOCSTRING,
1264
+ )
1265
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs
1266
+ class PhiForTokenClassification(PhiPreTrainedModel):
1267
+ def __init__(self, config: PhiConfig):
1268
+ super().__init__(config)
1269
+ self.num_labels = config.num_labels
1270
+
1271
+ self.model = PhiModel(config)
1272
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1273
+ classifier_dropout = config.classifier_dropout
1274
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1275
+ classifier_dropout = config.hidden_dropout
1276
+ else:
1277
+ classifier_dropout = 0.1
1278
+ self.dropout = nn.Dropout(classifier_dropout)
1279
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1280
+
1281
+ # Initialize weights and apply final processing
1282
+ self.post_init()
1283
+
1284
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1285
+ @add_code_sample_docstrings(
1286
+ checkpoint=_CHECKPOINT_FOR_DOC,
1287
+ output_type=TokenClassifierOutput,
1288
+ config_class=_CONFIG_FOR_DOC,
1289
+ )
1290
+ def forward(
1291
+ self,
1292
+ input_ids: Optional[torch.LongTensor] = None,
1293
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1294
+ attention_mask: Optional[torch.Tensor] = None,
1295
+ inputs_embeds: Optional[torch.Tensor] = None,
1296
+ labels: Optional[torch.Tensor] = None,
1297
+ use_cache: Optional[bool] = None,
1298
+ output_attentions: Optional[bool] = None,
1299
+ output_hidden_states: Optional[bool] = None,
1300
+ return_dict: Optional[bool] = None,
1301
+ **deprecated_arguments,
1302
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1303
+ r"""
1304
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1305
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1306
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1307
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1308
+ """
1309
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1310
+
1311
+ model_outputs = self.model(
1312
+ input_ids,
1313
+ past_key_values=past_key_values,
1314
+ attention_mask=attention_mask,
1315
+ inputs_embeds=inputs_embeds,
1316
+ use_cache=use_cache,
1317
+ output_attentions=output_attentions,
1318
+ output_hidden_states=output_hidden_states,
1319
+ return_dict=return_dict,
1320
+ )
1321
+
1322
+ hidden_states = model_outputs[0]
1323
+ hidden_states = self.dropout(hidden_states)
1324
+ logits = self.classifier(hidden_states)
1325
+
1326
+ loss = None
1327
+ if labels is not None:
1328
+ # move labels to correct device to enable model parallelism
1329
+ labels = labels.to(logits.device)
1330
+ batch_size, seq_length = labels.shape
1331
+ loss_fct = CrossEntropyLoss()
1332
+ loss = loss_fct(
1333
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1334
+ )
1335
+
1336
+ if not return_dict:
1337
+ output = (logits,) + model_outputs[2:]
1338
+ return ((loss,) + output) if loss is not None else output
1339
+
1340
+ return TokenClassifierOutput(
1341
+ loss=loss,
1342
+ logits=logits,
1343
+ hidden_states=model_outputs.hidden_states,
1344
+ attentions=model_outputs.attentions,
1345
+ )
tinychart/model/language_model/phi/utils.py ADDED
@@ -0,0 +1,1428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Import utilities: Utilities related to imports and our lazy inits.
16
+ """
17
+
18
+ import importlib.metadata
19
+ import importlib.util
20
+ import json
21
+ import os
22
+ import shutil
23
+ import subprocess
24
+ import sys
25
+ import warnings
26
+ from collections import OrderedDict
27
+ from functools import lru_cache, wraps
28
+ from itertools import chain
29
+ from types import ModuleType
30
+ from typing import Any, Tuple, Union
31
+
32
+ from packaging import version
33
+
34
+ from transformers import logging
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ # TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
41
+ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
42
+ # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
43
+ package_exists = importlib.util.find_spec(pkg_name) is not None
44
+ package_version = "N/A"
45
+ if package_exists:
46
+ try:
47
+ package_version = importlib.metadata.version(pkg_name)
48
+ package_exists = True
49
+ except importlib.metadata.PackageNotFoundError:
50
+ package_exists = False
51
+ logger.debug(f"Detected {pkg_name} version {package_version}")
52
+ if return_version:
53
+ return package_exists, package_version
54
+ else:
55
+ return package_exists
56
+
57
+
58
+ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
59
+ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
60
+
61
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
62
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
63
+ USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
64
+
65
+ FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
66
+
67
+ # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
68
+ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
69
+
70
+ ACCELERATE_MIN_VERSION = "0.21.0"
71
+ FSDP_MIN_VERSION = "1.12.0"
72
+
73
+
74
+ _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
75
+ _apex_available = _is_package_available("apex")
76
+ _bitsandbytes_available = _is_package_available("bitsandbytes")
77
+ # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
78
+ _bs4_available = importlib.util.find_spec("bs4") is not None
79
+ _coloredlogs_available = _is_package_available("coloredlogs")
80
+ # `importlib.metadata.util` doesn't work with `opencv-python-headless`.
81
+ _cv2_available = importlib.util.find_spec("cv2") is not None
82
+ _datasets_available = _is_package_available("datasets")
83
+ _decord_available = importlib.util.find_spec("decord") is not None
84
+ _detectron2_available = _is_package_available("detectron2")
85
+ # We need to check both `faiss` and `faiss-cpu`.
86
+ _faiss_available = importlib.util.find_spec("faiss") is not None
87
+ try:
88
+ _faiss_version = importlib.metadata.version("faiss")
89
+ logger.debug(f"Successfully imported faiss version {_faiss_version}")
90
+ except importlib.metadata.PackageNotFoundError:
91
+ try:
92
+ _faiss_version = importlib.metadata.version("faiss-cpu")
93
+ logger.debug(f"Successfully imported faiss version {_faiss_version}")
94
+ except importlib.metadata.PackageNotFoundError:
95
+ _faiss_available = False
96
+ _ftfy_available = _is_package_available("ftfy")
97
+ _g2p_en_available = _is_package_available("g2p_en")
98
+ _ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
99
+ _jieba_available = _is_package_available("jieba")
100
+ _jinja_available = _is_package_available("jinja2")
101
+ _kenlm_available = _is_package_available("kenlm")
102
+ _keras_nlp_available = _is_package_available("keras_nlp")
103
+ _levenshtein_available = _is_package_available("Levenshtein")
104
+ _librosa_available = _is_package_available("librosa")
105
+ _natten_available = _is_package_available("natten")
106
+ _nltk_available = _is_package_available("nltk")
107
+ _onnx_available = _is_package_available("onnx")
108
+ _openai_available = _is_package_available("openai")
109
+ _optimum_available = _is_package_available("optimum")
110
+ _auto_gptq_available = _is_package_available("auto_gptq")
111
+ # `importlib.metadata.version` doesn't work with `awq`
112
+ _auto_awq_available = importlib.util.find_spec("awq") is not None
113
+ _pandas_available = _is_package_available("pandas")
114
+ _peft_available = _is_package_available("peft")
115
+ _phonemizer_available = _is_package_available("phonemizer")
116
+ _psutil_available = _is_package_available("psutil")
117
+ _py3nvml_available = _is_package_available("py3nvml")
118
+ _pyctcdecode_available = _is_package_available("pyctcdecode")
119
+ _pytesseract_available = _is_package_available("pytesseract")
120
+ _pytest_available = _is_package_available("pytest")
121
+ _pytorch_quantization_available = _is_package_available("pytorch_quantization")
122
+ _rjieba_available = _is_package_available("rjieba")
123
+ _sacremoses_available = _is_package_available("sacremoses")
124
+ _safetensors_available = _is_package_available("safetensors")
125
+ _scipy_available = _is_package_available("scipy")
126
+ _sentencepiece_available = _is_package_available("sentencepiece")
127
+ _is_seqio_available = _is_package_available("seqio")
128
+ _sklearn_available = importlib.util.find_spec("sklearn") is not None
129
+ if _sklearn_available:
130
+ try:
131
+ importlib.metadata.version("scikit-learn")
132
+ except importlib.metadata.PackageNotFoundError:
133
+ _sklearn_available = False
134
+ _smdistributed_available = importlib.util.find_spec("smdistributed") is not None
135
+ _soundfile_available = _is_package_available("soundfile")
136
+ _spacy_available = _is_package_available("spacy")
137
+ _sudachipy_available = _is_package_available("sudachipy")
138
+ _tensorflow_probability_available = _is_package_available("tensorflow_probability")
139
+ _tensorflow_text_available = _is_package_available("tensorflow_text")
140
+ _tf2onnx_available = _is_package_available("tf2onnx")
141
+ _timm_available = _is_package_available("timm")
142
+ _tokenizers_available = _is_package_available("tokenizers")
143
+ _torchaudio_available = _is_package_available("torchaudio")
144
+ _torchdistx_available = _is_package_available("torchdistx")
145
+ _torchvision_available = _is_package_available("torchvision")
146
+
147
+
148
+ _torch_version = "N/A"
149
+ _torch_available = False
150
+ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
151
+ _torch_available, _torch_version = _is_package_available("torch", return_version=True)
152
+ else:
153
+ logger.info("Disabling PyTorch because USE_TF is set")
154
+ _torch_available = False
155
+
156
+
157
+ _tf_version = "N/A"
158
+ _tf_available = False
159
+ if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES:
160
+ _tf_available = True
161
+ else:
162
+ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
163
+ # Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below
164
+ # with tensorflow-cpu to make sure it still works!
165
+ _tf_available = importlib.util.find_spec("tensorflow") is not None
166
+ if _tf_available:
167
+ candidates = (
168
+ "tensorflow",
169
+ "tensorflow-cpu",
170
+ "tensorflow-gpu",
171
+ "tf-nightly",
172
+ "tf-nightly-cpu",
173
+ "tf-nightly-gpu",
174
+ "tf-nightly-rocm",
175
+ "intel-tensorflow",
176
+ "intel-tensorflow-avx512",
177
+ "tensorflow-rocm",
178
+ "tensorflow-macos",
179
+ "tensorflow-aarch64",
180
+ )
181
+ _tf_version = None
182
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
183
+ for pkg in candidates:
184
+ try:
185
+ _tf_version = importlib.metadata.version(pkg)
186
+ break
187
+ except importlib.metadata.PackageNotFoundError:
188
+ pass
189
+ _tf_available = _tf_version is not None
190
+ if _tf_available:
191
+ if version.parse(_tf_version) < version.parse("2"):
192
+ logger.info(
193
+ f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum."
194
+ )
195
+ _tf_available = False
196
+ else:
197
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
198
+
199
+
200
+ _essentia_available = importlib.util.find_spec("essentia") is not None
201
+ try:
202
+ _essentia_version = importlib.metadata.version("essentia")
203
+ logger.debug(f"Successfully imported essentia version {_essentia_version}")
204
+ except importlib.metadata.PackageNotFoundError:
205
+ _essentia_version = False
206
+
207
+
208
+ _pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
209
+ try:
210
+ _pretty_midi_version = importlib.metadata.version("pretty_midi")
211
+ logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}")
212
+ except importlib.metadata.PackageNotFoundError:
213
+ _pretty_midi_available = False
214
+
215
+
216
+ ccl_version = "N/A"
217
+ _is_ccl_available = (
218
+ importlib.util.find_spec("torch_ccl") is not None
219
+ or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
220
+ )
221
+ try:
222
+ ccl_version = importlib.metadata.version("oneccl_bind_pt")
223
+ logger.debug(f"Detected oneccl_bind_pt version {ccl_version}")
224
+ except importlib.metadata.PackageNotFoundError:
225
+ _is_ccl_available = False
226
+
227
+
228
+ _flax_available = False
229
+ if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
230
+ _flax_available, _flax_version = _is_package_available("flax", return_version=True)
231
+ if _flax_available:
232
+ _jax_available, _jax_version = _is_package_available("jax", return_version=True)
233
+ if _jax_available:
234
+ logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
235
+ else:
236
+ _flax_available = _jax_available = False
237
+ _jax_version = _flax_version = "N/A"
238
+
239
+
240
+ _torch_fx_available = False
241
+ if _torch_available:
242
+ torch_version = version.parse(_torch_version)
243
+ _torch_fx_available = (torch_version.major, torch_version.minor) >= (
244
+ TORCH_FX_REQUIRED_VERSION.major,
245
+ TORCH_FX_REQUIRED_VERSION.minor,
246
+ )
247
+
248
+
249
+ def is_kenlm_available():
250
+ return _kenlm_available
251
+
252
+
253
+ def is_cv2_available():
254
+ return _cv2_available
255
+
256
+
257
+ def is_torch_available():
258
+ return _torch_available
259
+
260
+
261
+ def get_torch_version():
262
+ return _torch_version
263
+
264
+
265
+ def is_torch_sdpa_available():
266
+ if not is_torch_available():
267
+ return False
268
+ elif _torch_version == "N/A":
269
+ return False
270
+
271
+ # NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
272
+ # - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
273
+ # - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
274
+ # NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
275
+ return version.parse(_torch_version) >= version.parse("2.1.1")
276
+
277
+
278
+ def is_torchvision_available():
279
+ return _torchvision_available
280
+
281
+
282
+ def is_pyctcdecode_available():
283
+ return _pyctcdecode_available
284
+
285
+
286
+ def is_librosa_available():
287
+ return _librosa_available
288
+
289
+
290
+ def is_essentia_available():
291
+ return _essentia_available
292
+
293
+
294
+ def is_pretty_midi_available():
295
+ return _pretty_midi_available
296
+
297
+
298
+ def is_torch_cuda_available():
299
+ if is_torch_available():
300
+ import torch
301
+
302
+ return torch.cuda.is_available()
303
+ else:
304
+ return False
305
+
306
+
307
+ def is_torch_mps_available():
308
+ if is_torch_available():
309
+ import torch
310
+
311
+ if hasattr(torch.backends, "mps"):
312
+ return torch.backends.mps.is_available()
313
+ return False
314
+
315
+
316
+ def is_torch_bf16_gpu_available():
317
+ if not is_torch_available():
318
+ return False
319
+
320
+ import torch
321
+
322
+ return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
323
+
324
+
325
+ def is_torch_bf16_cpu_available():
326
+ if not is_torch_available():
327
+ return False
328
+
329
+ import torch
330
+
331
+ try:
332
+ # multiple levels of AttributeError depending on the pytorch version so do them all in one check
333
+ _ = torch.cpu.amp.autocast
334
+ except AttributeError:
335
+ return False
336
+
337
+ return True
338
+
339
+
340
+ def is_torch_bf16_available():
341
+ # the original bf16 check was for gpu only, but later a cpu/bf16 combo has emerged so this util
342
+ # has become ambiguous and therefore deprecated
343
+ warnings.warn(
344
+ "The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available "
345
+ "or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu",
346
+ FutureWarning,
347
+ )
348
+ return is_torch_bf16_gpu_available()
349
+
350
+
351
+ @lru_cache()
352
+ def is_torch_fp16_available_on_device(device):
353
+ if not is_torch_available():
354
+ return False
355
+
356
+ import torch
357
+
358
+ try:
359
+ x = torch.zeros(2, 2, dtype=torch.float16).to(device)
360
+ _ = x @ x
361
+ except: # noqa: E722
362
+ # TODO: more precise exception matching, if possible.
363
+ # most backends should return `RuntimeError` however this is not guaranteed.
364
+ return False
365
+
366
+ return True
367
+
368
+
369
+ @lru_cache()
370
+ def is_torch_bf16_available_on_device(device):
371
+ if not is_torch_available():
372
+ return False
373
+
374
+ import torch
375
+
376
+ if device == "cuda":
377
+ return is_torch_bf16_gpu_available()
378
+
379
+ try:
380
+ x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
381
+ _ = x @ x
382
+ except: # noqa: E722
383
+ # TODO: more precise exception matching, if possible.
384
+ # most backends should return `RuntimeError` however this is not guaranteed.
385
+ return False
386
+
387
+ return True
388
+
389
+
390
+ def is_torch_tf32_available():
391
+ if not is_torch_available():
392
+ return False
393
+
394
+ import torch
395
+
396
+ if not torch.cuda.is_available() or torch.version.cuda is None:
397
+ return False
398
+ if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
399
+ return False
400
+ if int(torch.version.cuda.split(".")[0]) < 11:
401
+ return False
402
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
403
+ return False
404
+
405
+ return True
406
+
407
+
408
+ def is_torch_fx_available():
409
+ return _torch_fx_available
410
+
411
+
412
+ def is_peft_available():
413
+ return _peft_available
414
+
415
+
416
+ def is_bs4_available():
417
+ return _bs4_available
418
+
419
+
420
+ def is_tf_available():
421
+ return _tf_available
422
+
423
+
424
+ def is_coloredlogs_available():
425
+ return _coloredlogs_available
426
+
427
+
428
+ def is_tf2onnx_available():
429
+ return _tf2onnx_available
430
+
431
+
432
+ def is_onnx_available():
433
+ return _onnx_available
434
+
435
+
436
+ def is_openai_available():
437
+ return _openai_available
438
+
439
+
440
+ def is_flax_available():
441
+ return _flax_available
442
+
443
+
444
+ def is_ftfy_available():
445
+ return _ftfy_available
446
+
447
+
448
+ def is_g2p_en_available():
449
+ return _g2p_en_available
450
+
451
+
452
+ @lru_cache()
453
+ def is_torch_tpu_available(check_device=True):
454
+ "Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
455
+ if not _torch_available:
456
+ return False
457
+ if importlib.util.find_spec("torch_xla") is not None:
458
+ if check_device:
459
+ # We need to check if `xla_device` can be found, will raise a RuntimeError if not
460
+ try:
461
+ import torch_xla.core.xla_model as xm
462
+
463
+ _ = xm.xla_device()
464
+ return True
465
+ except RuntimeError:
466
+ return False
467
+ return True
468
+ return False
469
+
470
+
471
+ @lru_cache()
472
+ def is_torch_neuroncore_available(check_device=True):
473
+ if importlib.util.find_spec("torch_neuronx") is not None:
474
+ return is_torch_tpu_available(check_device)
475
+ return False
476
+
477
+
478
+ @lru_cache()
479
+ def is_torch_npu_available(check_device=False):
480
+ "Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
481
+ if not _torch_available or importlib.util.find_spec("torch_npu") is None:
482
+ return False
483
+
484
+ import torch
485
+ import torch_npu # noqa: F401
486
+
487
+ if check_device:
488
+ try:
489
+ # Will raise a RuntimeError if no NPU is found
490
+ _ = torch.npu.device_count()
491
+ return torch.npu.is_available()
492
+ except RuntimeError:
493
+ return False
494
+ return hasattr(torch, "npu") and torch.npu.is_available()
495
+
496
+
497
+ def is_torchdynamo_available():
498
+ if not is_torch_available():
499
+ return False
500
+ try:
501
+ import torch._dynamo as dynamo # noqa: F401
502
+
503
+ return True
504
+ except Exception:
505
+ return False
506
+
507
+
508
+ def is_torch_compile_available():
509
+ if not is_torch_available():
510
+ return False
511
+
512
+ import torch
513
+
514
+ # We don't do any version check here to support nighlies marked as 1.14. Ultimately needs to check version against
515
+ # 2.0 but let's do it later.
516
+ return hasattr(torch, "compile")
517
+
518
+
519
+ def is_torchdynamo_compiling():
520
+ if not is_torch_available():
521
+ return False
522
+ try:
523
+ import torch._dynamo as dynamo # noqa: F401
524
+
525
+ return dynamo.is_compiling()
526
+ except Exception:
527
+ return False
528
+
529
+
530
+ def is_torch_tensorrt_fx_available():
531
+ if importlib.util.find_spec("torch_tensorrt") is None:
532
+ return False
533
+ return importlib.util.find_spec("torch_tensorrt.fx") is not None
534
+
535
+
536
+ def is_datasets_available():
537
+ return _datasets_available
538
+
539
+
540
+ def is_detectron2_available():
541
+ return _detectron2_available
542
+
543
+
544
+ def is_rjieba_available():
545
+ return _rjieba_available
546
+
547
+
548
+ def is_psutil_available():
549
+ return _psutil_available
550
+
551
+
552
+ def is_py3nvml_available():
553
+ return _py3nvml_available
554
+
555
+
556
+ def is_sacremoses_available():
557
+ return _sacremoses_available
558
+
559
+
560
+ def is_apex_available():
561
+ return _apex_available
562
+
563
+
564
+ def is_ninja_available():
565
+ r"""
566
+ Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
567
+ [ninja](https://ninja-build.org/) build system is available on the system, `False` otherwise.
568
+ """
569
+ try:
570
+ subprocess.check_output("ninja --version".split())
571
+ except Exception:
572
+ return False
573
+ else:
574
+ return True
575
+
576
+
577
+ def is_ipex_available():
578
+ def get_major_and_minor_from_version(full_version):
579
+ return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
580
+
581
+ if not is_torch_available() or not _ipex_available:
582
+ return False
583
+
584
+ torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
585
+ ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
586
+ if torch_major_and_minor != ipex_major_and_minor:
587
+ logger.warning(
588
+ f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
589
+ f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
590
+ )
591
+ return False
592
+ return True
593
+
594
+
595
+ @lru_cache
596
+ def is_torch_xpu_available(check_device=False):
597
+ "Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment"
598
+ if not is_ipex_available():
599
+ return False
600
+
601
+ import intel_extension_for_pytorch # noqa: F401
602
+ import torch
603
+
604
+ if check_device:
605
+ try:
606
+ # Will raise a RuntimeError if no XPU is found
607
+ _ = torch.xpu.device_count()
608
+ return torch.xpu.is_available()
609
+ except RuntimeError:
610
+ return False
611
+ return hasattr(torch, "xpu") and torch.xpu.is_available()
612
+
613
+
614
+ def is_bitsandbytes_available():
615
+ if not is_torch_available():
616
+ return False
617
+
618
+ # bitsandbytes throws an error if cuda is not available
619
+ # let's avoid that by adding a simple check
620
+ import torch
621
+
622
+ return _bitsandbytes_available and torch.cuda.is_available()
623
+
624
+
625
+ def is_flash_attn_2_available():
626
+ if not is_torch_available():
627
+ return False
628
+
629
+ if not _is_package_available("flash_attn"):
630
+ return False
631
+
632
+ # Let's add an extra check to see if cuda is available
633
+ import torch
634
+
635
+ if not torch.cuda.is_available():
636
+ return False
637
+
638
+ if torch.version.cuda:
639
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
640
+ elif torch.version.hip:
641
+ # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention
642
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
643
+ else:
644
+ return False
645
+
646
+
647
+ def is_flash_attn_greater_or_equal_2_10():
648
+ if not _is_package_available("flash_attn"):
649
+ return False
650
+
651
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
652
+
653
+
654
+ def is_flash_attn_available():
655
+ logger.warning(
656
+ "Using `is_flash_attn_available` is deprecated and will be removed in v4.38. "
657
+ "Please use `is_flash_attn_2_available` instead."
658
+ )
659
+ return is_flash_attn_2_available()
660
+
661
+
662
+ def is_torchdistx_available():
663
+ return _torchdistx_available
664
+
665
+
666
+ def is_faiss_available():
667
+ return _faiss_available
668
+
669
+
670
+ def is_scipy_available():
671
+ return _scipy_available
672
+
673
+
674
+ def is_sklearn_available():
675
+ return _sklearn_available
676
+
677
+
678
+ def is_sentencepiece_available():
679
+ return _sentencepiece_available
680
+
681
+
682
+ def is_seqio_available():
683
+ return _is_seqio_available
684
+
685
+
686
+ def is_protobuf_available():
687
+ if importlib.util.find_spec("google") is None:
688
+ return False
689
+ return importlib.util.find_spec("google.protobuf") is not None
690
+
691
+
692
+ def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
693
+ if min_version is not None:
694
+ return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
695
+ return _accelerate_available
696
+
697
+
698
+ def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
699
+ return is_torch_available() and version.parse(_torch_version) >= version.parse(min_version)
700
+
701
+
702
+ def is_optimum_available():
703
+ return _optimum_available
704
+
705
+
706
+ def is_auto_awq_available():
707
+ return _auto_awq_available
708
+
709
+
710
+ def is_auto_gptq_available():
711
+ return _auto_gptq_available
712
+
713
+
714
+ def is_levenshtein_available():
715
+ return _levenshtein_available
716
+
717
+
718
+ def is_optimum_neuron_available():
719
+ return _optimum_available and _is_package_available("optimum.neuron")
720
+
721
+
722
+ def is_safetensors_available():
723
+ return _safetensors_available
724
+
725
+
726
+ def is_tokenizers_available():
727
+ return _tokenizers_available
728
+
729
+
730
+ def is_vision_available():
731
+ _pil_available = importlib.util.find_spec("PIL") is not None
732
+ if _pil_available:
733
+ try:
734
+ package_version = importlib.metadata.version("Pillow")
735
+ except importlib.metadata.PackageNotFoundError:
736
+ try:
737
+ package_version = importlib.metadata.version("Pillow-SIMD")
738
+ except importlib.metadata.PackageNotFoundError:
739
+ return False
740
+ logger.debug(f"Detected PIL version {package_version}")
741
+ return _pil_available
742
+
743
+
744
+ def is_pytesseract_available():
745
+ return _pytesseract_available
746
+
747
+
748
+ def is_pytest_available():
749
+ return _pytest_available
750
+
751
+
752
+ def is_spacy_available():
753
+ return _spacy_available
754
+
755
+
756
+ def is_tensorflow_text_available():
757
+ return is_tf_available() and _tensorflow_text_available
758
+
759
+
760
+ def is_keras_nlp_available():
761
+ return is_tensorflow_text_available() and _keras_nlp_available
762
+
763
+
764
+ def is_in_notebook():
765
+ try:
766
+ # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
767
+ get_ipython = sys.modules["IPython"].get_ipython
768
+ if "IPKernelApp" not in get_ipython().config:
769
+ raise ImportError("console")
770
+ if "VSCODE_PID" in os.environ:
771
+ raise ImportError("vscode")
772
+ if "DATABRICKS_RUNTIME_VERSION" in os.environ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0":
773
+ # Databricks Runtime 11.0 and above uses IPython kernel by default so it should be compatible with Jupyter notebook
774
+ # https://docs.microsoft.com/en-us/azure/databricks/notebooks/ipython-kernel
775
+ raise ImportError("databricks")
776
+
777
+ return importlib.util.find_spec("IPython") is not None
778
+ except (AttributeError, ImportError, KeyError):
779
+ return False
780
+
781
+
782
+ def is_pytorch_quantization_available():
783
+ return _pytorch_quantization_available
784
+
785
+
786
+ def is_tensorflow_probability_available():
787
+ return _tensorflow_probability_available
788
+
789
+
790
+ def is_pandas_available():
791
+ return _pandas_available
792
+
793
+
794
+ def is_sagemaker_dp_enabled():
795
+ # Get the sagemaker specific env variable.
796
+ sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
797
+ try:
798
+ # Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
799
+ sagemaker_params = json.loads(sagemaker_params)
800
+ if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
801
+ return False
802
+ except json.JSONDecodeError:
803
+ return False
804
+ # Lastly, check if the `smdistributed` module is present.
805
+ return _smdistributed_available
806
+
807
+
808
+ def is_sagemaker_mp_enabled():
809
+ # Get the sagemaker specific mp parameters from smp_options variable.
810
+ smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
811
+ try:
812
+ # Parse it and check the field "partitions" is included, it is required for model parallel.
813
+ smp_options = json.loads(smp_options)
814
+ if "partitions" not in smp_options:
815
+ return False
816
+ except json.JSONDecodeError:
817
+ return False
818
+
819
+ # Get the sagemaker specific framework parameters from mpi_options variable.
820
+ mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
821
+ try:
822
+ # Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
823
+ mpi_options = json.loads(mpi_options)
824
+ if not mpi_options.get("sagemaker_mpi_enabled", False):
825
+ return False
826
+ except json.JSONDecodeError:
827
+ return False
828
+ # Lastly, check if the `smdistributed` module is present.
829
+ return _smdistributed_available
830
+
831
+
832
+ def is_training_run_on_sagemaker():
833
+ return "SAGEMAKER_JOB_NAME" in os.environ
834
+
835
+
836
+ def is_soundfile_availble():
837
+ return _soundfile_available
838
+
839
+
840
+ def is_timm_available():
841
+ return _timm_available
842
+
843
+
844
+ def is_natten_available():
845
+ return _natten_available
846
+
847
+
848
+ def is_nltk_available():
849
+ return _nltk_available
850
+
851
+
852
+ def is_torchaudio_available():
853
+ return _torchaudio_available
854
+
855
+
856
+ def is_speech_available():
857
+ # For now this depends on torchaudio but the exact dependency might evolve in the future.
858
+ return _torchaudio_available
859
+
860
+
861
+ def is_phonemizer_available():
862
+ return _phonemizer_available
863
+
864
+
865
+ def torch_only_method(fn):
866
+ def wrapper(*args, **kwargs):
867
+ if not _torch_available:
868
+ raise ImportError(
869
+ "You need to install pytorch to use this method or class, "
870
+ "or activate it with environment variables USE_TORCH=1 and USE_TF=0."
871
+ )
872
+ else:
873
+ return fn(*args, **kwargs)
874
+
875
+ return wrapper
876
+
877
+
878
+ def is_ccl_available():
879
+ return _is_ccl_available
880
+
881
+
882
+ def is_decord_available():
883
+ return _decord_available
884
+
885
+
886
+ def is_sudachi_available():
887
+ return _sudachipy_available
888
+
889
+
890
+ def is_jumanpp_available():
891
+ return (importlib.util.find_spec("rhoknp") is not None) and (shutil.which("jumanpp") is not None)
892
+
893
+
894
+ def is_cython_available():
895
+ return importlib.util.find_spec("pyximport") is not None
896
+
897
+
898
+ def is_jieba_available():
899
+ return _jieba_available
900
+
901
+
902
+ def is_jinja_available():
903
+ return _jinja_available
904
+
905
+
906
+ # docstyle-ignore
907
+ CV2_IMPORT_ERROR = """
908
+ {0} requires the OpenCV library but it was not found in your environment. You can install it with:
909
+ ```
910
+ pip install opencv-python
911
+ ```
912
+ Please note that you may need to restart your runtime after installation.
913
+ """
914
+
915
+
916
+ # docstyle-ignore
917
+ DATASETS_IMPORT_ERROR = """
918
+ {0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
919
+ ```
920
+ pip install datasets
921
+ ```
922
+ In a notebook or a colab, you can install it by executing a cell with
923
+ ```
924
+ !pip install datasets
925
+ ```
926
+ then restarting your kernel.
927
+
928
+ Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
929
+ working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
930
+ that python file if that's the case. Please note that you may need to restart your runtime after installation.
931
+ """
932
+
933
+
934
+ # docstyle-ignore
935
+ TOKENIZERS_IMPORT_ERROR = """
936
+ {0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with:
937
+ ```
938
+ pip install tokenizers
939
+ ```
940
+ In a notebook or a colab, you can install it by executing a cell with
941
+ ```
942
+ !pip install tokenizers
943
+ ```
944
+ Please note that you may need to restart your runtime after installation.
945
+ """
946
+
947
+
948
+ # docstyle-ignore
949
+ SENTENCEPIECE_IMPORT_ERROR = """
950
+ {0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
951
+ installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
952
+ that match your environment. Please note that you may need to restart your runtime after installation.
953
+ """
954
+
955
+
956
+ # docstyle-ignore
957
+ PROTOBUF_IMPORT_ERROR = """
958
+ {0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
959
+ installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
960
+ that match your environment. Please note that you may need to restart your runtime after installation.
961
+ """
962
+
963
+
964
+ # docstyle-ignore
965
+ FAISS_IMPORT_ERROR = """
966
+ {0} requires the faiss library but it was not found in your environment. Checkout the instructions on the
967
+ installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
968
+ that match your environment. Please note that you may need to restart your runtime after installation.
969
+ """
970
+
971
+
972
+ # docstyle-ignore
973
+ PYTORCH_IMPORT_ERROR = """
974
+ {0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
975
+ installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
976
+ Please note that you may need to restart your runtime after installation.
977
+ """
978
+
979
+
980
+ # docstyle-ignore
981
+ TORCHVISION_IMPORT_ERROR = """
982
+ {0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the
983
+ installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
984
+ Please note that you may need to restart your runtime after installation.
985
+ """
986
+
987
+ # docstyle-ignore
988
+ PYTORCH_IMPORT_ERROR_WITH_TF = """
989
+ {0} requires the PyTorch library but it was not found in your environment.
990
+ However, we were able to find a TensorFlow installation. TensorFlow classes begin
991
+ with "TF", but are otherwise identically named to our PyTorch classes. This
992
+ means that the TF equivalent of the class you tried to import would be "TF{0}".
993
+ If you want to use TensorFlow, please use TF classes instead!
994
+
995
+ If you really do want to use PyTorch please go to
996
+ https://pytorch.org/get-started/locally/ and follow the instructions that
997
+ match your environment.
998
+ """
999
+
1000
+ # docstyle-ignore
1001
+ TF_IMPORT_ERROR_WITH_PYTORCH = """
1002
+ {0} requires the TensorFlow library but it was not found in your environment.
1003
+ However, we were able to find a PyTorch installation. PyTorch classes do not begin
1004
+ with "TF", but are otherwise identically named to our TF classes.
1005
+ If you want to use PyTorch, please use those classes instead!
1006
+
1007
+ If you really do want to use TensorFlow, please follow the instructions on the
1008
+ installation page https://www.tensorflow.org/install that match your environment.
1009
+ """
1010
+
1011
+ # docstyle-ignore
1012
+ BS4_IMPORT_ERROR = """
1013
+ {0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
1014
+ `pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
1015
+ """
1016
+
1017
+
1018
+ # docstyle-ignore
1019
+ SKLEARN_IMPORT_ERROR = """
1020
+ {0} requires the scikit-learn library but it was not found in your environment. You can install it with:
1021
+ ```
1022
+ pip install -U scikit-learn
1023
+ ```
1024
+ In a notebook or a colab, you can install it by executing a cell with
1025
+ ```
1026
+ !pip install -U scikit-learn
1027
+ ```
1028
+ Please note that you may need to restart your runtime after installation.
1029
+ """
1030
+
1031
+
1032
+ # docstyle-ignore
1033
+ TENSORFLOW_IMPORT_ERROR = """
1034
+ {0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
1035
+ installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
1036
+ Please note that you may need to restart your runtime after installation.
1037
+ """
1038
+
1039
+
1040
+ # docstyle-ignore
1041
+ DETECTRON2_IMPORT_ERROR = """
1042
+ {0} requires the detectron2 library but it was not found in your environment. Checkout the instructions on the
1043
+ installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones
1044
+ that match your environment. Please note that you may need to restart your runtime after installation.
1045
+ """
1046
+
1047
+
1048
+ # docstyle-ignore
1049
+ FLAX_IMPORT_ERROR = """
1050
+ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
1051
+ installation page: https://github.com/google/flax and follow the ones that match your environment.
1052
+ Please note that you may need to restart your runtime after installation.
1053
+ """
1054
+
1055
+ # docstyle-ignore
1056
+ FTFY_IMPORT_ERROR = """
1057
+ {0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
1058
+ installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
1059
+ that match your environment. Please note that you may need to restart your runtime after installation.
1060
+ """
1061
+
1062
+ LEVENSHTEIN_IMPORT_ERROR = """
1063
+ {0} requires the python-Levenshtein library but it was not found in your environment. You can install it with pip: `pip
1064
+ install python-Levenshtein`. Please note that you may need to restart your runtime after installation.
1065
+ """
1066
+
1067
+ # docstyle-ignore
1068
+ G2P_EN_IMPORT_ERROR = """
1069
+ {0} requires the g2p-en library but it was not found in your environment. You can install it with pip:
1070
+ `pip install g2p-en`. Please note that you may need to restart your runtime after installation.
1071
+ """
1072
+
1073
+ # docstyle-ignore
1074
+ PYTORCH_QUANTIZATION_IMPORT_ERROR = """
1075
+ {0} requires the pytorch-quantization library but it was not found in your environment. You can install it with pip:
1076
+ `pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com`
1077
+ Please note that you may need to restart your runtime after installation.
1078
+ """
1079
+
1080
+ # docstyle-ignore
1081
+ TENSORFLOW_PROBABILITY_IMPORT_ERROR = """
1082
+ {0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as
1083
+ explained here: https://github.com/tensorflow/probability. Please note that you may need to restart your runtime after installation.
1084
+ """
1085
+
1086
+ # docstyle-ignore
1087
+ TENSORFLOW_TEXT_IMPORT_ERROR = """
1088
+ {0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as
1089
+ explained here: https://www.tensorflow.org/text/guide/tf_text_intro.
1090
+ Please note that you may need to restart your runtime after installation.
1091
+ """
1092
+
1093
+
1094
+ # docstyle-ignore
1095
+ PANDAS_IMPORT_ERROR = """
1096
+ {0} requires the pandas library but it was not found in your environment. You can install it with pip as
1097
+ explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
1098
+ Please note that you may need to restart your runtime after installation.
1099
+ """
1100
+
1101
+
1102
+ # docstyle-ignore
1103
+ PHONEMIZER_IMPORT_ERROR = """
1104
+ {0} requires the phonemizer library but it was not found in your environment. You can install it with pip:
1105
+ `pip install phonemizer`. Please note that you may need to restart your runtime after installation.
1106
+ """
1107
+
1108
+
1109
+ # docstyle-ignore
1110
+ SACREMOSES_IMPORT_ERROR = """
1111
+ {0} requires the sacremoses library but it was not found in your environment. You can install it with pip:
1112
+ `pip install sacremoses`. Please note that you may need to restart your runtime after installation.
1113
+ """
1114
+
1115
+ # docstyle-ignore
1116
+ SCIPY_IMPORT_ERROR = """
1117
+ {0} requires the scipy library but it was not found in your environment. You can install it with pip:
1118
+ `pip install scipy`. Please note that you may need to restart your runtime after installation.
1119
+ """
1120
+
1121
+
1122
+ # docstyle-ignore
1123
+ SPEECH_IMPORT_ERROR = """
1124
+ {0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
1125
+ `pip install torchaudio`. Please note that you may need to restart your runtime after installation.
1126
+ """
1127
+
1128
+ # docstyle-ignore
1129
+ TIMM_IMPORT_ERROR = """
1130
+ {0} requires the timm library but it was not found in your environment. You can install it with pip:
1131
+ `pip install timm`. Please note that you may need to restart your runtime after installation.
1132
+ """
1133
+
1134
+ # docstyle-ignore
1135
+ NATTEN_IMPORT_ERROR = """
1136
+ {0} requires the natten library but it was not found in your environment. You can install it by referring to:
1137
+ shi-labs.com/natten . You can also install it with pip (may take longer to build):
1138
+ `pip install natten`. Please note that you may need to restart your runtime after installation.
1139
+ """
1140
+
1141
+
1142
+ # docstyle-ignore
1143
+ NLTK_IMPORT_ERROR = """
1144
+ {0} requires the NLTK library but it was not found in your environment. You can install it by referring to:
1145
+ https://www.nltk.org/install.html. Please note that you may need to restart your runtime after installation.
1146
+ """
1147
+
1148
+
1149
+ # docstyle-ignore
1150
+ VISION_IMPORT_ERROR = """
1151
+ {0} requires the PIL library but it was not found in your environment. You can install it with pip:
1152
+ `pip install pillow`. Please note that you may need to restart your runtime after installation.
1153
+ """
1154
+
1155
+
1156
+ # docstyle-ignore
1157
+ PYTESSERACT_IMPORT_ERROR = """
1158
+ {0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
1159
+ `pip install pytesseract`. Please note that you may need to restart your runtime after installation.
1160
+ """
1161
+
1162
+ # docstyle-ignore
1163
+ PYCTCDECODE_IMPORT_ERROR = """
1164
+ {0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip:
1165
+ `pip install pyctcdecode`. Please note that you may need to restart your runtime after installation.
1166
+ """
1167
+
1168
+ # docstyle-ignore
1169
+ ACCELERATE_IMPORT_ERROR = """
1170
+ {0} requires the accelerate library >= {ACCELERATE_MIN_VERSION} it was not found in your environment.
1171
+ You can install or update it with pip: `pip install --upgrade accelerate`. Please note that you may need to restart your
1172
+ runtime after installation.
1173
+ """
1174
+
1175
+ # docstyle-ignore
1176
+ CCL_IMPORT_ERROR = """
1177
+ {0} requires the torch ccl library but it was not found in your environment. You can install it with pip:
1178
+ `pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable`
1179
+ Please note that you may need to restart your runtime after installation.
1180
+ """
1181
+
1182
+ # docstyle-ignore
1183
+ ESSENTIA_IMPORT_ERROR = """
1184
+ {0} requires essentia library. But that was not found in your environment. You can install them with pip:
1185
+ `pip install essentia==2.1b6.dev1034`
1186
+ Please note that you may need to restart your runtime after installation.
1187
+ """
1188
+
1189
+ # docstyle-ignore
1190
+ LIBROSA_IMPORT_ERROR = """
1191
+ {0} requires thes librosa library. But that was not found in your environment. You can install them with pip:
1192
+ `pip install librosa`
1193
+ Please note that you may need to restart your runtime after installation.
1194
+ """
1195
+
1196
+ # docstyle-ignore
1197
+ PRETTY_MIDI_IMPORT_ERROR = """
1198
+ {0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip:
1199
+ `pip install pretty_midi`
1200
+ Please note that you may need to restart your runtime after installation.
1201
+ """
1202
+
1203
+ DECORD_IMPORT_ERROR = """
1204
+ {0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install
1205
+ decord`. Please note that you may need to restart your runtime after installation.
1206
+ """
1207
+
1208
+ CYTHON_IMPORT_ERROR = """
1209
+ {0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install
1210
+ Cython`. Please note that you may need to restart your runtime after installation.
1211
+ """
1212
+
1213
+ JIEBA_IMPORT_ERROR = """
1214
+ {0} requires the jieba library but it was not found in your environment. You can install it with pip: `pip install
1215
+ jieba`. Please note that you may need to restart your runtime after installation.
1216
+ """
1217
+
1218
+ PEFT_IMPORT_ERROR = """
1219
+ {0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install
1220
+ peft`. Please note that you may need to restart your runtime after installation.
1221
+ """
1222
+
1223
+ JINJA_IMPORT_ERROR = """
1224
+ {0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install
1225
+ jinja2`. Please note that you may need to restart your runtime after installation.
1226
+ """
1227
+
1228
+ BACKENDS_MAPPING = OrderedDict(
1229
+ [
1230
+ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
1231
+ ("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
1232
+ ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
1233
+ ("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
1234
+ ("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)),
1235
+ ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
1236
+ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
1237
+ ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
1238
+ ("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)),
1239
+ ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
1240
+ ("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)),
1241
+ ("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)),
1242
+ ("levenshtein", (is_levenshtein_available, LEVENSHTEIN_IMPORT_ERROR)),
1243
+ ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
1244
+ ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
1245
+ ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
1246
+ ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
1247
+ ("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)),
1248
+ ("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)),
1249
+ ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
1250
+ ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
1251
+ ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
1252
+ ("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)),
1253
+ ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
1254
+ ("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)),
1255
+ ("timm", (is_timm_available, TIMM_IMPORT_ERROR)),
1256
+ ("natten", (is_natten_available, NATTEN_IMPORT_ERROR)),
1257
+ ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
1258
+ ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
1259
+ ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
1260
+ ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
1261
+ ("vision", (is_vision_available, VISION_IMPORT_ERROR)),
1262
+ ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
1263
+ ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
1264
+ ("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)),
1265
+ ("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
1266
+ ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
1267
+ ("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
1268
+ ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
1269
+ ("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
1270
+ ]
1271
+ )
1272
+
1273
+
1274
+ def requires_backends(obj, backends):
1275
+ if not isinstance(backends, (list, tuple)):
1276
+ backends = [backends]
1277
+
1278
+ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
1279
+
1280
+ # Raise an error for users who might not realize that classes without "TF" are torch-only
1281
+ if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available():
1282
+ raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
1283
+
1284
+ # Raise the inverse error for PyTorch users trying to load TF classes
1285
+ if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available():
1286
+ raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
1287
+
1288
+ checks = (BACKENDS_MAPPING[backend] for backend in backends)
1289
+ failed = [msg.format(name) for available, msg in checks if not available()]
1290
+ if failed:
1291
+ raise ImportError("".join(failed))
1292
+
1293
+
1294
+ class DummyObject(type):
1295
+ """
1296
+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
1297
+ `requires_backend` each time a user tries to access any method of that class.
1298
+ """
1299
+
1300
+ def __getattribute__(cls, key):
1301
+ if key.startswith("_") and key != "_from_config":
1302
+ return super().__getattribute__(key)
1303
+ requires_backends(cls, cls._backends)
1304
+
1305
+
1306
+ def torch_required(func):
1307
+ warnings.warn(
1308
+ "The method `torch_required` is deprecated and will be removed in v4.36. Use `requires_backends` instead.",
1309
+ FutureWarning,
1310
+ )
1311
+
1312
+ # Chose a different decorator name than in tests so it's clear they are not the same.
1313
+ @wraps(func)
1314
+ def wrapper(*args, **kwargs):
1315
+ if is_torch_available():
1316
+ return func(*args, **kwargs)
1317
+ else:
1318
+ raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
1319
+
1320
+ return wrapper
1321
+
1322
+
1323
+ def tf_required(func):
1324
+ warnings.warn(
1325
+ "The method `tf_required` is deprecated and will be removed in v4.36. Use `requires_backends` instead.",
1326
+ FutureWarning,
1327
+ )
1328
+
1329
+ # Chose a different decorator name than in tests so it's clear they are not the same.
1330
+ @wraps(func)
1331
+ def wrapper(*args, **kwargs):
1332
+ if is_tf_available():
1333
+ return func(*args, **kwargs)
1334
+ else:
1335
+ raise ImportError(f"Method `{func.__name__}` requires TF.")
1336
+
1337
+ return wrapper
1338
+
1339
+
1340
+ def is_torch_fx_proxy(x):
1341
+ if is_torch_fx_available():
1342
+ import torch.fx
1343
+
1344
+ return isinstance(x, torch.fx.Proxy)
1345
+ return False
1346
+
1347
+
1348
+ class _LazyModule(ModuleType):
1349
+ """
1350
+ Module class that surfaces all objects but only performs associated imports when the objects are requested.
1351
+ """
1352
+
1353
+ # Very heavily inspired by optuna.integration._IntegrationModule
1354
+ # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
1355
+ def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
1356
+ super().__init__(name)
1357
+ self._modules = set(import_structure.keys())
1358
+ self._class_to_module = {}
1359
+ for key, values in import_structure.items():
1360
+ for value in values:
1361
+ self._class_to_module[value] = key
1362
+ # Needed for autocompletion in an IDE
1363
+ self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
1364
+ self.__file__ = module_file
1365
+ self.__spec__ = module_spec
1366
+ self.__path__ = [os.path.dirname(module_file)]
1367
+ self._objects = {} if extra_objects is None else extra_objects
1368
+ self._name = name
1369
+ self._import_structure = import_structure
1370
+
1371
+ # Needed for autocompletion in an IDE
1372
+ def __dir__(self):
1373
+ result = super().__dir__()
1374
+ # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
1375
+ # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
1376
+ for attr in self.__all__:
1377
+ if attr not in result:
1378
+ result.append(attr)
1379
+ return result
1380
+
1381
+ def __getattr__(self, name: str) -> Any:
1382
+ if name in self._objects:
1383
+ return self._objects[name]
1384
+ if name in self._modules:
1385
+ value = self._get_module(name)
1386
+ elif name in self._class_to_module.keys():
1387
+ module = self._get_module(self._class_to_module[name])
1388
+ value = getattr(module, name)
1389
+ else:
1390
+ raise AttributeError(f"module {self.__name__} has no attribute {name}")
1391
+
1392
+ setattr(self, name, value)
1393
+ return value
1394
+
1395
+ def _get_module(self, module_name: str):
1396
+ try:
1397
+ return importlib.import_module("." + module_name, self.__name__)
1398
+ except Exception as e:
1399
+ raise RuntimeError(
1400
+ f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
1401
+ f" traceback):\n{e}"
1402
+ ) from e
1403
+
1404
+ def __reduce__(self):
1405
+ return (self.__class__, (self._name, self.__file__, self._import_structure))
1406
+
1407
+
1408
+ class OptionalDependencyNotAvailable(BaseException):
1409
+ """Internally used error class for signalling an optional dependency was not found."""
1410
+
1411
+
1412
+ def direct_transformers_import(path: str, file="__init__.py") -> ModuleType:
1413
+ """Imports transformers directly
1414
+
1415
+ Args:
1416
+ path (`str`): The path to the source file
1417
+ file (`str`, optional): The file to join with the path. Defaults to "__init__.py".
1418
+
1419
+ Returns:
1420
+ `ModuleType`: The resulting imported module
1421
+ """
1422
+ name = "transformers"
1423
+ location = os.path.join(path, file)
1424
+ spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path])
1425
+ module = importlib.util.module_from_spec(spec)
1426
+ spec.loader.exec_module(module)
1427
+ module = sys.modules[name]
1428
+ return module
tinychart/model/llava_arch.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Baichuan Zhou , Junlong Jia, Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from tinychart.model.multimodal_encoder.builder import build_vision_tower
22
+ from tinychart.model.multimodal_projector.builder import build_vision_projector
23
+
24
+ from tinychart.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+
26
+
27
+ class LlavaMetaModel:
28
+
29
+ def __init__(self, config):
30
+ super(LlavaMetaModel, self).__init__(config)
31
+
32
+ if hasattr(config, "mm_vision_tower"):
33
+ self.vision_tower = build_vision_tower(config, delay_load=True)
34
+ self.mm_projector = build_vision_projector(config)
35
+
36
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
37
+ self.image_newline = nn.Parameter(
38
+ torch.empty(config.hidden_size, dtype=self.dtype)
39
+ )
40
+
41
+ def get_vision_tower(self):
42
+ vision_tower = getattr(self, 'vision_tower', None)
43
+ if type(vision_tower) is list:
44
+ vision_tower = vision_tower[0]
45
+ return vision_tower
46
+
47
+ def initialize_vision_modules(self, model_args, fsdp=None):
48
+ vision_tower = model_args.vision_tower
49
+ mm_vision_select_layer = model_args.mm_vision_select_layer
50
+ mm_vision_select_feature = model_args.mm_vision_select_feature
51
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
52
+ mm_patch_merge_type = model_args.mm_patch_merge_type
53
+
54
+ self.config.mm_vision_tower = vision_tower
55
+
56
+ if self.get_vision_tower() is None:
57
+ vision_tower = build_vision_tower(model_args)
58
+
59
+ if fsdp is not None and len(fsdp) > 0:
60
+ self.vision_tower = [vision_tower]
61
+ else:
62
+ self.vision_tower = vision_tower
63
+
64
+ elif self.get_vision_tower().vision_tower_name != vision_tower:
65
+ print(f"rebuilding vision tower! vision tower initialized from: {vision_tower}")
66
+ vision_tower = build_vision_tower(model_args)
67
+ if fsdp is not None and len(fsdp) > 0:
68
+ self.vision_tower = [vision_tower]
69
+ else:
70
+ self.vision_tower = vision_tower
71
+
72
+ else:
73
+ if fsdp is not None and len(fsdp) > 0:
74
+ vision_tower = self.vision_tower[0]
75
+ else:
76
+ vision_tower = self.vision_tower
77
+ vision_tower.load_model()
78
+
79
+ self.config.use_mm_proj = True
80
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
81
+ self.config.mm_hidden_size = vision_tower.hidden_size
82
+ self.config.mm_vision_select_layer = mm_vision_select_layer
83
+ self.config.mm_vision_select_feature = mm_vision_select_feature
84
+ self.config.mm_patch_merge_type = mm_patch_merge_type
85
+
86
+ if getattr(self, 'mm_projector', None) is None:
87
+ self.mm_projector = build_vision_projector(self.config)
88
+
89
+ if 'unpad' in mm_patch_merge_type:
90
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
91
+ self.image_newline = nn.Parameter(
92
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
93
+ )
94
+ else:
95
+ # In case it is frozen by LoRA
96
+ for p in self.mm_projector.parameters():
97
+ p.requires_grad = True
98
+
99
+ if pretrain_mm_mlp_adapter is not None:
100
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
101
+ def get_w(weights, keyword):
102
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
103
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
104
+
105
+
106
+ def unpad_image(tensor, original_size):
107
+ """
108
+ Unpads a PyTorch tensor of a padded and resized image.
109
+
110
+ Args:
111
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
112
+ original_size (tuple): The original size of the image (height, width).
113
+
114
+ Returns:
115
+ torch.Tensor: The unpadded image tensor.
116
+ """
117
+ original_width, original_height = original_size
118
+ current_height, current_width = tensor.shape[1:]
119
+
120
+ original_aspect_ratio = original_width / original_height
121
+ current_aspect_ratio = current_width / current_height
122
+
123
+ if original_aspect_ratio > current_aspect_ratio:
124
+ scale_factor = current_width / original_width
125
+ new_height = int(original_height * scale_factor)
126
+ padding = (current_height - new_height) // 2
127
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
128
+ else:
129
+ scale_factor = current_height / original_height
130
+ new_width = int(original_width * scale_factor)
131
+ padding = (current_width - new_width) // 2
132
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
133
+
134
+ return unpadded_tensor
135
+
136
+
137
+ class LlavaMetaForCausalLM(ABC):
138
+
139
+ @abstractmethod
140
+ def get_model(self):
141
+ pass
142
+
143
+ def get_vision_tower(self):
144
+ return self.get_model().get_vision_tower()
145
+
146
+ def encode_images(self, images):
147
+ image_features = self.get_model().get_vision_tower()(images)
148
+ image_features = self.get_model().mm_projector(image_features)
149
+ return image_features
150
+
151
+ def prepare_inputs_labels_for_multimodal(
152
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
153
+ images, image_sizes=None
154
+ ):
155
+ vision_tower = self.get_vision_tower()
156
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
157
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
158
+
159
+ if type(images) is list or images.ndim == 5:
160
+ if type(images) is list:
161
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
162
+ concat_images = torch.cat([image for image in images], dim=0)
163
+ image_features = self.encode_images(concat_images)
164
+ split_sizes = [image.shape[0] for image in images]
165
+ image_features = torch.split(image_features, split_sizes, dim=0)
166
+ mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
167
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
168
+ if mm_patch_merge_type == 'flat':
169
+ image_features = [x.flatten(0, 1) for x in image_features]
170
+ elif mm_patch_merge_type.startswith('spatial'):
171
+ new_image_features = []
172
+ for image_idx, image_feature in enumerate(image_features):
173
+ if image_feature.shape[0] > 1:
174
+ base_image_feature = image_feature[0]
175
+ image_feature = image_feature[1:]
176
+ height = width = self.get_vision_tower().num_patches_per_side
177
+ assert height * width == base_image_feature.shape[0]
178
+ if 'unpad' in mm_patch_merge_type:
179
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
180
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
181
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
182
+ image_feature = torch.cat((
183
+ image_feature,
184
+ self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
185
+ ), dim=-1)
186
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
187
+ else:
188
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
189
+ image_feature = image_feature.flatten(0, 3)
190
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
191
+ else:
192
+ image_feature = image_feature[0]
193
+ if 'unpad' in mm_patch_merge_type:
194
+ image_feature = torch.cat((
195
+ image_feature,
196
+ self.model.image_newline[None].to(image_feature.device)
197
+ ), dim=0)
198
+ new_image_features.append(image_feature)
199
+ image_features = new_image_features
200
+ else:
201
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
202
+ else:
203
+ image_features = self.encode_images(images)
204
+
205
+ # TODO: image start / end is not implemented here to support pretraining.
206
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
207
+ raise NotImplementedError
208
+
209
+ # Let's just add dummy tensors if they do not exist,
210
+ # it is a headache to deal with None all the time.
211
+ # But it is not ideal, and if you have a better idea,
212
+ # please open an issue / submit a PR, thanks.
213
+ _labels = labels
214
+ _position_ids = position_ids
215
+ _attention_mask = attention_mask
216
+ if attention_mask is None:
217
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
218
+ else:
219
+ attention_mask = attention_mask.bool()
220
+ if position_ids is None:
221
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
222
+ if labels is None:
223
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
224
+
225
+ # remove the padding using attention_mask -- FIXME
226
+ _input_ids = input_ids
227
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
228
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
229
+
230
+ new_input_embeds = []
231
+ new_labels = []
232
+ cur_image_idx = 0
233
+ for batch_idx, cur_input_ids in enumerate(input_ids):
234
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
235
+ if num_images == 0:
236
+ cur_image_features = image_features[cur_image_idx]
237
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
238
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
239
+ new_input_embeds.append(cur_input_embeds)
240
+ new_labels.append(labels[batch_idx])
241
+ cur_image_idx += 1
242
+ continue
243
+
244
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
245
+ cur_input_ids_noim = []
246
+ cur_labels = labels[batch_idx]
247
+ cur_labels_noim = []
248
+ for i in range(len(image_token_indices) - 1):
249
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
250
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
251
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
252
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
253
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
254
+ cur_new_input_embeds = []
255
+ cur_new_labels = []
256
+
257
+ for i in range(num_images + 1):
258
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
259
+ cur_new_labels.append(cur_labels_noim[i])
260
+ if i < num_images:
261
+ cur_image_features = image_features[cur_image_idx]
262
+ cur_image_idx += 1
263
+ cur_new_input_embeds.append(cur_image_features)
264
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
265
+
266
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
267
+
268
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
269
+ cur_new_labels = torch.cat(cur_new_labels)
270
+
271
+ new_input_embeds.append(cur_new_input_embeds)
272
+ new_labels.append(cur_new_labels)
273
+
274
+ # Truncate sequences to max length as image embeddings can make the sequence longer
275
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
276
+ if tokenizer_model_max_length is not None:
277
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
278
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
279
+
280
+ # Combine them
281
+ max_len = max(x.shape[0] for x in new_input_embeds)
282
+ batch_size = len(new_input_embeds)
283
+
284
+ new_input_embeds_padded = []
285
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
286
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
287
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
288
+
289
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
290
+ cur_len = cur_new_embed.shape[0]
291
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
292
+ new_input_embeds_padded.append(torch.cat((
293
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
294
+ cur_new_embed
295
+ ), dim=0))
296
+ if cur_len > 0:
297
+ new_labels_padded[i, -cur_len:] = cur_new_labels
298
+ attention_mask[i, -cur_len:] = True
299
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
300
+ else:
301
+ new_input_embeds_padded.append(torch.cat((
302
+ cur_new_embed,
303
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
304
+ ), dim=0))
305
+ if cur_len > 0:
306
+ new_labels_padded[i, :cur_len] = cur_new_labels
307
+ attention_mask[i, :cur_len] = True
308
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
309
+
310
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
311
+
312
+ if _labels is None:
313
+ new_labels = None
314
+ else:
315
+ new_labels = new_labels_padded
316
+
317
+ if _attention_mask is None:
318
+ attention_mask = None
319
+ else:
320
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
321
+
322
+ if _position_ids is None:
323
+ position_ids = None
324
+
325
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
326
+
327
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
328
+ if model_args.mm_use_im_patch_token:
329
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
330
+ self.resize_token_embeddings(len(tokenizer))
331
+
332
+ if model_args.mm_use_im_start_end:
333
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
334
+ self.resize_token_embeddings(len(tokenizer))
335
+
336
+ if num_new_tokens > 0:
337
+ input_embeddings = self.get_input_embeddings().weight.data
338
+ output_embeddings = self.get_output_embeddings().weight.data
339
+
340
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
341
+ dim=0, keepdim=True)
342
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
343
+ dim=0, keepdim=True)
344
+
345
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
346
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
347
+
348
+ if model_args.tune_mm_mlp_adapter:
349
+ for p in self.get_input_embeddings().parameters():
350
+ p.requires_grad = True
351
+ for p in self.get_output_embeddings().parameters():
352
+ p.requires_grad = False
353
+
354
+ if model_args.pretrain_mm_mlp_adapter:
355
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
356
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
357
+ assert num_new_tokens == 2
358
+ if input_embeddings.shape == embed_tokens_weight.shape:
359
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
360
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
361
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
362
+ else:
363
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
364
+ elif model_args.tune_embed_tokens:
365
+ for p in self.get_input_embeddings().parameters():
366
+ p.requires_grad = True
367
+ for p in self.get_output_embeddings().parameters():
368
+ p.requires_grad = False
369
+ print("Set input embeddings to trainable")
370
+
371
+ elif model_args.mm_use_im_patch_token:
372
+ if model_args.tune_mm_mlp_adapter:
373
+ for p in self.get_input_embeddings().parameters():
374
+ p.requires_grad = False
375
+ for p in self.get_output_embeddings().parameters():
376
+ p.requires_grad = False
377
+
378
+ if model_args.pretrain_mm_mlp_adapter:
379
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
380
+ if 'model.embed_tokens.weight' in mm_projector_weights.keys():
381
+ def get_w(weights, keyword):
382
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
383
+ self.get_model().embed_tokens.load_state_dict(get_w(mm_projector_weights, 'model.embed_tokens'))
tinychart/model/model_factory.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+
4
+ MODEL_REGISTRY = {}
5
+ TOKENIZER_REGISTRY = {}
6
+
7
+
8
+ def ModelSelect(model_name_or_path):
9
+ model = None
10
+ for name in MODEL_REGISTRY.keys():
11
+ if name in model_name_or_path.lower():
12
+ model = MODEL_REGISTRY[name]
13
+ if model is None:
14
+ model = MODEL_REGISTRY['llama']
15
+ return model
16
+
17
+
18
+ def TokenizerSelect(model_name_or_path):
19
+ tokenizer_init = None
20
+ for name in TOKENIZER_REGISTRY.keys():
21
+ if name in model_name_or_path.lower():
22
+ tokenizer_init = TOKENIZER_REGISTRY[name]
23
+ if tokenizer_init is None:
24
+ tokenizer_init = TOKENIZER_REGISTRY['llama']
25
+ return tokenizer_init
26
+
27
+
28
+ def register_model(name):
29
+ def register_model_cls(cls):
30
+ if name in MODEL_REGISTRY:
31
+ return MODEL_REGISTRY[name]
32
+
33
+ MODEL_REGISTRY[name] = cls
34
+ return cls
35
+
36
+ return register_model_cls
37
+
38
+
39
+ def register_tokenizer(name):
40
+ def register_tokenizer_cls(cls):
41
+ if name in TOKENIZER_REGISTRY:
42
+ return TOKENIZER_REGISTRY[name]
43
+
44
+ TOKENIZER_REGISTRY[name] = cls
45
+ return cls
46
+
47
+ return register_tokenizer_cls
48
+
49
+
50
+ def import_models(models_dir, namespace):
51
+ for file in os.listdir(models_dir):
52
+ path = os.path.join(models_dir, file)
53
+ if (
54
+ not file.startswith("_")
55
+ and not file.startswith(".")
56
+ and file.endswith(".py")
57
+ ):
58
+ model_name = file[: file.find(".py")] if file.endswith(".py") else file
59
+ importlib.import_module(namespace + "." + model_name)
60
+
61
+
62
+ # automatically import any Python files in the models/ directory
63
+ models_dir = os.path.join(os.path.dirname(__file__), 'language_model')
64
+ import_models(models_dir, "tinychart.model.language_model")
tinychart/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tinychart.model.multimodal_encoder.siglip_encoder import SigLipVisionTower
3
+
4
+ def build_vision_tower(vision_tower_cfg, **kwargs):
5
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
6
+ is_absolute_path_exists = os.path.exists(vision_tower)
7
+ return SigLipVisionTower(vision_tower, vision_tower_cfg, **kwargs)
tinychart/model/multimodal_encoder/merge.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ from typing import Callable, Tuple
10
+
11
+ import torch
12
+
13
+
14
+ def do_nothing(x, mode=None):
15
+ return x
16
+
17
+
18
+ def bipartite_soft_matching(
19
+ metric: torch.Tensor,
20
+ r: int,
21
+ class_token: bool = False,
22
+ distill_token: bool = False,
23
+ ) -> Tuple[Callable, Callable]:
24
+ """
25
+ Applies ToMe with a balanced matching set (50%, 50%).
26
+
27
+ Input size is [batch, tokens, channels].
28
+ r indicates the number of tokens to remove (max 50% of tokens).
29
+
30
+ Extra args:
31
+ - class_token: Whether or not there's a class token.
32
+ - distill_token: Whether or not there's also a distillation token.
33
+
34
+ When enabled, the class token and distillation tokens won't get merged.
35
+ """
36
+ protected = 0
37
+ if class_token:
38
+ protected += 1
39
+ if distill_token:
40
+ protected += 1
41
+
42
+ # We can only reduce by a maximum of 50% tokens
43
+ t = metric.shape[1]
44
+ r = min(r, (t - protected) // 2)
45
+
46
+ if r <= 0:
47
+ return do_nothing, do_nothing
48
+
49
+ with torch.no_grad():
50
+ metric = metric / metric.norm(dim=-1, keepdim=True)
51
+ a, b = metric[..., ::2, :], metric[..., 1::2, :]
52
+ scores = a @ b.transpose(-1, -2)
53
+
54
+ if class_token:
55
+ scores[..., 0, :] = -math.inf
56
+ if distill_token:
57
+ scores[..., :, 0] = -math.inf
58
+
59
+ node_max, node_idx = scores.max(dim=-1)
60
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
61
+
62
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
63
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
64
+ dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
65
+
66
+ if class_token:
67
+ # Sort to ensure the class token is at the start
68
+ unm_idx = unm_idx.sort(dim=1)[0]
69
+
70
+ def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
71
+ src, dst = x[..., ::2, :], x[..., 1::2, :]
72
+ n, t1, c = src.shape
73
+ unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
74
+ src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
75
+ dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
76
+
77
+ if distill_token:
78
+ return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
79
+ else:
80
+ return torch.cat([unm, dst], dim=1)
81
+
82
+ def unmerge(x: torch.Tensor) -> torch.Tensor:
83
+ unm_len = unm_idx.shape[1]
84
+ unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
85
+ n, _, c = unm.shape
86
+
87
+ src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))
88
+
89
+ out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)
90
+
91
+ out[..., 1::2, :] = dst
92
+ out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
93
+ out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)
94
+
95
+ return out
96
+
97
+ return merge, unmerge
98
+
99
+
100
+ def kth_bipartite_soft_matching(
101
+ metric: torch.Tensor, k: int
102
+ ) -> Tuple[Callable, Callable]:
103
+ """
104
+ Applies ToMe with the two sets as (every kth element, the rest).
105
+ If n is the number of tokens, resulting number of tokens will be n // z.
106
+
107
+ Input size is [batch, tokens, channels].
108
+ z indicates the stride for the first set.
109
+ z = 2 is equivalent to regular bipartite_soft_matching with r = 0.5 * N
110
+ """
111
+ if k <= 1:
112
+ return do_nothing, do_nothing
113
+
114
+ def split(x):
115
+ t_rnd = (x.shape[1] // k) * k
116
+ x = x[:, :t_rnd, :].view(x.shape[0], -1, k, x.shape[2])
117
+ a, b = (
118
+ x[:, :, : (k - 1), :].contiguous().view(x.shape[0], -1, x.shape[-1]),
119
+ x[:, :, (k - 1), :],
120
+ )
121
+ return a, b
122
+
123
+ with torch.no_grad():
124
+ metric = metric / metric.norm(dim=-1, keepdim=True)
125
+ a, b = split(metric)
126
+ r = a.shape[1]
127
+ scores = a @ b.transpose(-1, -2)
128
+
129
+ _, dst_idx = scores.max(dim=-1)
130
+ dst_idx = dst_idx[..., None]
131
+
132
+ def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
133
+ src, dst = split(x)
134
+ n, _, c = src.shape
135
+ dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
136
+
137
+ return dst
138
+
139
+ def unmerge(x: torch.Tensor) -> torch.Tensor:
140
+ n, _, c = x.shape
141
+ dst = x
142
+
143
+ src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)).to(x.dtype)
144
+
145
+ src = src.view(n, -1, (k - 1), c)
146
+ dst = dst.view(n, -1, 1, c)
147
+
148
+ out = torch.cat([src, dst], dim=-2)
149
+ out = out.contiguous().view(n, -1, c)
150
+
151
+ return out
152
+
153
+ return merge, unmerge
154
+
155
+
156
+ def random_bipartite_soft_matching(
157
+ metric: torch.Tensor, r: int
158
+ ) -> Tuple[Callable, Callable]:
159
+ """
160
+ Applies ToMe with the two sets as (r chosen randomly, the rest).
161
+ Input size is [batch, tokens, channels].
162
+
163
+ This will reduce the number of tokens by r.
164
+ """
165
+ if r <= 0:
166
+ return do_nothing, do_nothing
167
+
168
+ with torch.no_grad():
169
+ B, N, _ = metric.shape
170
+ rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1)
171
+
172
+ a_idx = rand_idx[:, :r, :]
173
+ b_idx = rand_idx[:, r:, :]
174
+
175
+ def split(x):
176
+ C = x.shape[-1]
177
+ a = x.gather(dim=1, index=a_idx.expand(B, r, C))
178
+ b = x.gather(dim=1, index=b_idx.expand(B, N - r, C))
179
+ return a, b
180
+
181
+ metric = metric / metric.norm(dim=-1, keepdim=True)
182
+ a, b = split(metric)
183
+ scores = a @ b.transpose(-1, -2)
184
+
185
+ _, dst_idx = scores.max(dim=-1)
186
+ dst_idx = dst_idx[..., None]
187
+
188
+ def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
189
+ src, dst = split(x)
190
+ C = src.shape[-1]
191
+ dst = dst.scatter_reduce(-2, dst_idx.expand(B, r, C), src, reduce=mode)
192
+
193
+ return dst
194
+
195
+ def unmerge(x: torch.Tensor) -> torch.Tensor:
196
+ C = x.shape[-1]
197
+ dst = x
198
+ src = dst.gather(dim=-2, index=dst_idx.expand(B, r, C))
199
+
200
+ out = torch.zeros(B, N, C, device=x.device, dtype=x.dtype)
201
+
202
+ out.scatter_(dim=-2, index=a_idx.expand(B, r, C), src=src)
203
+ out.scatter_(dim=-2, index=b_idx.expand(B, N - r, C), src=dst)
204
+
205
+ return out
206
+
207
+ return merge, unmerge
208
+
209
+
210
+ def merge_wavg(
211
+ merge: Callable, x: torch.Tensor, size: torch.Tensor = None
212
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
213
+ """
214
+ Applies the merge function by taking a weighted average based on token size.
215
+ Returns the merged tensor and the new token sizes.
216
+ """
217
+ if size is None:
218
+ size = torch.ones_like(x[..., 0, None])
219
+
220
+ x = merge(x * size, mode="sum")
221
+ size = merge(size, mode="sum")
222
+
223
+ x = x / size
224
+ return x, size
225
+
226
+
227
+ def merge_source(
228
+ merge: Callable, x: torch.Tensor, source: torch.Tensor = None
229
+ ) -> torch.Tensor:
230
+ """
231
+ For source tracking. Source is an adjacency matrix between the initial tokens and final merged groups.
232
+ x is used to find out how many tokens there are in case the source is None.
233
+ """
234
+ if source is None:
235
+ n, t, _ = x.shape
236
+ source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t)
237
+
238
+ source = merge(source, mode="amax")
239
+ return source
tinychart/model/multimodal_encoder/siglip_encoder.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py
3
+ '''
4
+
5
+ from typing import Optional, Tuple, Union, Dict
6
+ from dataclasses import dataclass
7
+ from functools import partial, reduce
8
+ from PIL import Image
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+ import os
14
+ import numpy as np
15
+ from transformers.image_processing_utils import BatchFeature, get_size_dict
16
+ from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
17
+ from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, )
18
+ from transformers.activations import ACT2FN
19
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
20
+ from transformers.modeling_utils import PreTrainedModel
21
+ from transformers import PretrainedConfig
22
+ from transformers.utils import ModelOutput
23
+ from tinychart.model.multimodal_encoder.merge import bipartite_soft_matching, merge_source, merge_wavg
24
+
25
+
26
+ class SigLipImageProcessor:
27
+ def __init__(self,
28
+ image_mean=(0.5, 0.5, 0.5),
29
+ image_std=(0.5, 0.5, 0.5),
30
+ size=(384, 384),
31
+ crop_size: Dict[str, int] = None,
32
+ resample=PILImageResampling.BICUBIC,
33
+ rescale_factor=1 / 255,
34
+ data_format=ChannelDimension.FIRST):
35
+ crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
36
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
37
+
38
+ self.image_mean = image_mean
39
+ self.image_std = image_std
40
+ self.size = size
41
+ self.resample = resample
42
+ self.rescale_factor = rescale_factor
43
+ self.data_format = data_format
44
+ self.crop_size = crop_size
45
+
46
+ def preprocess(self, images, return_tensors):
47
+ if isinstance(images, Image.Image):
48
+ images = [images]
49
+ else:
50
+ assert isinstance(images, list)
51
+
52
+ transforms = [
53
+ convert_to_rgb,
54
+ to_numpy_array,
55
+ partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
56
+ partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
57
+ partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
58
+ partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
59
+ ]
60
+
61
+ images = reduce(lambda x, f: [*map(f, x)], transforms, images)
62
+ data = {"pixel_values": images}
63
+
64
+ return BatchFeature(data=data, tensor_type=return_tensors)
65
+
66
+
67
+ class SigLipVisionConfig(PretrainedConfig):
68
+ model_type = "siglip_vision_model"
69
+
70
+ def __init__(
71
+ self,
72
+ hidden_size=1152,
73
+ image_mean=(0.5, 0.5, 0.5),
74
+ intermediate_size=4304,
75
+ num_hidden_layers=27,
76
+ num_attention_heads=16,
77
+ num_channels=3,
78
+ image_size=384,
79
+ patch_size=14,
80
+ hidden_act="gelu_pytorch_tanh",
81
+ layer_norm_eps=1e-6,
82
+ attention_dropout=0.0,
83
+ **kwargs,
84
+ ):
85
+ super().__init__(**kwargs)
86
+
87
+ self.hidden_size = hidden_size
88
+ self.intermediate_size = intermediate_size
89
+ self.num_hidden_layers = num_hidden_layers
90
+ self.num_attention_heads = num_attention_heads
91
+ self.num_channels = num_channels
92
+ self.patch_size = patch_size
93
+ self.image_size = image_size
94
+ self.attention_dropout = attention_dropout
95
+ self.layer_norm_eps = layer_norm_eps
96
+ self.hidden_act = hidden_act
97
+ self.image_mean = image_mean
98
+
99
+ for key, value in kwargs.items():
100
+ setattr(self, key, value)
101
+
102
+ @classmethod
103
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
104
+ cls._set_token_in_kwargs(kwargs)
105
+
106
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
107
+
108
+ # get the vision config dict if we are loading from SigLipConfig
109
+ if config_dict.get("model_type") == "siglip":
110
+ config_dict = config_dict["vision_config"]
111
+
112
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
113
+ logger.warning(
114
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
115
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
116
+ )
117
+
118
+ return cls.from_dict(config_dict, **kwargs)
119
+
120
+
121
+ @dataclass
122
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
123
+ class SigLipVisionModelOutput(ModelOutput):
124
+ """
125
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
126
+
127
+ Args:
128
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
129
+ The image embeddings obtained by applying the projection layer to the pooler_output.
130
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
131
+ Sequence of hidden-states at the output of the last layer of the model.
132
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
133
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
134
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
135
+
136
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
137
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
138
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
139
+ sequence_length)`.
140
+
141
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
142
+ heads.
143
+ """
144
+
145
+ image_embeds: Optional[torch.FloatTensor] = None
146
+ last_hidden_state: torch.FloatTensor = None
147
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
148
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
149
+
150
+
151
+ class SigLipVisionEmbeddings(nn.Module):
152
+ def __init__(self, config: SigLipVisionConfig):
153
+ super().__init__()
154
+ self.config = config
155
+ self.embed_dim = config.hidden_size
156
+ self.image_size = config.image_size
157
+ self.patch_size = config.patch_size
158
+
159
+ self.patch_embedding = nn.Conv2d(
160
+ in_channels=config.num_channels,
161
+ out_channels=self.embed_dim,
162
+ kernel_size=self.patch_size,
163
+ stride=self.patch_size,
164
+ padding="valid",
165
+ )
166
+
167
+ self.num_patches = (self.image_size // self.patch_size) ** 2
168
+ self.num_positions = self.num_patches
169
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
170
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
171
+
172
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
173
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
174
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
175
+
176
+ embeddings = embeddings + self.position_embedding(self.position_ids)
177
+ return embeddings
178
+
179
+
180
+ class SigLipAttentionToMe(nn.Module):
181
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
182
+
183
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
184
+ def __init__(self, config):
185
+ super().__init__()
186
+ self.config = config
187
+ self.embed_dim = config.hidden_size
188
+ self.num_heads = config.num_attention_heads
189
+ self.head_dim = self.embed_dim // self.num_heads
190
+ if self.head_dim * self.num_heads != self.embed_dim:
191
+ raise ValueError(
192
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
193
+ f" {self.num_heads})."
194
+ )
195
+ self.scale = self.head_dim ** -0.5
196
+ self.dropout = config.attention_dropout
197
+
198
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
199
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
200
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
201
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
202
+
203
+ def forward(
204
+ self,
205
+ hidden_states: torch.Tensor,
206
+ attention_mask: Optional[torch.Tensor] = None,
207
+ output_attentions: Optional[bool] = False,
208
+ size: torch.Tensor = None,
209
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
210
+ """Input shape: Batch x Time x Channel"""
211
+
212
+ batch_size, q_len, _ = hidden_states.size()
213
+
214
+ query_states = self.q_proj(hidden_states)
215
+ key_states = self.k_proj(hidden_states)
216
+ value_states = self.v_proj(hidden_states)
217
+
218
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
219
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
220
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
221
+
222
+ k_v_seq_len = key_states.shape[-2]
223
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
224
+
225
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
226
+ raise ValueError(
227
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
228
+ f" {attn_weights.size()}"
229
+ )
230
+
231
+ if attention_mask is not None:
232
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
233
+ raise ValueError(
234
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
235
+ )
236
+ attn_weights = attn_weights + attention_mask
237
+
238
+ # upcast attention to fp32
239
+ if size is not None:
240
+ attn_weights += size.log()[:, None, None, :, 0]
241
+
242
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
243
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
244
+ attn_output = torch.matmul(attn_weights, value_states)
245
+
246
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
247
+ raise ValueError(
248
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
249
+ f" {attn_output.size()}"
250
+ )
251
+
252
+ attn_output = attn_output.transpose(1, 2).contiguous()
253
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
254
+
255
+ attn_output = self.out_proj(attn_output)
256
+ return attn_output, attn_weights, key_states.mean(dim=1)
257
+
258
+
259
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
260
+ class SigLipMLP(nn.Module):
261
+ def __init__(self, config):
262
+ super().__init__()
263
+ self.config = config
264
+ self.activation_fn = ACT2FN[config.hidden_act]
265
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
266
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
267
+
268
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
269
+ hidden_states = self.fc1(hidden_states)
270
+ hidden_states = self.activation_fn(hidden_states)
271
+ hidden_states = self.fc2(hidden_states)
272
+ return hidden_states
273
+
274
+
275
+ class SigLipEncoderLayerToMe(nn.Module):
276
+ def __init__(self, config: SigLipVisionConfig, layer_id=None):
277
+ super().__init__()
278
+ self.embed_dim = config.hidden_size
279
+ self.self_attn = SigLipAttentionToMe(config)
280
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
281
+ self.mlp = SigLipMLP(config)
282
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
283
+ self.r = config.tome_r
284
+ self.layer_id = layer_id
285
+
286
+ # Ignore copy
287
+ def forward(
288
+ self,
289
+ hidden_states: torch.Tensor,
290
+ attention_mask: torch.Tensor,
291
+ output_attentions: Optional[bool] = False,
292
+ attention_size=None,
293
+ source=None,
294
+ trace_source=False
295
+ ) -> Tuple[torch.FloatTensor]:
296
+ """
297
+ Args:
298
+ hidden_states (`torch.FloatTensor`):
299
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
300
+ attention_mask (`torch.FloatTensor`):
301
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
302
+ output_attentions (`bool`, *optional*, defaults to `False`):
303
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
304
+ returned tensors for more detail.
305
+ """
306
+ residual = hidden_states
307
+
308
+ hidden_states = self.layer_norm1(hidden_states)
309
+ hidden_states, attn_weights, metric = self.self_attn(
310
+ hidden_states=hidden_states,
311
+ attention_mask=attention_mask,
312
+ output_attentions=output_attentions,
313
+ size=attention_size
314
+ )
315
+ hidden_states = residual + hidden_states
316
+
317
+ if self.r > 0:
318
+ merge, unmerge = bipartite_soft_matching(
319
+ metric,
320
+ r=self.r,
321
+ class_token=False,
322
+ distill_token=False
323
+ )
324
+ if trace_source:
325
+ source = merge_source(merge, hidden_states, source)
326
+ hidden_states, attention_size = merge_wavg(merge, hidden_states, attention_size)
327
+ residual = hidden_states
328
+ hidden_states = self.layer_norm2(hidden_states)
329
+ hidden_states = self.mlp(hidden_states)
330
+ hidden_states = residual + hidden_states
331
+
332
+ outputs = (hidden_states,)
333
+
334
+ if output_attentions:
335
+ outputs += (attn_weights,)
336
+
337
+ if trace_source:
338
+ outputs += (source,)
339
+
340
+ outputs += (attention_size,)
341
+
342
+ return outputs
343
+
344
+
345
+ class SigLipPreTrainedModel(PreTrainedModel):
346
+ """
347
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
348
+ models.
349
+ """
350
+
351
+ config_class = SigLipVisionConfig
352
+ base_model_prefix = "siglip"
353
+ supports_gradient_checkpointing = True
354
+
355
+ def _init_weights(self, module):
356
+ """Initialize the weights"""
357
+ pass
358
+
359
+
360
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
361
+ class SigLipEncoder(nn.Module):
362
+ """
363
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
364
+ [`SigLipEncoderLayer`].
365
+
366
+ Args:
367
+ config: SigLipVisionConfig
368
+ """
369
+
370
+ def __init__(self, config: SigLipVisionConfig):
371
+ super().__init__()
372
+ self.config = config
373
+ self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
374
+ self.gradient_checkpointing = False
375
+
376
+ # Ignore copy
377
+ def forward(
378
+ self,
379
+ inputs_embeds,
380
+ attention_mask: Optional[torch.Tensor] = None,
381
+ output_attentions: Optional[bool] = None,
382
+ output_hidden_states: Optional[bool] = None,
383
+ return_dict: Optional[bool] = None,
384
+ ) -> Union[Tuple, BaseModelOutput]:
385
+ r"""
386
+ Args:
387
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
388
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
389
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
390
+ than the model's internal embedding lookup matrix.
391
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
392
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
393
+
394
+ - 1 for tokens that are **not masked**,
395
+ - 0 for tokens that are **masked**.
396
+
397
+ [What are attention masks?](../glossary#attention-mask)
398
+ output_attentions (`bool`, *optional*):
399
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
400
+ returned tensors for more detail.
401
+ output_hidden_states (`bool`, *optional*):
402
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
403
+ for more detail.
404
+ return_dict (`bool`, *optional*):
405
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
406
+ """
407
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
408
+ output_hidden_states = (
409
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
410
+ )
411
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
412
+
413
+ encoder_states = () if output_hidden_states else None
414
+ all_attentions = () if output_attentions else None
415
+
416
+ hidden_states = inputs_embeds
417
+ for encoder_layer in self.layers:
418
+ if output_hidden_states:
419
+ encoder_states = encoder_states + (hidden_states,)
420
+ if self.gradient_checkpointing and self.training:
421
+ layer_outputs = self._gradient_checkpointing_func(
422
+ encoder_layer.__call__,
423
+ hidden_states,
424
+ attention_mask,
425
+ output_attentions,
426
+ )
427
+ else:
428
+ layer_outputs = encoder_layer(
429
+ hidden_states,
430
+ attention_mask,
431
+ output_attentions=output_attentions,
432
+ )
433
+
434
+ hidden_states = layer_outputs[0]
435
+
436
+ if output_attentions:
437
+ all_attentions = all_attentions + (layer_outputs[1],)
438
+
439
+ if output_hidden_states:
440
+ encoder_states = encoder_states + (hidden_states,)
441
+
442
+ if not return_dict:
443
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
444
+ return BaseModelOutput(
445
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
446
+ )
447
+
448
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
449
+ class SigLipEncoderToMe(nn.Module):
450
+ """
451
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
452
+ [`SigLipEncoderLayerToMe`].
453
+
454
+ Args:
455
+ config: SigLipVisionConfig
456
+ """
457
+
458
+ def __init__(self, config: SigLipVisionConfig):
459
+ super().__init__()
460
+ self.config = config
461
+ self.layers = nn.ModuleList([SigLipEncoderLayerToMe(config, layer_id=layer_id) for layer_id in range(config.num_hidden_layers)])
462
+ self.gradient_checkpointing = False
463
+ self.trace_source = getattr(config, 'trace_source', False)
464
+
465
+ # Ignore copy
466
+ def forward(
467
+ self,
468
+ inputs_embeds,
469
+ attention_mask: Optional[torch.Tensor] = None,
470
+ output_attentions: Optional[bool] = None,
471
+ output_hidden_states: Optional[bool] = None,
472
+ return_dict: Optional[bool] = None,
473
+ ) -> Union[Tuple, BaseModelOutput]:
474
+ r"""
475
+ Args:
476
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
477
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
478
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
479
+ than the model's internal embedding lookup matrix.
480
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
481
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
482
+
483
+ - 1 for tokens that are **not masked**,
484
+ - 0 for tokens that are **masked**.
485
+
486
+ [What are attention masks?](../glossary#attention-mask)
487
+ output_attentions (`bool`, *optional*):
488
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
489
+ returned tensors for more detail.
490
+ output_hidden_states (`bool`, *optional*):
491
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
492
+ for more detail.
493
+ return_dict (`bool`, *optional*):
494
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
495
+ """
496
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
497
+ output_hidden_states = (
498
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
499
+ )
500
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
501
+
502
+ encoder_states = () if output_hidden_states else None
503
+ all_attentions = () if output_attentions else None
504
+
505
+ hidden_states = inputs_embeds
506
+ attention_size = None
507
+ source = None
508
+ for encoder_layer in self.layers:
509
+ if output_hidden_states:
510
+ encoder_states = encoder_states + (hidden_states,)
511
+ if self.gradient_checkpointing and self.training:
512
+
513
+ layer_outputs = self._gradient_checkpointing_func(
514
+ encoder_layer.__call__,
515
+ hidden_states,
516
+ attention_mask,
517
+ output_attentions,
518
+ attention_size,
519
+ source if self.trace_source else None,
520
+ self.trace_source
521
+ )
522
+ else:
523
+ layer_outputs = encoder_layer(
524
+ hidden_states,
525
+ attention_mask,
526
+ output_attentions=output_attentions,
527
+ attention_size=attention_size,
528
+ source=source if self.trace_source else None,
529
+ trace_source=self.trace_source
530
+ )
531
+
532
+ hidden_states = layer_outputs[0]
533
+ if self.trace_source:
534
+ source = layer_outputs[-2]
535
+
536
+ attention_size = layer_outputs[-1]
537
+
538
+ if output_attentions:
539
+ all_attentions = all_attentions + (layer_outputs[1],)
540
+
541
+ if output_hidden_states:
542
+ encoder_states = encoder_states + (hidden_states,)
543
+
544
+ if not return_dict:
545
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
546
+ return BaseModelOutput(
547
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
548
+ )
549
+
550
+
551
+ class SigLipVisionTransformer(nn.Module):
552
+ def __init__(self, config: SigLipVisionConfig):
553
+ super().__init__()
554
+ self.config = config
555
+ embed_dim = config.hidden_size
556
+
557
+ self.embeddings = SigLipVisionEmbeddings(config)
558
+ self.encoder = SigLipEncoderToMe(config)
559
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
560
+ self.head = SigLipMultiheadAttentionPoolingHead(config)
561
+
562
+ def forward(
563
+ self,
564
+ pixel_values,
565
+ output_attentions: Optional[bool] = None,
566
+ output_hidden_states: Optional[bool] = None,
567
+ return_dict: Optional[bool] = None,
568
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
569
+ r"""
570
+ Returns:
571
+
572
+ """
573
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
574
+ output_hidden_states = (
575
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
576
+ )
577
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
578
+
579
+ hidden_states = self.embeddings(pixel_values)
580
+
581
+ encoder_outputs = self.encoder(
582
+ inputs_embeds=hidden_states,
583
+ output_attentions=output_attentions,
584
+ output_hidden_states=output_hidden_states,
585
+ return_dict=return_dict,
586
+ )
587
+
588
+ last_hidden_state = encoder_outputs[0]
589
+ last_hidden_state = self.post_layernorm(last_hidden_state)
590
+
591
+ pooled_output = self.head(last_hidden_state)
592
+
593
+ if not return_dict:
594
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
595
+
596
+ return BaseModelOutputWithPooling(
597
+ last_hidden_state=last_hidden_state,
598
+ pooler_output=pooled_output,
599
+ hidden_states=encoder_outputs.hidden_states,
600
+ attentions=encoder_outputs.attentions,
601
+ )
602
+
603
+
604
+ class SigLipMultiheadAttentionPoolingHead(nn.Module):
605
+ """Multihead Attention Pooling."""
606
+
607
+ def __init__(self, config: SigLipVisionConfig):
608
+ super().__init__()
609
+
610
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
611
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
612
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
613
+ self.mlp = SigLipMLP(config)
614
+
615
+ def forward(self, hidden_state):
616
+ batch_size = hidden_state.shape[0]
617
+ probe = self.probe.repeat(batch_size, 1, 1)
618
+
619
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
620
+
621
+ residual = hidden_state
622
+ hidden_state = self.layernorm(hidden_state)
623
+ hidden_state = residual + self.mlp(hidden_state)
624
+
625
+ return hidden_state[:, 0]
626
+
627
+
628
+ class SigLipVisionModel(SigLipPreTrainedModel):
629
+ config_class = SigLipVisionConfig
630
+ main_input_name = "pixel_values"
631
+ _no_split_modules = ["SigLipEncoderLayerToMe"]
632
+
633
+ def __init__(self, config: SigLipVisionConfig):
634
+ super().__init__(config)
635
+
636
+ self.vision_model = SigLipVisionTransformer(config)
637
+ del self.vision_model.encoder.layers[-1:]
638
+ self.vision_model.head = nn.Identity()
639
+ # Initialize weights and apply final processing
640
+ self.post_init()
641
+
642
+ def get_input_embeddings(self) -> nn.Module:
643
+ return self.vision_model.embeddings.patch_embedding
644
+
645
+ def forward(
646
+ self,
647
+ pixel_values,
648
+ output_attentions: Optional[bool] = None,
649
+ output_hidden_states: Optional[bool] = None,
650
+ return_dict: Optional[bool] = None,
651
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
652
+ r"""
653
+ Returns:
654
+
655
+ Examples:
656
+
657
+ ```python
658
+ >>> from PIL import Image
659
+ >>> import requests
660
+ >>> from transformers import AutoProcessor, SigLipVisionModel
661
+
662
+ >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
663
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
664
+
665
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
666
+ >>> image = Image.open(requests.get(url, stream=True).raw)
667
+
668
+ >>> inputs = processor(images=image, return_tensors="pt")
669
+
670
+ >>> outputs = model(**inputs)
671
+ >>> last_hidden_state = outputs.last_hidden_state
672
+ >>> pooled_output = outputs.pooler_output # pooled features
673
+ ```"""
674
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
675
+
676
+ return self.vision_model(
677
+ pixel_values=pixel_values,
678
+ output_attentions=output_attentions,
679
+ output_hidden_states=output_hidden_states,
680
+ return_dict=return_dict,
681
+ )
682
+
683
+ class SigLipVisionTower(nn.Module):
684
+ def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
685
+ super().__init__()
686
+
687
+ self.is_loaded = False
688
+
689
+ if vision_tower is not None:
690
+ self.config = SigLipVisionConfig.from_pretrained(vision_tower)
691
+ else:
692
+ self.config = SigLipVisionConfig()
693
+
694
+ self.vision_tower_name = vision_tower
695
+
696
+ self.image_processor = SigLipImageProcessor(size=(self.config.image_size, self.config.image_size), image_mean=self.config.image_mean)
697
+
698
+ if not delay_load:
699
+ self.load_model()
700
+ else:
701
+ self.cfg_only = self.config
702
+
703
+ def load_model(self):
704
+ if self.is_loaded:
705
+ return
706
+
707
+ self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name)
708
+
709
+ self.vision_tower.requires_grad_(False)
710
+ self.vision_tower.eval()
711
+
712
+ self.is_loaded = True
713
+
714
+ # @torch.no_grad()
715
+ def forward(self, images):
716
+ if type(images) is list:
717
+ image_features = []
718
+ for image in images:
719
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
720
+ output_hidden_states=True)
721
+ image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
722
+
723
+ image_features.append(image_feature)
724
+ else:
725
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
726
+ output_hidden_states=True)
727
+ image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
728
+
729
+ return image_features
730
+
731
+ @property
732
+ def dummy_feature(self):
733
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
734
+
735
+ @property
736
+ def dtype(self):
737
+ for p in self.vision_tower.parameters():
738
+ return p.dtype
739
+
740
+ @property
741
+ def device(self):
742
+ for p in self.vision_tower.parameters():
743
+ return p.device
744
+
745
+ @property
746
+ def hidden_size(self):
747
+ return self.config.hidden_size
748
+
749
+ @property
750
+ def num_patches(self):
751
+ return (self.config.image_size // self.config.patch_size) ** 2
tinychart/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import re
6
+
7
+ from einops import rearrange, repeat
8
+
9
+
10
+ class IdentityMap(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def forward(self, x, *args, **kwargs):
15
+ return x
16
+
17
+ @property
18
+ def config(self):
19
+ return {"mm_projector_type": 'identity'}
20
+
21
+
22
+ class SimpleResBlock(nn.Module):
23
+ def __init__(self, channels):
24
+ super().__init__()
25
+ self.pre_norm = nn.LayerNorm(channels)
26
+
27
+ self.proj = nn.Sequential(
28
+ nn.Linear(channels, channels),
29
+ nn.GELU(),
30
+ nn.Linear(channels, channels)
31
+ )
32
+ def forward(self, x):
33
+ x = self.pre_norm(x)
34
+ return x + self.proj(x)
35
+
36
+
37
+ class ResamplerBlock(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_size: int = 768,
41
+ image_hidden_size: int = 1024,
42
+ num_heads: int = 12,
43
+ intermediate_size: int = None
44
+ ):
45
+ super().__init__()
46
+ assert hidden_size % num_heads == 0, "For MHSA, you must have number of heads divisible by initial hidden size"
47
+ intermediate_size = hidden_size * 4 if intermediate_size is None else intermediate_size
48
+ # intermediate_size = hidden_size * 4
49
+ self.scale = 1 / math.sqrt(hidden_size // num_heads)
50
+ self.num_heads = num_heads
51
+ self.to_q = nn.Linear(hidden_size, hidden_size, bias=False)
52
+ self.to_k = nn.Linear(image_hidden_size, hidden_size, bias=False)
53
+ self.to_v = nn.Linear(image_hidden_size, hidden_size, bias=False)
54
+
55
+ self.to_out = nn.Linear(hidden_size, hidden_size, bias=False)
56
+
57
+ self.feed_forward = nn.Sequential(
58
+ *[
59
+ nn.LayerNorm(hidden_size),
60
+ nn.Linear(hidden_size, intermediate_size, bias=False),
61
+ nn.GELU(),
62
+ nn.Linear(intermediate_size, hidden_size, bias=False),
63
+ ]
64
+ )
65
+ # prenorm for image features
66
+ self.norm_image = nn.LayerNorm(image_hidden_size)
67
+ self.norm_hidden = nn.LayerNorm(hidden_size)
68
+
69
+ def forward(self, hidden_states: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
70
+ # prenorm
71
+ x = self.norm_image(x)
72
+ residual_hidden_states = hidden_states
73
+ hidden_states = self.norm_hidden(hidden_states)
74
+ # compute Q, K, V
75
+ queries = self.to_q(hidden_states)
76
+ keys = self.to_k(x)
77
+ values = self.to_v(x)
78
+ # rearrange them into multi-head format
79
+ queries = rearrange(queries, "b n (h d) -> b h n d", h=self.num_heads)
80
+ keys = rearrange(keys, "b n (h d) -> b h n d", h=self.num_heads)
81
+ values = rearrange(values, "b n (h d) -> b h n d", h=self.num_heads)
82
+ # rescale
83
+ queries = self.scale * queries
84
+ # compute QK^T
85
+ scores = torch.einsum("... i d, ... j d -> ... i j", queries, keys)
86
+ # for stability
87
+ scores = scores - scores.amax(dim=-1, keepdim=True).detach()
88
+ # softmax
89
+ attention_scores = scores.softmax(dim=-1) # b h i j (i: number of queries, j: number of keys)
90
+ # dot product with V
91
+ out = torch.einsum("... i j, ... j d -> ... i d", attention_scores, values)
92
+ out = rearrange(out, "b h n d -> b n (h d)", h=self.num_heads)
93
+ out = self.to_out(out) + residual_hidden_states
94
+ residual_out = out
95
+ out = self.feed_forward(out)
96
+ return out + residual_out
97
+
98
+
99
+ class Resampler(nn.Module):
100
+ def __init__(
101
+ self,
102
+ hidden_size: int = 768,
103
+ image_hidden_size: int = 1024,
104
+ final_hidden_size: int = 4096,
105
+ num_heads: int = 12,
106
+ intermediate_size: int = None,
107
+ num_queries: int = 128,
108
+ num_layers: int = 3,
109
+ initializer_range: float = 0.02
110
+ ):
111
+ super().__init__()
112
+ self.resampler_blocks = nn.ModuleList(
113
+ [
114
+ ResamplerBlock(
115
+ hidden_size, image_hidden_size, num_heads, intermediate_size
116
+ ) for _ in range(num_layers)
117
+ ]
118
+ )
119
+ self.queries = nn.Parameter(torch.randn(num_queries, hidden_size))
120
+ self.post_norm = nn.LayerNorm(hidden_size)
121
+
122
+ self.final_proj = nn.Linear(hidden_size, final_hidden_size, bias=False)
123
+
124
+ # self.initializer_range = initializer_range
125
+ # for module in self.modules():
126
+ # if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Conv2d)):
127
+ # self._init_weights(module)
128
+ #
129
+ # def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
130
+ # """Initialize the weights"""
131
+ # if isinstance(module, (nn.Linear, nn.Conv2d)):
132
+ # # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
133
+ # # `trunc_normal_cpu` not implemented in `half` issues
134
+ # module.weight.data = nn.init.trunc_normal_(
135
+ # module.weight.data.to(torch.float32), mean=0.0, std=self.initializer_range
136
+ # ).to(module.weight.dtype)
137
+ # if module.bias is not None:
138
+ # module.bias.data.zero_()
139
+ # elif isinstance(module, nn.LayerNorm):
140
+ # module.bias.data.zero_()
141
+ # module.weight.data.fill_(1.0)
142
+
143
+ def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
144
+ b = image_hidden_states.size(0)
145
+ queries = repeat(self.queries, 'n d -> b n d', b=b)
146
+ for resampler_block in self.resampler_blocks:
147
+ queries = resampler_block(queries, image_hidden_states)
148
+
149
+ # post norm
150
+ queries = self.post_norm(queries)
151
+ return self.final_proj(queries)
152
+
153
+
154
+ def build_vision_projector(config, delay_load=False, **kwargs):
155
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
156
+
157
+ if projector_type == 'linear':
158
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
159
+
160
+ if projector_type == 'resampler':
161
+ hidden_size = getattr(config, 'resampler_hidden_size', 768)
162
+ image_hidden_size = config.mm_hidden_size
163
+ num_queries = getattr(config, 'num_queries', 128)
164
+ final_hidden_size = config.hidden_size
165
+ num_heads = 12
166
+ if hidden_size == 512:
167
+ num_heads = 8
168
+ num_layers = getattr(config, 'num_resampler_layers', 3)
169
+
170
+ initializer_range = getattr(config, 'initializer_range', 0.02)
171
+ print(
172
+ f"resampler config: resampler hidden size: {hidden_size}, num_queries: {num_queries}, "
173
+ f"num_resampler_layers: {num_layers}"
174
+ )
175
+ return Resampler(
176
+ hidden_size=hidden_size,
177
+ image_hidden_size=image_hidden_size,
178
+ num_queries=num_queries,
179
+ final_hidden_size=final_hidden_size,
180
+ num_layers=num_layers,
181
+ num_heads=num_heads,
182
+ initializer_range=initializer_range
183
+ )
184
+
185
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
186
+ if mlp_gelu_match:
187
+ mlp_depth = int(mlp_gelu_match.group(1))
188
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
189
+ for _ in range(1, mlp_depth):
190
+ modules.append(nn.GELU())
191
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
192
+ mlp = nn.Sequential(*modules)
193
+ if getattr(config, 'load_moe_mm_projector', False):
194
+ from deepspeed.moe.layer import MoE
195
+ mlp = MoE(
196
+ config.mm_hidden_size,
197
+ expert=mlp,
198
+ num_experts=4,
199
+ ep_size=1,
200
+ k=2,
201
+ capacity_factor=1.,
202
+ eval_capacity_factor=1.,
203
+ min_capacity=4,
204
+ use_residual=False,
205
+ )
206
+
207
+ def moe_forward_wrapper(forward_func):
208
+ return lambda *args, **kwargs: forward_func(*args, **kwargs)[0]
209
+ mlp.forward = moe_forward_wrapper(mlp.forward)
210
+ return mlp
211
+
212
+ if projector_type == 'identity':
213
+ return IdentityMap()
214
+
215
+ raise ValueError(f'Unknown projector type: {projector_type}')
tinychart/utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from tinychart.constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+
17
+ def build_logger(logger_name, logger_filename):
18
+ global handler
19
+
20
+ formatter = logging.Formatter(
21
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S",
23
+ )
24
+
25
+ # Set the format of root handlers
26
+ if not logging.getLogger().handlers:
27
+ logging.basicConfig(level=logging.INFO)
28
+ logging.getLogger().handlers[0].setFormatter(formatter)
29
+
30
+ # Redirect stdout and stderr to loggers
31
+ stdout_logger = logging.getLogger("stdout")
32
+ stdout_logger.setLevel(logging.INFO)
33
+ sl = StreamToLogger(stdout_logger, logging.INFO)
34
+ sys.stdout = sl
35
+
36
+ stderr_logger = logging.getLogger("stderr")
37
+ stderr_logger.setLevel(logging.ERROR)
38
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
39
+ sys.stderr = sl
40
+
41
+ # Get logger
42
+ logger = logging.getLogger(logger_name)
43
+ logger.setLevel(logging.INFO)
44
+
45
+ # Add a file handler for all loggers
46
+ if handler is None:
47
+ os.makedirs(LOGDIR, exist_ok=True)
48
+ filename = os.path.join(LOGDIR, logger_filename)
49
+ handler = logging.handlers.TimedRotatingFileHandler(
50
+ filename, when='D', utc=True, encoding='UTF-8')
51
+ handler.setFormatter(formatter)
52
+
53
+ for name, item in logging.root.manager.loggerDict.items():
54
+ if isinstance(item, logging.Logger):
55
+ item.addHandler(handler)
56
+
57
+ return logger
58
+
59
+
60
+ class StreamToLogger(object):
61
+ """
62
+ Fake file-like stream object that redirects writes to a logger instance.
63
+ """
64
+ def __init__(self, logger, log_level=logging.INFO):
65
+ self.terminal = sys.stdout
66
+ self.logger = logger
67
+ self.log_level = log_level
68
+ self.linebuf = ''
69
+
70
+ def __getattr__(self, attr):
71
+ return getattr(self.terminal, attr)
72
+
73
+ def write(self, buf):
74
+ temp_linebuf = self.linebuf + buf
75
+ self.linebuf = ''
76
+ for line in temp_linebuf.splitlines(True):
77
+ # From the io.TextIOWrapper LOGS:
78
+ # On output, if newline is None, any '\n' characters written
79
+ # are translated to the system default line separator.
80
+ # By default sys.stdout.write() expects '\n' newlines and then
81
+ # translates them so this is still cross platform.
82
+ if line[-1] == '\n':
83
+ self.logger.log(self.log_level, line.rstrip())
84
+ else:
85
+ self.linebuf += line
86
+
87
+ def flush(self):
88
+ if self.linebuf != '':
89
+ self.logger.log(self.log_level, self.linebuf.rstrip())
90
+ self.linebuf = ''
91
+
92
+
93
+ def disable_torch_init():
94
+ """
95
+ Disable the redundant torch default initialization to accelerate model creation.
96
+ """
97
+ import torch
98
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100
+
101
+
102
+ def violates_moderation(text):
103
+ """
104
+ Check whether the text violates OpenAI moderation API.
105
+ """
106
+ url = "https://api.openai.com/v1/moderations"
107
+ headers = {"Content-Type": "application/json",
108
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109
+ text = text.replace("\n", "")
110
+ data = "{" + '"input": ' + f'"{text}"' + "}"
111
+ data = data.encode("utf-8")
112
+ try:
113
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
114
+ flagged = ret.json()["results"][0]["flagged"]
115
+ except requests.exceptions.RequestException as e:
116
+ flagged = False
117
+ except KeyError as e:
118
+ flagged = False
119
+
120
+ return flagged
121
+
122
+
123
+ def pretty_print_semaphore(semaphore):
124
+ if semaphore is None:
125
+ return "None"
126
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127
+
128
+
129
+
130
+ local_rank = None
131
+
132
+ def rank0_print(*args):
133
+ if os.environ["RANK"] == '0':
134
+ print(*args)