Spaces:
Paused
Paused
init
Browse files- app.py +227 -0
- chat.py +215 -0
- omnilmm/__init__.py +0 -0
- omnilmm/constants.py +4 -0
- omnilmm/conversation.py +320 -0
- omnilmm/model/__init__.py +1 -0
- omnilmm/model/omnilmm.py +457 -0
- omnilmm/model/resampler.py +171 -0
- omnilmm/model/utils.py +555 -0
- omnilmm/train/train_utils.py +153 -0
- omnilmm/utils.py +127 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# encoding: utf-8
|
3 |
+
import spaces
|
4 |
+
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
import traceback
|
7 |
+
import re
|
8 |
+
import torch
|
9 |
+
import argparse
|
10 |
+
from transformers import AutoModel, AutoTokenizer
|
11 |
+
from chat import OmniLMM12B
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# Load model
|
16 |
+
model_path = 'openbmb/RLAIF-V-12B'
|
17 |
+
model = OmniLMM12B(model_path)
|
18 |
+
|
19 |
+
|
20 |
+
ERROR_MSG = "Error, please retry"
|
21 |
+
model_name = 'RLAIF-V-12B'
|
22 |
+
|
23 |
+
form_radio = {
|
24 |
+
'choices': ['Beam Search', 'Sampling'],
|
25 |
+
#'value': 'Beam Search',
|
26 |
+
'value': 'Sampling',
|
27 |
+
'interactive': True,
|
28 |
+
'label': 'Decode Type'
|
29 |
+
}
|
30 |
+
# Beam Form
|
31 |
+
num_beams_slider = {
|
32 |
+
'minimum': 0,
|
33 |
+
'maximum': 5,
|
34 |
+
'value': 3,
|
35 |
+
'step': 1,
|
36 |
+
'interactive': True,
|
37 |
+
'label': 'Num Beams'
|
38 |
+
}
|
39 |
+
repetition_penalty_slider = {
|
40 |
+
'minimum': 0,
|
41 |
+
'maximum': 3,
|
42 |
+
'value': 1.2,
|
43 |
+
'step': 0.01,
|
44 |
+
'interactive': True,
|
45 |
+
'label': 'Repetition Penalty'
|
46 |
+
}
|
47 |
+
repetition_penalty_slider2 = {
|
48 |
+
'minimum': 0,
|
49 |
+
'maximum': 3,
|
50 |
+
'value': 1.05,
|
51 |
+
'step': 0.01,
|
52 |
+
'interactive': True,
|
53 |
+
'label': 'Repetition Penalty'
|
54 |
+
}
|
55 |
+
max_new_tokens_slider = {
|
56 |
+
'minimum': 1,
|
57 |
+
'maximum': 4096,
|
58 |
+
'value': 1024,
|
59 |
+
'step': 1,
|
60 |
+
'interactive': True,
|
61 |
+
'label': 'Max New Tokens'
|
62 |
+
}
|
63 |
+
|
64 |
+
top_p_slider = {
|
65 |
+
'minimum': 0,
|
66 |
+
'maximum': 1,
|
67 |
+
'value': 0.8,
|
68 |
+
'step': 0.05,
|
69 |
+
'interactive': True,
|
70 |
+
'label': 'Top P'
|
71 |
+
}
|
72 |
+
top_k_slider = {
|
73 |
+
'minimum': 0,
|
74 |
+
'maximum': 200,
|
75 |
+
'value': 100,
|
76 |
+
'step': 1,
|
77 |
+
'interactive': True,
|
78 |
+
'label': 'Top K'
|
79 |
+
}
|
80 |
+
temperature_slider = {
|
81 |
+
'minimum': 0,
|
82 |
+
'maximum': 2,
|
83 |
+
'value': 0.7,
|
84 |
+
'step': 0.05,
|
85 |
+
'interactive': True,
|
86 |
+
'label': 'Temperature'
|
87 |
+
}
|
88 |
+
|
89 |
+
|
90 |
+
def create_component(params, comp='Slider'):
|
91 |
+
if comp == 'Slider':
|
92 |
+
return gr.Slider(
|
93 |
+
minimum=params['minimum'],
|
94 |
+
maximum=params['maximum'],
|
95 |
+
value=params['value'],
|
96 |
+
step=params['step'],
|
97 |
+
interactive=params['interactive'],
|
98 |
+
label=params['label']
|
99 |
+
)
|
100 |
+
elif comp == 'Radio':
|
101 |
+
return gr.Radio(
|
102 |
+
choices=params['choices'],
|
103 |
+
value=params['value'],
|
104 |
+
interactive=params['interactive'],
|
105 |
+
label=params['label']
|
106 |
+
)
|
107 |
+
elif comp == 'Button':
|
108 |
+
return gr.Button(
|
109 |
+
value=params['value'],
|
110 |
+
interactive=True
|
111 |
+
)
|
112 |
+
|
113 |
+
@spaces.GPU(duration=120)
|
114 |
+
def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
|
115 |
+
if img is None:
|
116 |
+
return -1, "Error, invalid image, please upload a new image", None, None
|
117 |
+
try:
|
118 |
+
image = img.convert('RGB')
|
119 |
+
answer = model.chat(
|
120 |
+
image=image,
|
121 |
+
msgs=msgs,
|
122 |
+
)
|
123 |
+
return 0, answer, None, None
|
124 |
+
except Exception as err:
|
125 |
+
print(err)
|
126 |
+
traceback.print_exc()
|
127 |
+
return -1, ERROR_MSG, None, None
|
128 |
+
|
129 |
+
|
130 |
+
def upload_img(image, _chatbot, _app_session):
|
131 |
+
image = Image.fromarray(image)
|
132 |
+
|
133 |
+
_app_session['sts']=None
|
134 |
+
_app_session['ctx']=[]
|
135 |
+
_app_session['img']=image
|
136 |
+
_chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
|
137 |
+
return _chatbot, _app_session
|
138 |
+
|
139 |
+
|
140 |
+
def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
|
141 |
+
if _app_cfg.get('ctx', None) is None:
|
142 |
+
_chat_bot.append((_question, 'Please upload an image to start'))
|
143 |
+
return '', _chat_bot, _app_cfg
|
144 |
+
|
145 |
+
_context = _app_cfg['ctx'].copy()
|
146 |
+
if _context:
|
147 |
+
_context.append({"role": "user", "content": _question})
|
148 |
+
else:
|
149 |
+
_context = [{"role": "user", "content": _question}]
|
150 |
+
print('<User>:', _question)
|
151 |
+
|
152 |
+
if params_form == 'Beam Search':
|
153 |
+
params = {
|
154 |
+
'sampling': False,
|
155 |
+
'num_beams': num_beams,
|
156 |
+
'repetition_penalty': repetition_penalty,
|
157 |
+
"max_new_tokens": 896
|
158 |
+
}
|
159 |
+
else:
|
160 |
+
params = {
|
161 |
+
'sampling': True,
|
162 |
+
'top_p': top_p,
|
163 |
+
'top_k': top_k,
|
164 |
+
'temperature': temperature,
|
165 |
+
'repetition_penalty': repetition_penalty_2,
|
166 |
+
"max_new_tokens": 896
|
167 |
+
}
|
168 |
+
code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
|
169 |
+
print('<Assistant>:', _answer)
|
170 |
+
|
171 |
+
_context.append({"role": "assistant", "content": _answer})
|
172 |
+
_chat_bot.append((_question, _answer))
|
173 |
+
if code == 0:
|
174 |
+
_app_cfg['ctx']=_context
|
175 |
+
_app_cfg['sts']=sts
|
176 |
+
return '', _chat_bot, _app_cfg
|
177 |
+
|
178 |
+
|
179 |
+
def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
|
180 |
+
if len(_chat_bot) <= 1:
|
181 |
+
_chat_bot.append(('Regenerate', 'No question for regeneration.'))
|
182 |
+
return '', _chat_bot, _app_cfg
|
183 |
+
elif _chat_bot[-1][0] == 'Regenerate':
|
184 |
+
return '', _chat_bot, _app_cfg
|
185 |
+
else:
|
186 |
+
_question = _chat_bot[-1][0]
|
187 |
+
_chat_bot = _chat_bot[:-1]
|
188 |
+
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
|
189 |
+
return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
with gr.Blocks() as demo:
|
194 |
+
with gr.Row():
|
195 |
+
with gr.Column(scale=1, min_width=300):
|
196 |
+
params_form = create_component(form_radio, comp='Radio', visible=True)
|
197 |
+
with gr.Accordion("Beam Search") as beams_according:
|
198 |
+
num_beams = create_component(num_beams_slider)
|
199 |
+
repetition_penalty = create_component(repetition_penalty_slider)
|
200 |
+
with gr.Accordion("Sampling") as sampling_according:
|
201 |
+
top_p = create_component(top_p_slider)
|
202 |
+
top_k = create_component(top_k_slider)
|
203 |
+
temperature = create_component(temperature_slider)
|
204 |
+
repetition_penalty_2 = create_component(repetition_penalty_slider2)
|
205 |
+
regenerate = create_component({'value': 'Regenerate'}, comp='Button')
|
206 |
+
with gr.Column(scale=3, min_width=500):
|
207 |
+
app_session = gr.State({'sts':None,'ctx':None,'img':None})
|
208 |
+
bt_pic = gr.Image(label="Upload an image to start")
|
209 |
+
chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
|
210 |
+
txt_message = gr.Textbox(label="Input text")
|
211 |
+
|
212 |
+
regenerate.click(
|
213 |
+
regenerate_button_clicked,
|
214 |
+
[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
|
215 |
+
[txt_message, chat_bot, app_session]
|
216 |
+
)
|
217 |
+
txt_message.submit(
|
218 |
+
respond,
|
219 |
+
[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
|
220 |
+
[txt_message, chat_bot, app_session]
|
221 |
+
)
|
222 |
+
bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
|
223 |
+
|
224 |
+
# launch
|
225 |
+
#demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
|
226 |
+
demo.launch()
|
227 |
+
|
chat.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from PIL import Image
|
5 |
+
import base64
|
6 |
+
import io
|
7 |
+
#from accelerate import load_checkpoint_and_dispatch, init_empty_weights
|
8 |
+
from transformers import AutoTokenizer, AutoModel
|
9 |
+
|
10 |
+
from omnilmm.utils import disable_torch_init
|
11 |
+
from omnilmm.model.omnilmm import OmniLMMForCausalLM
|
12 |
+
from omnilmm.model.utils import build_transform
|
13 |
+
from omnilmm.train.train_utils import omni_preprocess
|
14 |
+
|
15 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
16 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
17 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
18 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
def init_omni_lmm(model_path):
|
23 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
24 |
+
disable_torch_init()
|
25 |
+
model_name = os.path.expanduser(model_path)
|
26 |
+
print(f'Load omni_lmm model and tokenizer from {model_name}')
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
28 |
+
model_name, model_max_length=4096)
|
29 |
+
|
30 |
+
if False:
|
31 |
+
# model on multiple devices for small size gpu memory (Nvidia 3090 24G x2)
|
32 |
+
with init_empty_weights():
|
33 |
+
model = OmniLMMForCausalLM.from_pretrained(model_name, tune_clip=True, torch_dtype=torch.bfloat16)
|
34 |
+
model = load_checkpoint_and_dispatch(model, model_name, dtype=torch.bfloat16,
|
35 |
+
device_map="auto", no_split_module_classes=['Eva','MistralDecoderLayer', 'ModuleList', 'Resampler']
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
model = OmniLMMForCausalLM.from_pretrained(
|
39 |
+
model_name, tune_clip=True, torch_dtype=torch.bfloat16
|
40 |
+
).to(device='cuda', dtype=torch.bfloat16)
|
41 |
+
|
42 |
+
image_processor = build_transform(
|
43 |
+
is_train=False, input_size=model.model.config.image_size, std_mode='OPENAI_CLIP')
|
44 |
+
|
45 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
46 |
+
assert mm_use_im_start_end
|
47 |
+
|
48 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,
|
49 |
+
DEFAULT_IM_END_TOKEN], special_tokens=True)
|
50 |
+
|
51 |
+
|
52 |
+
vision_config = model.model.vision_config
|
53 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
54 |
+
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
55 |
+
vision_config.use_im_start_end = mm_use_im_start_end
|
56 |
+
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
|
57 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
58 |
+
image_token_len = model.model.config.num_query
|
59 |
+
|
60 |
+
return model, image_processor, image_token_len, tokenizer
|
61 |
+
|
62 |
+
def expand_question_into_multimodal(question_text, image_token_len, im_st_token, im_ed_token, im_patch_token):
|
63 |
+
if '<image>' in question_text[0]['content']:
|
64 |
+
question_text[0]['content'] = question_text[0]['content'].replace(
|
65 |
+
'<image>', im_st_token + im_patch_token * image_token_len + im_ed_token)
|
66 |
+
else:
|
67 |
+
question_text[0]['content'] = im_st_token + im_patch_token * \
|
68 |
+
image_token_len + im_ed_token + '\n' + question_text[0]['content']
|
69 |
+
return question_text
|
70 |
+
|
71 |
+
def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
|
72 |
+
question = expand_question_into_multimodal(
|
73 |
+
question, image_token_len, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN)
|
74 |
+
|
75 |
+
conversation = question
|
76 |
+
data_dict = omni_preprocess(sources=[conversation],
|
77 |
+
tokenizer=tokenizer,
|
78 |
+
generation=True)
|
79 |
+
|
80 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
81 |
+
labels=data_dict["labels"][0])
|
82 |
+
return data_dict
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
class OmniLMM12B:
|
87 |
+
def __init__(self, model_path) -> None:
|
88 |
+
model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path)
|
89 |
+
self.model = model
|
90 |
+
self.image_token_len = image_token_len
|
91 |
+
self.image_transform = img_processor
|
92 |
+
self.tokenizer = tokenizer
|
93 |
+
self.model.eval()
|
94 |
+
|
95 |
+
def decode(self, image, input_ids):
|
96 |
+
with torch.inference_mode():
|
97 |
+
output = self.model.generate_vllm(
|
98 |
+
input_ids=input_ids.unsqueeze(0).cuda(),
|
99 |
+
images=image.unsqueeze(0).half().cuda(),
|
100 |
+
temperature=0.6,
|
101 |
+
max_new_tokens=1024,
|
102 |
+
# num_beams=num_beams,
|
103 |
+
do_sample=True,
|
104 |
+
output_scores=True,
|
105 |
+
return_dict_in_generate=True,
|
106 |
+
repetition_penalty=1.1,
|
107 |
+
top_k=30,
|
108 |
+
top_p=0.9,
|
109 |
+
)
|
110 |
+
|
111 |
+
response = self.tokenizer.decode(
|
112 |
+
output.sequences[0], skip_special_tokens=True)
|
113 |
+
response = response.strip()
|
114 |
+
return response
|
115 |
+
|
116 |
+
def chat(self, image, msgs):
|
117 |
+
#image = input['image']
|
118 |
+
#msgs = json.loads(input['question'])
|
119 |
+
input_ids = wrap_question_for_omni_lmm(
|
120 |
+
msgs, self.image_token_len, self.tokenizer)['input_ids']
|
121 |
+
input_ids = torch.as_tensor(input_ids)
|
122 |
+
#print('input_ids', input_ids)
|
123 |
+
image = self.image_transform(image)
|
124 |
+
|
125 |
+
out = self.decode(image, input_ids)
|
126 |
+
|
127 |
+
return out
|
128 |
+
|
129 |
+
|
130 |
+
def img2base64(file_name):
|
131 |
+
with open(file_name, 'rb') as f:
|
132 |
+
encoded_string = base64.b64encode(f.read())
|
133 |
+
return encoded_string
|
134 |
+
|
135 |
+
class MiniCPMV:
|
136 |
+
def __init__(self, model_path) -> None:
|
137 |
+
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
|
138 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
139 |
+
self.model.eval().cuda()
|
140 |
+
|
141 |
+
def chat(self, input):
|
142 |
+
try:
|
143 |
+
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
|
144 |
+
except Exception as e:
|
145 |
+
return "Image decode error"
|
146 |
+
|
147 |
+
msgs = json.loads(input['question'])
|
148 |
+
|
149 |
+
answer, context, _ = self.model.chat(
|
150 |
+
image=image,
|
151 |
+
msgs=msgs,
|
152 |
+
context=None,
|
153 |
+
tokenizer=self.tokenizer,
|
154 |
+
sampling=True,
|
155 |
+
temperature=0.7
|
156 |
+
)
|
157 |
+
return answer
|
158 |
+
|
159 |
+
class MiniCPMV2_5:
|
160 |
+
def __init__(self, model_path) -> None:
|
161 |
+
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.float16)
|
162 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
163 |
+
self.model.eval().cuda()
|
164 |
+
|
165 |
+
def chat(self, input):
|
166 |
+
try:
|
167 |
+
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
|
168 |
+
except Exception as e:
|
169 |
+
return "Image decode error"
|
170 |
+
|
171 |
+
msgs = json.loads(input['question'])
|
172 |
+
|
173 |
+
answer = self.model.chat(
|
174 |
+
image=image,
|
175 |
+
msgs=msgs,
|
176 |
+
tokenizer=self.tokenizer,
|
177 |
+
sampling=True,
|
178 |
+
temperature=0.7
|
179 |
+
)
|
180 |
+
return answer
|
181 |
+
|
182 |
+
|
183 |
+
class MiniCPMVChat:
|
184 |
+
def __init__(self, model_path) -> None:
|
185 |
+
if '12B' in model_path:
|
186 |
+
self.model = OmniLMM12B(model_path)
|
187 |
+
elif 'MiniCPM-Llama3-V' in model_path:
|
188 |
+
self.model = MiniCPMV2_5(model_path)
|
189 |
+
else:
|
190 |
+
self.model = MiniCPMV(model_path)
|
191 |
+
|
192 |
+
def chat(self, input):
|
193 |
+
return self.model.chat(input)
|
194 |
+
|
195 |
+
|
196 |
+
if __name__ == '__main__':
|
197 |
+
|
198 |
+
model_path = 'openbmb/OmniLMM-12B'
|
199 |
+
chat_model = MiniCPMVChat(model_path)
|
200 |
+
|
201 |
+
im_64 = img2base64('./assets/worldmap_ck.jpg')
|
202 |
+
|
203 |
+
# first round chat
|
204 |
+
msgs = [{"role": "user", "content": "What is interesting about this image?"}]
|
205 |
+
input = {"image": im_64, "question": json.dumps(msgs, ensure_ascii=True)}
|
206 |
+
answer = chat_model.chat(input)
|
207 |
+
print(msgs[-1]["content"]+'\n', answer)
|
208 |
+
|
209 |
+
# second round chat
|
210 |
+
msgs.append({"role": "assistant", "content": answer})
|
211 |
+
msgs.append({"role": "user", "content": "Where is China in the image"})
|
212 |
+
input = {"image": im_64,"question": json.dumps(msgs, ensure_ascii=True)}
|
213 |
+
answer = chat_model.chat(input)
|
214 |
+
print(msgs[-1]["content"]+'\n', answer)
|
215 |
+
|
omnilmm/__init__.py
ADDED
File without changes
|
omnilmm/constants.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
omnilmm/conversation.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
|
11 |
+
|
12 |
+
@dataclasses.dataclass
|
13 |
+
class Conversation:
|
14 |
+
"""A class that keeps all conversation history."""
|
15 |
+
system: str
|
16 |
+
roles: List[str]
|
17 |
+
messages: List[List[str]]
|
18 |
+
offset: int
|
19 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
20 |
+
sep: str = "###"
|
21 |
+
sep2: str = None
|
22 |
+
version: str = "Unknown"
|
23 |
+
|
24 |
+
skip_next: bool = False
|
25 |
+
|
26 |
+
def get_prompt(self):
|
27 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
28 |
+
ret = self.system + self.sep
|
29 |
+
for role, message in self.messages:
|
30 |
+
if message:
|
31 |
+
if type(message) is tuple:
|
32 |
+
message, _, _ = message
|
33 |
+
ret += role + ": " + message + self.sep
|
34 |
+
else:
|
35 |
+
ret += role + ":"
|
36 |
+
return ret
|
37 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
38 |
+
seps = [self.sep, self.sep2]
|
39 |
+
ret = self.system + seps[0]
|
40 |
+
for i, (role, message) in enumerate(self.messages):
|
41 |
+
if message:
|
42 |
+
if type(message) is tuple:
|
43 |
+
message, _, _ = message
|
44 |
+
ret += role + ": " + message + seps[i % 2]
|
45 |
+
else:
|
46 |
+
ret += role + ":"
|
47 |
+
return ret
|
48 |
+
else:
|
49 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
50 |
+
|
51 |
+
def append_message(self, role, message):
|
52 |
+
self.messages.append([role, message])
|
53 |
+
|
54 |
+
def get_images(self, return_pil=False):
|
55 |
+
images = []
|
56 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
57 |
+
if i % 2 == 0:
|
58 |
+
if type(msg) is tuple:
|
59 |
+
import base64
|
60 |
+
from io import BytesIO
|
61 |
+
from PIL import Image
|
62 |
+
msg, image, image_process_mode = msg
|
63 |
+
if image_process_mode == "Pad":
|
64 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
65 |
+
width, height = pil_img.size
|
66 |
+
if width == height:
|
67 |
+
return pil_img
|
68 |
+
elif width > height:
|
69 |
+
result = Image.new(
|
70 |
+
pil_img.mode, (width, width), background_color)
|
71 |
+
result.paste(
|
72 |
+
pil_img, (0, (width - height) // 2))
|
73 |
+
return result
|
74 |
+
else:
|
75 |
+
result = Image.new(
|
76 |
+
pil_img.mode, (height, height), background_color)
|
77 |
+
result.paste(
|
78 |
+
pil_img, ((height - width) // 2, 0))
|
79 |
+
return result
|
80 |
+
image = expand2square(image)
|
81 |
+
elif image_process_mode == "Crop":
|
82 |
+
pass
|
83 |
+
elif image_process_mode == "Resize":
|
84 |
+
image = image.resize((224, 224))
|
85 |
+
else:
|
86 |
+
raise ValueError(
|
87 |
+
f"Invalid image_process_mode: {image_process_mode}")
|
88 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
89 |
+
aspect_ratio = max_hw / min_hw
|
90 |
+
max_len, min_len = 800, 400
|
91 |
+
shortest_edge = int(
|
92 |
+
min(max_len / aspect_ratio, min_len, min_hw))
|
93 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
94 |
+
W, H = image.size
|
95 |
+
if H > W:
|
96 |
+
H, W = longest_edge, shortest_edge
|
97 |
+
else:
|
98 |
+
H, W = shortest_edge, longest_edge
|
99 |
+
image = image.resize((W, H))
|
100 |
+
if return_pil:
|
101 |
+
images.append(image)
|
102 |
+
else:
|
103 |
+
buffered = BytesIO()
|
104 |
+
image.save(buffered, format="JPEG")
|
105 |
+
img_b64_str = base64.b64encode(
|
106 |
+
buffered.getvalue()).decode()
|
107 |
+
images.append(img_b64_str)
|
108 |
+
return images
|
109 |
+
|
110 |
+
def to_gradio_chatbot(self):
|
111 |
+
ret = []
|
112 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
113 |
+
if i % 2 == 0:
|
114 |
+
if type(msg) is tuple:
|
115 |
+
import base64
|
116 |
+
from io import BytesIO
|
117 |
+
msg, image, image_process_mode = msg
|
118 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
119 |
+
aspect_ratio = max_hw / min_hw
|
120 |
+
max_len, min_len = 800, 400
|
121 |
+
shortest_edge = int(
|
122 |
+
min(max_len / aspect_ratio, min_len, min_hw))
|
123 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
124 |
+
W, H = image.size
|
125 |
+
if H > W:
|
126 |
+
H, W = longest_edge, shortest_edge
|
127 |
+
else:
|
128 |
+
H, W = shortest_edge, longest_edge
|
129 |
+
image = image.resize((W, H))
|
130 |
+
# image = image.resize((224, 224))
|
131 |
+
buffered = BytesIO()
|
132 |
+
image.save(buffered, format="JPEG")
|
133 |
+
img_b64_str = base64.b64encode(
|
134 |
+
buffered.getvalue()).decode()
|
135 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
136 |
+
msg = msg.replace('<image>', img_str)
|
137 |
+
ret.append([msg, None])
|
138 |
+
else:
|
139 |
+
ret[-1][-1] = msg
|
140 |
+
return ret
|
141 |
+
|
142 |
+
def copy(self):
|
143 |
+
return Conversation(
|
144 |
+
system=self.system,
|
145 |
+
roles=self.roles,
|
146 |
+
messages=[[x, y] for x, y in self.messages],
|
147 |
+
offset=self.offset,
|
148 |
+
sep_style=self.sep_style,
|
149 |
+
sep=self.sep,
|
150 |
+
sep2=self.sep2)
|
151 |
+
|
152 |
+
def dict(self):
|
153 |
+
if len(self.get_images()) > 0:
|
154 |
+
return {
|
155 |
+
"system": self.system,
|
156 |
+
"roles": self.roles,
|
157 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
158 |
+
"offset": self.offset,
|
159 |
+
"sep": self.sep,
|
160 |
+
"sep2": self.sep2,
|
161 |
+
}
|
162 |
+
return {
|
163 |
+
"system": self.system,
|
164 |
+
"roles": self.roles,
|
165 |
+
"messages": self.messages,
|
166 |
+
"offset": self.offset,
|
167 |
+
"sep": self.sep,
|
168 |
+
"sep2": self.sep2,
|
169 |
+
}
|
170 |
+
|
171 |
+
|
172 |
+
conv_v1 = Conversation(
|
173 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
174 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
175 |
+
roles=("Human", "Assistant"),
|
176 |
+
messages=(
|
177 |
+
("Human", "Give three tips for staying healthy."),
|
178 |
+
("Assistant",
|
179 |
+
"Sure, here are three tips for staying healthy:\n"
|
180 |
+
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
|
181 |
+
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
|
182 |
+
"and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
|
183 |
+
"75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
|
184 |
+
"activities at least two days per week.\n"
|
185 |
+
"2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
|
186 |
+
"vegetables, whole grains, lean proteins, and healthy fats can help support "
|
187 |
+
"your overall health. Try to limit your intake of processed and high-sugar foods, "
|
188 |
+
"and aim to drink plenty of water throughout the day.\n"
|
189 |
+
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
|
190 |
+
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
|
191 |
+
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
|
192 |
+
"help improve the quality of your sleep.")
|
193 |
+
),
|
194 |
+
offset=2,
|
195 |
+
sep_style=SeparatorStyle.SINGLE,
|
196 |
+
sep="###",
|
197 |
+
)
|
198 |
+
|
199 |
+
conv_v1_2 = Conversation(
|
200 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
201 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
202 |
+
roles=("Human", "Assistant"),
|
203 |
+
messages=(
|
204 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
205 |
+
("Assistant",
|
206 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
207 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
208 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
209 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
210 |
+
"renewable and non-renewable energy sources:\n"
|
211 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
212 |
+
"energy sources are finite and will eventually run out.\n"
|
213 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
214 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
215 |
+
"and other negative effects.\n"
|
216 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
217 |
+
"have lower operational costs than non-renewable sources.\n"
|
218 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
219 |
+
"locations than non-renewable sources.\n"
|
220 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
221 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
222 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
223 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
224 |
+
),
|
225 |
+
offset=2,
|
226 |
+
sep_style=SeparatorStyle.SINGLE,
|
227 |
+
sep="###",
|
228 |
+
)
|
229 |
+
|
230 |
+
conv_vicuna_v1_1 = Conversation(
|
231 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
232 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
233 |
+
roles=("USER", "ASSISTANT"),
|
234 |
+
version="v1",
|
235 |
+
messages=(),
|
236 |
+
offset=0,
|
237 |
+
sep_style=SeparatorStyle.TWO,
|
238 |
+
sep=" ",
|
239 |
+
sep2="</s>",
|
240 |
+
)
|
241 |
+
|
242 |
+
conv_bair_v1 = Conversation(
|
243 |
+
system="BEGINNING OF CONVERSATION:",
|
244 |
+
roles=("USER", "GPT"),
|
245 |
+
messages=(),
|
246 |
+
offset=0,
|
247 |
+
sep_style=SeparatorStyle.TWO,
|
248 |
+
sep=" ",
|
249 |
+
sep2="</s>",
|
250 |
+
)
|
251 |
+
|
252 |
+
simple_conv = Conversation(
|
253 |
+
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture."
|
254 |
+
"You are designed to assist human with a variety of tasks using natural language."
|
255 |
+
"Follow the instructions carefully.",
|
256 |
+
roles=("Human", "Assistant"),
|
257 |
+
messages=(
|
258 |
+
("Human", "Hi!"),
|
259 |
+
("Assistant", "Hi there! How can I help you today?\n")
|
260 |
+
),
|
261 |
+
offset=2,
|
262 |
+
sep_style=SeparatorStyle.SINGLE,
|
263 |
+
sep="###",
|
264 |
+
)
|
265 |
+
|
266 |
+
simple_conv_multimodal = Conversation(
|
267 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
268 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
269 |
+
roles=("Human", "Assistant"),
|
270 |
+
messages=(
|
271 |
+
),
|
272 |
+
offset=0,
|
273 |
+
sep_style=SeparatorStyle.SINGLE,
|
274 |
+
sep="###",
|
275 |
+
)
|
276 |
+
|
277 |
+
simple_conv_legacy = Conversation(
|
278 |
+
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
|
279 |
+
"You are designed to assist human with a variety of tasks using natural language."
|
280 |
+
"Follow the instructions carefully.",
|
281 |
+
roles=("Human", "Assistant"),
|
282 |
+
messages=(
|
283 |
+
("Human", "Hi!\n\n### Response:"),
|
284 |
+
("Assistant", "Hi there! How can I help you today?\n")
|
285 |
+
),
|
286 |
+
offset=2,
|
287 |
+
sep_style=SeparatorStyle.SINGLE,
|
288 |
+
sep="###",
|
289 |
+
)
|
290 |
+
|
291 |
+
conv_llava_v1 = Conversation(
|
292 |
+
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
293 |
+
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
294 |
+
"Follow the instructions carefully and explain your answers in detail.",
|
295 |
+
roles=("USER", "ASSISTANT"),
|
296 |
+
version="v1",
|
297 |
+
messages=(),
|
298 |
+
offset=0,
|
299 |
+
sep_style=SeparatorStyle.TWO,
|
300 |
+
sep=" ",
|
301 |
+
sep2="</s>",
|
302 |
+
)
|
303 |
+
|
304 |
+
default_conversation = conv_v1_2
|
305 |
+
conv_templates = {
|
306 |
+
"default": conv_v1_2,
|
307 |
+
"simple": simple_conv,
|
308 |
+
"simple_legacy": simple_conv_legacy,
|
309 |
+
"multimodal": simple_conv_multimodal,
|
310 |
+
"llava_v1": conv_llava_v1,
|
311 |
+
|
312 |
+
# fastchat
|
313 |
+
"v1": conv_v1_2,
|
314 |
+
"bair_v1": conv_bair_v1,
|
315 |
+
"vicuna_v1_1": conv_vicuna_v1_1,
|
316 |
+
}
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == "__main__":
|
320 |
+
print(default_conversation.get_prompt())
|
omnilmm/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .omnilmm import OmniLMMForCausalLM
|
omnilmm/model/omnilmm.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gc
|
3 |
+
import math
|
4 |
+
import timm
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn import CrossEntropyLoss
|
9 |
+
from typing import List, Optional, Tuple, Union
|
10 |
+
|
11 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
12 |
+
from transformers import MistralForCausalLM, MistralModel, MistralConfig
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
14 |
+
|
15 |
+
from omnilmm.model.utils import build_transform
|
16 |
+
from omnilmm.model.resampler import Resampler
|
17 |
+
|
18 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
19 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
20 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
21 |
+
|
22 |
+
|
23 |
+
class OmniLMMConfig(MistralConfig):
|
24 |
+
model_type = "omnilmm"
|
25 |
+
|
26 |
+
|
27 |
+
class Identity(torch.nn.Identity):
|
28 |
+
def forward(self, input: Tensor, **kwargs) -> Tensor:
|
29 |
+
return super().forward(input)
|
30 |
+
|
31 |
+
|
32 |
+
def create_vision_module(config):
|
33 |
+
vision_tower = timm.create_model('eva02_enormous_patch14_clip_224.laion2b_plus',
|
34 |
+
pretrained=False,
|
35 |
+
num_classes=0,
|
36 |
+
dynamic_img_size=True,
|
37 |
+
dynamic_img_pad=True)
|
38 |
+
|
39 |
+
if isinstance(vision_tower, timm.models.VisionTransformer):
|
40 |
+
if vision_tower.attn_pool is not None:
|
41 |
+
vision_tower.attn_pool = Identity()
|
42 |
+
|
43 |
+
# use 2nd last layer's output
|
44 |
+
vision_tower.blocks[-1] = Identity()
|
45 |
+
|
46 |
+
embed_dim = config.hidden_size
|
47 |
+
resampler = Resampler(
|
48 |
+
grid_size=int(math.sqrt(config.num_query)),
|
49 |
+
embed_dim=embed_dim,
|
50 |
+
num_heads=embed_dim // 128,
|
51 |
+
kv_dim=vision_tower.embed_dim,
|
52 |
+
)
|
53 |
+
return vision_tower, resampler
|
54 |
+
|
55 |
+
|
56 |
+
class OmniLMMModel(MistralModel):
|
57 |
+
config_class = OmniLMMConfig
|
58 |
+
|
59 |
+
def __init__(self, config: OmniLMMConfig, mm_vision_tower=None, mm_hidden_size=None, tune_clip=True):
|
60 |
+
super(OmniLMMModel, self).__init__(config)
|
61 |
+
|
62 |
+
if hasattr(config, "mm_vision_tower"):
|
63 |
+
vision_tower, resampler = create_vision_module(config)
|
64 |
+
|
65 |
+
# print(__file__, 'skip loading vision tower weights')
|
66 |
+
|
67 |
+
# HACK: for FSDP
|
68 |
+
self.vision_tower = [vision_tower]
|
69 |
+
self.resampler = resampler
|
70 |
+
if tune_clip:
|
71 |
+
self.vision_tower = self.vision_tower[0]
|
72 |
+
|
73 |
+
self.vision_config = lambda x: None
|
74 |
+
|
75 |
+
def initialize_vision_modules(self, vision_tower, no_randaug, num_query, image_size, tune_clip=False):
|
76 |
+
self.config.mm_vision_tower = vision_tower
|
77 |
+
self.config.use_mm_proj = True
|
78 |
+
self.config.num_query = num_query
|
79 |
+
self.config.image_size = image_size
|
80 |
+
|
81 |
+
if not hasattr(self, 'vision_tower'):
|
82 |
+
vision_tower, resampler = create_vision_module(self.config)
|
83 |
+
state_dict = torch.load(
|
84 |
+
'/tt/data/public/multimodal/multimodal_model_ckpts/timm/eva02_enormous_patch14_clip_224.laion2b_plus.pt')
|
85 |
+
vision_tower.load_state_dict(state_dict, strict=False)
|
86 |
+
del state_dict
|
87 |
+
gc.collect()
|
88 |
+
else:
|
89 |
+
if isinstance(self.vision_tower, list):
|
90 |
+
vision_tower = self.vision_tower[0]
|
91 |
+
else:
|
92 |
+
vision_tower = self.vision_tower
|
93 |
+
resampler = self.resampler
|
94 |
+
self.vision_tower = vision_tower if tune_clip else [vision_tower]
|
95 |
+
self.resampler = resampler
|
96 |
+
|
97 |
+
train_img_transform = build_transform(
|
98 |
+
is_train=True, randaug=not no_randaug, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
|
99 |
+
eval_img_transform = build_transform(
|
100 |
+
is_train=False, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
|
101 |
+
|
102 |
+
return dict(
|
103 |
+
image_processor=(train_img_transform, eval_img_transform),
|
104 |
+
image_token_len=num_query,
|
105 |
+
vision_config=self.vision_config
|
106 |
+
)
|
107 |
+
|
108 |
+
def get_vision_embedding(self, pixel_values):
|
109 |
+
if isinstance(self.vision_tower, list):
|
110 |
+
vision_tower = self.vision_tower[0] # HACK: for FSDP
|
111 |
+
else:
|
112 |
+
vision_tower = self.vision_tower
|
113 |
+
|
114 |
+
dtype = vision_tower.pos_embed.data.dtype
|
115 |
+
vision_embedding = vision_tower.forward_features(
|
116 |
+
pixel_values.type(dtype))
|
117 |
+
if hasattr(vision_tower, 'num_prefix_tokens') and vision_tower.num_prefix_tokens > 0:
|
118 |
+
vision_embedding = vision_embedding[:,
|
119 |
+
vision_tower.num_prefix_tokens:]
|
120 |
+
res = self.resampler(vision_embedding)
|
121 |
+
return res
|
122 |
+
|
123 |
+
def get_vllm_embedding(self, data):
|
124 |
+
|
125 |
+
if 'vision_hidden_states' not in data:
|
126 |
+
pixel_values_list = data['pixel_values']
|
127 |
+
vision_hidden_states = []
|
128 |
+
for pixel_values in pixel_values_list:
|
129 |
+
if len(pixel_values) > 0:
|
130 |
+
vision_hidden_states.append(self.get_vision_embedding(pixel_values.unsqueeze(0))[0])
|
131 |
+
else:
|
132 |
+
vision_hidden_states.append([])
|
133 |
+
else:
|
134 |
+
vision_hidden_states = data['vision_hidden_states']
|
135 |
+
|
136 |
+
#vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
|
137 |
+
inputs_embeds = self.embed_tokens(data['input_ids'])
|
138 |
+
vision_hidden_states = [i.type(inputs_embeds.dtype)
|
139 |
+
if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
140 |
+
]
|
141 |
+
|
142 |
+
|
143 |
+
# HACK: replace back original embeddings for LLaVA pretraining
|
144 |
+
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
|
145 |
+
|
146 |
+
new_input_embeds = []
|
147 |
+
cur_image_idx = 0
|
148 |
+
for cur_input_ids, cur_input_embeds in zip(data['input_ids'], inputs_embeds):
|
149 |
+
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
|
150 |
+
# multimodal LLM, but the current sample is not multimodal
|
151 |
+
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
|
152 |
+
new_input_embeds.append(cur_input_embeds)
|
153 |
+
continue
|
154 |
+
|
155 |
+
if self.vision_config.use_im_start_end:
|
156 |
+
cur_image_features = vision_hidden_states[cur_image_idx]
|
157 |
+
num_patches = cur_image_features.shape[0]
|
158 |
+
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
|
159 |
+
raise ValueError(
|
160 |
+
"The number of image start tokens and image end tokens should be the same.")
|
161 |
+
image_start_tokens = torch.where(
|
162 |
+
cur_input_ids == self.vision_config.im_start_token)[0]
|
163 |
+
for image_start_token_pos in image_start_tokens:
|
164 |
+
cur_image_features = vision_hidden_states[cur_image_idx].to(
|
165 |
+
device=cur_input_embeds.device)
|
166 |
+
num_patches = cur_image_features.shape[0]
|
167 |
+
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
|
168 |
+
raise ValueError(
|
169 |
+
"The image end token should follow the image start token.")
|
170 |
+
if orig_embeds_params is not None:
|
171 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
172 |
+
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
|
173 |
+
else:
|
174 |
+
cur_new_input_embeds = torch.cat(
|
175 |
+
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
176 |
+
cur_image_idx += 1
|
177 |
+
new_input_embeds.append(cur_new_input_embeds)
|
178 |
+
else:
|
179 |
+
raise NotImplementedError
|
180 |
+
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
181 |
+
|
182 |
+
return inputs_embeds, vision_hidden_states
|
183 |
+
|
184 |
+
def forward(
|
185 |
+
self,
|
186 |
+
input_ids: torch.LongTensor = None,
|
187 |
+
attention_mask: Optional[torch.Tensor] = None,
|
188 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
189 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
190 |
+
use_cache: Optional[bool] = None,
|
191 |
+
output_attentions: Optional[bool] = None,
|
192 |
+
output_hidden_states: Optional[bool] = None,
|
193 |
+
images: Optional[torch.FloatTensor] = None,
|
194 |
+
return_dict: Optional[bool] = None,
|
195 |
+
**kwargs
|
196 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
197 |
+
|
198 |
+
# HACK: replace back original embeddings for LLaVA pretraining
|
199 |
+
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
|
200 |
+
|
201 |
+
if inputs_embeds is None and past_key_values is None:
|
202 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
203 |
+
|
204 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
205 |
+
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
|
206 |
+
|
207 |
+
if type(images) is list:
|
208 |
+
image_features = []
|
209 |
+
for image in images:
|
210 |
+
image_forward_out = self.get_vision_embedding(image.unsqueeze(0))[
|
211 |
+
0]
|
212 |
+
image_features.append(image_forward_out)
|
213 |
+
else:
|
214 |
+
image_features = self.get_vision_embedding(images)
|
215 |
+
|
216 |
+
dummy_image_features = torch.zeros(
|
217 |
+
self.config.num_query,
|
218 |
+
self.config.hidden_size,
|
219 |
+
device=inputs_embeds.device,
|
220 |
+
dtype=inputs_embeds.dtype)
|
221 |
+
|
222 |
+
new_input_embeds = []
|
223 |
+
cur_image_idx = 0
|
224 |
+
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
225 |
+
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
|
226 |
+
# multimodal LLM, but the current sample is not multimodal
|
227 |
+
cur_input_embeds = cur_input_embeds + \
|
228 |
+
(0. * dummy_image_features).sum()
|
229 |
+
new_input_embeds.append(cur_input_embeds)
|
230 |
+
continue
|
231 |
+
|
232 |
+
if self.vision_config.use_im_start_end:
|
233 |
+
cur_image_features = image_features[cur_image_idx]
|
234 |
+
num_patches = cur_image_features.shape[0]
|
235 |
+
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
|
236 |
+
raise ValueError(
|
237 |
+
"The number of image start tokens and image end tokens should be the same.")
|
238 |
+
image_start_tokens = torch.where(
|
239 |
+
cur_input_ids == self.vision_config.im_start_token)[0]
|
240 |
+
for image_start_token_pos in image_start_tokens:
|
241 |
+
cur_image_features = image_features[cur_image_idx].to(
|
242 |
+
device=cur_input_embeds.device)
|
243 |
+
num_patches = cur_image_features.shape[0]
|
244 |
+
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
|
245 |
+
raise ValueError(
|
246 |
+
"The image end token should follow the image start token.")
|
247 |
+
if orig_embeds_params is not None:
|
248 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
249 |
+
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
|
250 |
+
else:
|
251 |
+
cur_new_input_embeds = torch.cat(
|
252 |
+
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
253 |
+
cur_image_idx += 1
|
254 |
+
new_input_embeds.append(cur_new_input_embeds)
|
255 |
+
else:
|
256 |
+
raise NotImplementedError
|
257 |
+
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
258 |
+
input_ids = None
|
259 |
+
|
260 |
+
return super(OmniLMMModel, self).forward(
|
261 |
+
input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values,
|
262 |
+
inputs_embeds=inputs_embeds, use_cache=use_cache,
|
263 |
+
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
|
264 |
+
return_dict=return_dict,
|
265 |
+
**kwargs
|
266 |
+
)
|
267 |
+
|
268 |
+
|
269 |
+
class OmniLMMForCausalLM(MistralForCausalLM):
|
270 |
+
config_class = OmniLMMConfig
|
271 |
+
|
272 |
+
def __init__(self, config, mm_vision_tower=None, tune_clip=True):
|
273 |
+
super(MistralForCausalLM, self).__init__(config)
|
274 |
+
self.model = OmniLMMModel(
|
275 |
+
config, mm_vision_tower=mm_vision_tower, tune_clip=tune_clip)
|
276 |
+
|
277 |
+
self.lm_head = nn.Linear(
|
278 |
+
config.hidden_size, config.vocab_size, bias=False)
|
279 |
+
|
280 |
+
# Initialize weights and apply final processing
|
281 |
+
self.post_init()
|
282 |
+
|
283 |
+
def forward(
|
284 |
+
self,
|
285 |
+
input_ids: torch.LongTensor = None,
|
286 |
+
attention_mask: Optional[torch.Tensor] = None,
|
287 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
288 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
289 |
+
labels: Optional[torch.LongTensor] = None,
|
290 |
+
use_cache: Optional[bool] = None,
|
291 |
+
output_attentions: Optional[bool] = None,
|
292 |
+
output_hidden_states: Optional[bool] = None,
|
293 |
+
images: Optional[torch.FloatTensor] = None,
|
294 |
+
return_dict: Optional[bool] = None,
|
295 |
+
**kwargs
|
296 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
297 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
298 |
+
output_hidden_states = (
|
299 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
300 |
+
)
|
301 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
302 |
+
|
303 |
+
# print(f'@@@ At forward, labels: {labels.shape}-{labels}', flush=True)
|
304 |
+
# print(f'@@@ At forward, input_ids: {input_ids.shape}-{input_ids}', flush=True)
|
305 |
+
# print(f'@@@ At forward, input_ids: {attention_mask.shape}-{attention_mask}', flush=True)
|
306 |
+
|
307 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
308 |
+
outputs = self.model(
|
309 |
+
input_ids=input_ids,
|
310 |
+
attention_mask=attention_mask,
|
311 |
+
past_key_values=past_key_values,
|
312 |
+
inputs_embeds=inputs_embeds,
|
313 |
+
use_cache=use_cache,
|
314 |
+
output_attentions=output_attentions,
|
315 |
+
output_hidden_states=output_hidden_states,
|
316 |
+
return_dict=return_dict,
|
317 |
+
images=images,
|
318 |
+
**kwargs
|
319 |
+
)
|
320 |
+
|
321 |
+
hidden_states = outputs[0]
|
322 |
+
logits = self.lm_head(hidden_states)
|
323 |
+
|
324 |
+
loss = None
|
325 |
+
if labels is not None:
|
326 |
+
# Shift so that tokens < n predict n
|
327 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
328 |
+
shift_labels = labels[..., 1:].contiguous()
|
329 |
+
# Flatten the tokens
|
330 |
+
loss_fct = CrossEntropyLoss()
|
331 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
332 |
+
shift_labels = shift_labels.view(-1)
|
333 |
+
# Enable model/pipeline parallelism
|
334 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
335 |
+
loss = loss_fct(shift_logits, shift_labels)
|
336 |
+
|
337 |
+
if not return_dict:
|
338 |
+
output = (logits,) + outputs[1:]
|
339 |
+
return (loss,) + output if loss is not None else output
|
340 |
+
|
341 |
+
return CausalLMOutputWithPast(
|
342 |
+
loss=loss,
|
343 |
+
logits=logits,
|
344 |
+
past_key_values=outputs.past_key_values,
|
345 |
+
hidden_states=outputs.hidden_states,
|
346 |
+
attentions=outputs.attentions,
|
347 |
+
)
|
348 |
+
|
349 |
+
# TODO could be removed for generate_vllm()
|
350 |
+
def prepare_inputs_for_generation(
|
351 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
352 |
+
):
|
353 |
+
if past_key_values:
|
354 |
+
input_ids = input_ids[:, -1:]
|
355 |
+
|
356 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
357 |
+
if inputs_embeds is not None and past_key_values is None:
|
358 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
359 |
+
else:
|
360 |
+
model_inputs = {"input_ids": input_ids}
|
361 |
+
|
362 |
+
model_inputs.update(
|
363 |
+
{
|
364 |
+
"past_key_values": past_key_values,
|
365 |
+
"use_cache": kwargs.get("use_cache"),
|
366 |
+
"attention_mask": attention_mask,
|
367 |
+
"images": kwargs.get("images", None),
|
368 |
+
}
|
369 |
+
)
|
370 |
+
return model_inputs
|
371 |
+
|
372 |
+
def generate_vllm(
|
373 |
+
self,
|
374 |
+
input_ids: torch.LongTensor = None,
|
375 |
+
images: Optional[torch.FloatTensor] = None,
|
376 |
+
vision_hidden_states=None,
|
377 |
+
return_vision_hidden_states=False,
|
378 |
+
**kwargs
|
379 |
+
):
|
380 |
+
model_inputs = {'input_ids': input_ids}
|
381 |
+
if vision_hidden_states is None:
|
382 |
+
model_inputs['pixel_values'] = images
|
383 |
+
else:
|
384 |
+
model_inputs['vision_hidden_states'] = vision_hidden_states
|
385 |
+
|
386 |
+
with torch.inference_mode():
|
387 |
+
inputs_embeds, vision_hidden_states = self.model.get_vllm_embedding(model_inputs)
|
388 |
+
|
389 |
+
result = self.generate(
|
390 |
+
inputs_embeds=inputs_embeds,
|
391 |
+
**kwargs
|
392 |
+
)
|
393 |
+
|
394 |
+
if return_vision_hidden_states:
|
395 |
+
return result, vision_hidden_states
|
396 |
+
|
397 |
+
return result
|
398 |
+
|
399 |
+
|
400 |
+
def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
|
401 |
+
tune_mm_mlp_adapter=False):
|
402 |
+
self.model.vision_config.use_im_start_end = mm_use_im_start_end
|
403 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
404 |
+
self.resize_token_embeddings(len(tokenizer))
|
405 |
+
|
406 |
+
if mm_use_im_start_end:
|
407 |
+
num_new_tokens = tokenizer.add_tokens(
|
408 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
409 |
+
self.resize_token_embeddings(len(tokenizer))
|
410 |
+
self.model.vision_config.im_start_token, self.model.vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
|
411 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
412 |
+
|
413 |
+
if num_new_tokens > 0:
|
414 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
415 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
416 |
+
|
417 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
418 |
+
dim=0, keepdim=True)
|
419 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
420 |
+
dim=0, keepdim=True)
|
421 |
+
|
422 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
423 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
424 |
+
|
425 |
+
# for new sft data
|
426 |
+
num_new_tokens = tokenizer.add_tokens(
|
427 |
+
['<box>', '</box>', '<ref>', '</ref>', '<quad>', '</quad>'], special_tokens=True)
|
428 |
+
self.resize_token_embeddings(len(tokenizer))
|
429 |
+
|
430 |
+
if num_new_tokens > 0:
|
431 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
432 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
433 |
+
|
434 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
435 |
+
dim=0, keepdim=True)
|
436 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
437 |
+
dim=0, keepdim=True)
|
438 |
+
|
439 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
440 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
441 |
+
|
442 |
+
if tune_mm_mlp_adapter:
|
443 |
+
self.model.orig_embeds_params = [
|
444 |
+
self.get_input_embeddings().weight.data.clone().to(device=device)]
|
445 |
+
for p in self.get_input_embeddings().parameters():
|
446 |
+
p.requires_grad = True
|
447 |
+
for p in self.get_output_embeddings().parameters():
|
448 |
+
p.requires_grad = False
|
449 |
+
|
450 |
+
self.model.vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
451 |
+
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
452 |
+
print(f'Tokenizer: {tokenizer}\n patch_token_id: {self.model.vision_config.im_patch_token}, visoin_config: {self.model.vision_config}', flush=True)
|
453 |
+
# exit()
|
454 |
+
|
455 |
+
|
456 |
+
AutoConfig.register("omnilmm", OmniLMMConfig)
|
457 |
+
AutoModelForCausalLM.register(OmniLMMConfig, OmniLMMForCausalLM)
|
omnilmm/model/resampler.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba Cloud.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
import math
|
8 |
+
import requests
|
9 |
+
from io import BytesIO
|
10 |
+
from functools import partial
|
11 |
+
from PIL import Image
|
12 |
+
from typing import Callable, Optional, Sequence, Tuple, List, Union
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
from torch.nn import functional as F
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
from torchvision import transforms
|
20 |
+
from torchvision.transforms import InterpolationMode
|
21 |
+
|
22 |
+
|
23 |
+
def get_abs_pos(abs_pos, tgt_size):
|
24 |
+
# abs_pos: L, C
|
25 |
+
# tgt_size: M
|
26 |
+
# return: M, C
|
27 |
+
src_size = int(math.sqrt(abs_pos.size(0)))
|
28 |
+
tgt_size = int(math.sqrt(tgt_size))
|
29 |
+
dtype = abs_pos.dtype
|
30 |
+
|
31 |
+
if src_size != tgt_size:
|
32 |
+
return F.interpolate(
|
33 |
+
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
34 |
+
size=(tgt_size, tgt_size),
|
35 |
+
mode="bicubic",
|
36 |
+
align_corners=False,
|
37 |
+
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
|
38 |
+
else:
|
39 |
+
return abs_pos
|
40 |
+
|
41 |
+
|
42 |
+
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
43 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
44 |
+
"""
|
45 |
+
grid_size: int of the grid height and width
|
46 |
+
return:
|
47 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
48 |
+
"""
|
49 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
50 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
51 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
52 |
+
grid = np.stack(grid, axis=0)
|
53 |
+
|
54 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
55 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
56 |
+
if cls_token:
|
57 |
+
pos_embed = np.concatenate(
|
58 |
+
[np.zeros([1, embed_dim]), pos_embed], axis=0)
|
59 |
+
return pos_embed
|
60 |
+
|
61 |
+
|
62 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
63 |
+
assert embed_dim % 2 == 0
|
64 |
+
|
65 |
+
# use half of dimensions to encode grid_h
|
66 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(
|
67 |
+
embed_dim // 2, grid[0]) # (H*W, D/2)
|
68 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(
|
69 |
+
embed_dim // 2, grid[1]) # (H*W, D/2)
|
70 |
+
|
71 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
72 |
+
return emb
|
73 |
+
|
74 |
+
|
75 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
76 |
+
"""
|
77 |
+
embed_dim: output dimension for each position
|
78 |
+
pos: a list of positions to be encoded: size (M,)
|
79 |
+
out: (M, D)
|
80 |
+
"""
|
81 |
+
assert embed_dim % 2 == 0
|
82 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
83 |
+
omega /= embed_dim / 2.
|
84 |
+
omega = 1. / 10000 ** omega # (D/2,)
|
85 |
+
|
86 |
+
pos = pos.reshape(-1) # (M,)
|
87 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
88 |
+
|
89 |
+
emb_sin = np.sin(out) # (M, D/2)
|
90 |
+
emb_cos = np.cos(out) # (M, D/2)
|
91 |
+
|
92 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
93 |
+
return emb
|
94 |
+
|
95 |
+
|
96 |
+
class Resampler(nn.Module):
|
97 |
+
"""
|
98 |
+
A 2D perceiver-resampler network with one cross attention layers by
|
99 |
+
(grid_size**2) learnable queries and 2d sincos pos_emb
|
100 |
+
Outputs:
|
101 |
+
A tensor with the shape of (grid_size**2, embed_dim)
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
grid_size,
|
107 |
+
embed_dim,
|
108 |
+
num_heads,
|
109 |
+
kv_dim=None,
|
110 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
111 |
+
):
|
112 |
+
super().__init__()
|
113 |
+
self.num_queries = grid_size ** 2
|
114 |
+
self.embed_dim = embed_dim
|
115 |
+
self.num_heads = num_heads
|
116 |
+
|
117 |
+
self.pos_embed = nn.Parameter(
|
118 |
+
torch.from_numpy(get_2d_sincos_pos_embed(
|
119 |
+
embed_dim, grid_size)).float()
|
120 |
+
).requires_grad_(False)
|
121 |
+
|
122 |
+
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
123 |
+
trunc_normal_(self.query, std=.02)
|
124 |
+
|
125 |
+
if kv_dim is not None and kv_dim != embed_dim:
|
126 |
+
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
|
127 |
+
else:
|
128 |
+
self.kv_proj = nn.Identity()
|
129 |
+
|
130 |
+
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
131 |
+
self.ln_q = norm_layer(embed_dim)
|
132 |
+
self.ln_kv = norm_layer(embed_dim)
|
133 |
+
|
134 |
+
self.ln_post = norm_layer(embed_dim)
|
135 |
+
self.proj = nn.Parameter(
|
136 |
+
(embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
|
137 |
+
|
138 |
+
self.apply(self._init_weights)
|
139 |
+
|
140 |
+
def _init_weights(self, m):
|
141 |
+
if isinstance(m, nn.Linear):
|
142 |
+
trunc_normal_(m.weight, std=.02)
|
143 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
144 |
+
nn.init.constant_(m.bias, 0)
|
145 |
+
elif isinstance(m, nn.LayerNorm):
|
146 |
+
nn.init.constant_(m.bias, 0)
|
147 |
+
nn.init.constant_(m.weight, 1.0)
|
148 |
+
|
149 |
+
def forward(self, x, attn_mask=None):
|
150 |
+
|
151 |
+
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
152 |
+
|
153 |
+
x = self.kv_proj(x)
|
154 |
+
x = self.ln_kv(x).permute(1, 0, 2)
|
155 |
+
|
156 |
+
N = x.shape[1]
|
157 |
+
q = self.ln_q(self.query)
|
158 |
+
# print((self._repeat(q, N) + self.pos_embed.unsqueeze(1)).dtype, (x + pos_embed.unsqueeze(1)).dtype, x.dtype)
|
159 |
+
out = self.attn(
|
160 |
+
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
|
161 |
+
x + pos_embed.unsqueeze(1),
|
162 |
+
x,
|
163 |
+
attn_mask=attn_mask)[0]
|
164 |
+
x = out.permute(1, 0, 2)
|
165 |
+
|
166 |
+
x = self.ln_post(x)
|
167 |
+
x = x @ self.proj
|
168 |
+
return x
|
169 |
+
|
170 |
+
def _repeat(self, query, N: int):
|
171 |
+
return query.unsqueeze(1).repeat(1, N, 1)
|
omnilmm/model/utils.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
from timm.data.transforms import RandomResizedCropAndInterpolation
|
3 |
+
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
4 |
+
from transformers import AutoConfig
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
import torch.distributed as dist
|
8 |
+
import numpy as np
|
9 |
+
import pickle
|
10 |
+
import base64
|
11 |
+
import cv2
|
12 |
+
import os
|
13 |
+
import torch
|
14 |
+
from transformers import AutoConfig, StoppingCriteria
|
15 |
+
|
16 |
+
try:
|
17 |
+
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
18 |
+
except ImportError:
|
19 |
+
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
20 |
+
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
|
21 |
+
|
22 |
+
|
23 |
+
def auto_upgrade(config):
|
24 |
+
cfg = AutoConfig.from_pretrained(config)
|
25 |
+
if 'llava' in config and cfg.model_type != 'llava':
|
26 |
+
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
27 |
+
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
28 |
+
confirm = input(
|
29 |
+
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
30 |
+
if confirm.lower() in ["y", "yes"]:
|
31 |
+
print("Upgrading checkpoint...")
|
32 |
+
assert len(cfg.architectures) == 1
|
33 |
+
setattr(cfg.__class__, "model_type", "llava")
|
34 |
+
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
35 |
+
cfg.save_pretrained(config)
|
36 |
+
print("Checkpoint upgraded.")
|
37 |
+
else:
|
38 |
+
print("Checkpoint upgrade aborted.")
|
39 |
+
exit(1)
|
40 |
+
|
41 |
+
|
42 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
43 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
44 |
+
self.keywords = keywords
|
45 |
+
self.tokenizer = tokenizer
|
46 |
+
self.start_len = None
|
47 |
+
self.input_ids = input_ids
|
48 |
+
|
49 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
50 |
+
if self.start_len is None:
|
51 |
+
self.start_len = self.input_ids.shape[1]
|
52 |
+
else:
|
53 |
+
outputs = self.tokenizer.batch_decode(
|
54 |
+
output_ids[:, self.start_len:], skip_special_tokens=True)[0]
|
55 |
+
for keyword in self.keywords:
|
56 |
+
if keyword in outputs:
|
57 |
+
return True
|
58 |
+
return False
|
59 |
+
|
60 |
+
|
61 |
+
def auto_upgrade(config):
|
62 |
+
cfg = AutoConfig.from_pretrained(config)
|
63 |
+
if 'llava' in config and cfg.model_type != 'llava':
|
64 |
+
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
65 |
+
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
66 |
+
confirm = input(
|
67 |
+
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
68 |
+
if confirm.lower() in ["y", "yes"]:
|
69 |
+
print("Upgrading checkpoint...")
|
70 |
+
assert len(cfg.architectures) == 1
|
71 |
+
setattr(cfg.__class__, "model_type", "llava")
|
72 |
+
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
73 |
+
cfg.save_pretrained(config)
|
74 |
+
print("Checkpoint upgraded.")
|
75 |
+
else:
|
76 |
+
print("Checkpoint upgrade aborted.")
|
77 |
+
exit(1)
|
78 |
+
|
79 |
+
# aug functions
|
80 |
+
|
81 |
+
|
82 |
+
def identity_func(img):
|
83 |
+
return img
|
84 |
+
|
85 |
+
|
86 |
+
def autocontrast_func(img, cutoff=0):
|
87 |
+
'''
|
88 |
+
same output as PIL.ImageOps.autocontrast
|
89 |
+
'''
|
90 |
+
n_bins = 256
|
91 |
+
|
92 |
+
def tune_channel(ch):
|
93 |
+
n = ch.size
|
94 |
+
cut = cutoff * n // 100
|
95 |
+
if cut == 0:
|
96 |
+
high, low = ch.max(), ch.min()
|
97 |
+
else:
|
98 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
99 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
100 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
101 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
102 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
103 |
+
if high <= low:
|
104 |
+
table = np.arange(n_bins)
|
105 |
+
else:
|
106 |
+
scale = (n_bins - 1) / (high - low)
|
107 |
+
table = np.arange(n_bins) * scale - low * scale
|
108 |
+
table[table < 0] = 0
|
109 |
+
table[table > n_bins - 1] = n_bins - 1
|
110 |
+
table = table.clip(0, 255).astype(np.uint8)
|
111 |
+
return table[ch]
|
112 |
+
|
113 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
114 |
+
out = cv2.merge(channels)
|
115 |
+
return out
|
116 |
+
|
117 |
+
|
118 |
+
def equalize_func(img):
|
119 |
+
'''
|
120 |
+
same output as PIL.ImageOps.equalize
|
121 |
+
PIL's implementation is different from cv2.equalize
|
122 |
+
'''
|
123 |
+
n_bins = 256
|
124 |
+
|
125 |
+
def tune_channel(ch):
|
126 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
127 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
128 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
129 |
+
if step == 0:
|
130 |
+
return ch
|
131 |
+
n = np.empty_like(hist)
|
132 |
+
n[0] = step // 2
|
133 |
+
n[1:] = hist[:-1]
|
134 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
135 |
+
return table[ch]
|
136 |
+
|
137 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
138 |
+
out = cv2.merge(channels)
|
139 |
+
return out
|
140 |
+
|
141 |
+
|
142 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
143 |
+
'''
|
144 |
+
like PIL, rotate by degree, not radians
|
145 |
+
'''
|
146 |
+
H, W = img.shape[0], img.shape[1]
|
147 |
+
center = W / 2, H / 2
|
148 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
149 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
150 |
+
return out
|
151 |
+
|
152 |
+
|
153 |
+
def solarize_func(img, thresh=128):
|
154 |
+
'''
|
155 |
+
same output as PIL.ImageOps.posterize
|
156 |
+
'''
|
157 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
158 |
+
table = table.clip(0, 255).astype(np.uint8)
|
159 |
+
out = table[img]
|
160 |
+
return out
|
161 |
+
|
162 |
+
|
163 |
+
def color_func(img, factor):
|
164 |
+
'''
|
165 |
+
same output as PIL.ImageEnhance.Color
|
166 |
+
'''
|
167 |
+
# implementation according to PIL definition, quite slow
|
168 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
169 |
+
# out = blend(degenerate, img, factor)
|
170 |
+
# M = (
|
171 |
+
# np.eye(3) * factor
|
172 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
173 |
+
# )[np.newaxis, np.newaxis, :]
|
174 |
+
M = (
|
175 |
+
np.float32([
|
176 |
+
[0.886, -0.114, -0.114],
|
177 |
+
[-0.587, 0.413, -0.587],
|
178 |
+
[-0.299, -0.299, 0.701]]) * factor
|
179 |
+
+ np.float32([[0.114], [0.587], [0.299]])
|
180 |
+
)
|
181 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
182 |
+
return out
|
183 |
+
|
184 |
+
|
185 |
+
def contrast_func(img, factor):
|
186 |
+
"""
|
187 |
+
same output as PIL.ImageEnhance.Contrast
|
188 |
+
"""
|
189 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
190 |
+
table = np.array([(
|
191 |
+
el - mean) * factor + mean
|
192 |
+
for el in range(256)
|
193 |
+
]).clip(0, 255).astype(np.uint8)
|
194 |
+
out = table[img]
|
195 |
+
return out
|
196 |
+
|
197 |
+
|
198 |
+
def brightness_func(img, factor):
|
199 |
+
'''
|
200 |
+
same output as PIL.ImageEnhance.Contrast
|
201 |
+
'''
|
202 |
+
table = (np.arange(256, dtype=np.float32) *
|
203 |
+
factor).clip(0, 255).astype(np.uint8)
|
204 |
+
out = table[img]
|
205 |
+
return out
|
206 |
+
|
207 |
+
|
208 |
+
def sharpness_func(img, factor):
|
209 |
+
'''
|
210 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
211 |
+
areas are same
|
212 |
+
'''
|
213 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
214 |
+
kernel[1][1] = 5
|
215 |
+
kernel /= 13
|
216 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
217 |
+
if factor == 0.0:
|
218 |
+
out = degenerate
|
219 |
+
elif factor == 1.0:
|
220 |
+
out = img
|
221 |
+
else:
|
222 |
+
out = img.astype(np.float32)
|
223 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
224 |
+
out[1:-1, 1:-1, :] = degenerate + factor * \
|
225 |
+
(out[1:-1, 1:-1, :] - degenerate)
|
226 |
+
out = out.astype(np.uint8)
|
227 |
+
return out
|
228 |
+
|
229 |
+
|
230 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
231 |
+
H, W = img.shape[0], img.shape[1]
|
232 |
+
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
233 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
234 |
+
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
235 |
+
return out
|
236 |
+
|
237 |
+
|
238 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
239 |
+
'''
|
240 |
+
same output as PIL.Image.transform
|
241 |
+
'''
|
242 |
+
H, W = img.shape[0], img.shape[1]
|
243 |
+
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
244 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
245 |
+
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
246 |
+
return out
|
247 |
+
|
248 |
+
|
249 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
250 |
+
'''
|
251 |
+
same output as PIL.Image.transform
|
252 |
+
'''
|
253 |
+
H, W = img.shape[0], img.shape[1]
|
254 |
+
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
255 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
256 |
+
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
257 |
+
return out
|
258 |
+
|
259 |
+
|
260 |
+
def posterize_func(img, bits):
|
261 |
+
'''
|
262 |
+
same output as PIL.ImageOps.posterize
|
263 |
+
'''
|
264 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
265 |
+
return out
|
266 |
+
|
267 |
+
|
268 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
269 |
+
H, W = img.shape[0], img.shape[1]
|
270 |
+
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
271 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
272 |
+
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
273 |
+
return out
|
274 |
+
|
275 |
+
|
276 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
277 |
+
replace = np.array(replace, dtype=np.uint8)
|
278 |
+
H, W = img.shape[0], img.shape[1]
|
279 |
+
rh, rw = np.random.random(2)
|
280 |
+
pad_size = pad_size // 2
|
281 |
+
ch, cw = int(rh * H), int(rw * W)
|
282 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
283 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
284 |
+
out = img.copy()
|
285 |
+
out[x1:x2, y1:y2, :] = replace
|
286 |
+
return out
|
287 |
+
|
288 |
+
|
289 |
+
# level to args
|
290 |
+
def enhance_level_to_args(MAX_LEVEL):
|
291 |
+
def level_to_args(level):
|
292 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
293 |
+
return level_to_args
|
294 |
+
|
295 |
+
|
296 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
297 |
+
def level_to_args(level):
|
298 |
+
level = (level / MAX_LEVEL) * 0.3
|
299 |
+
if np.random.random() > 0.5:
|
300 |
+
level = -level
|
301 |
+
return (level, replace_value)
|
302 |
+
|
303 |
+
return level_to_args
|
304 |
+
|
305 |
+
|
306 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
307 |
+
def level_to_args(level):
|
308 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
309 |
+
if np.random.random() > 0.5:
|
310 |
+
level = -level
|
311 |
+
return (level, replace_value)
|
312 |
+
|
313 |
+
return level_to_args
|
314 |
+
|
315 |
+
|
316 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
317 |
+
def level_to_args(level):
|
318 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
319 |
+
return (level, replace_value)
|
320 |
+
|
321 |
+
return level_to_args
|
322 |
+
|
323 |
+
|
324 |
+
def solarize_level_to_args(MAX_LEVEL):
|
325 |
+
def level_to_args(level):
|
326 |
+
level = int((level / MAX_LEVEL) * 256)
|
327 |
+
return (level, )
|
328 |
+
return level_to_args
|
329 |
+
|
330 |
+
|
331 |
+
def none_level_to_args(level):
|
332 |
+
return ()
|
333 |
+
|
334 |
+
|
335 |
+
def posterize_level_to_args(MAX_LEVEL):
|
336 |
+
def level_to_args(level):
|
337 |
+
level = int((level / MAX_LEVEL) * 4)
|
338 |
+
return (level, )
|
339 |
+
return level_to_args
|
340 |
+
|
341 |
+
|
342 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
343 |
+
def level_to_args(level):
|
344 |
+
level = (level / MAX_LEVEL) * 30
|
345 |
+
if np.random.random() < 0.5:
|
346 |
+
level = -level
|
347 |
+
return (level, replace_value)
|
348 |
+
|
349 |
+
return level_to_args
|
350 |
+
|
351 |
+
|
352 |
+
func_dict = {
|
353 |
+
'Identity': identity_func,
|
354 |
+
'AutoContrast': autocontrast_func,
|
355 |
+
'Equalize': equalize_func,
|
356 |
+
'Rotate': rotate_func,
|
357 |
+
'Solarize': solarize_func,
|
358 |
+
'Color': color_func,
|
359 |
+
'Contrast': contrast_func,
|
360 |
+
'Brightness': brightness_func,
|
361 |
+
'Sharpness': sharpness_func,
|
362 |
+
'ShearX': shear_x_func,
|
363 |
+
'TranslateX': translate_x_func,
|
364 |
+
'TranslateY': translate_y_func,
|
365 |
+
'Posterize': posterize_func,
|
366 |
+
'ShearY': shear_y_func,
|
367 |
+
}
|
368 |
+
|
369 |
+
translate_const = 10
|
370 |
+
MAX_LEVEL = 10
|
371 |
+
replace_value = (128, 128, 128)
|
372 |
+
arg_dict = {
|
373 |
+
'Identity': none_level_to_args,
|
374 |
+
'AutoContrast': none_level_to_args,
|
375 |
+
'Equalize': none_level_to_args,
|
376 |
+
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
|
377 |
+
'Solarize': solarize_level_to_args(MAX_LEVEL),
|
378 |
+
'Color': enhance_level_to_args(MAX_LEVEL),
|
379 |
+
'Contrast': enhance_level_to_args(MAX_LEVEL),
|
380 |
+
'Brightness': enhance_level_to_args(MAX_LEVEL),
|
381 |
+
'Sharpness': enhance_level_to_args(MAX_LEVEL),
|
382 |
+
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
|
383 |
+
'TranslateX': translate_level_to_args(
|
384 |
+
translate_const, MAX_LEVEL, replace_value
|
385 |
+
),
|
386 |
+
'TranslateY': translate_level_to_args(
|
387 |
+
translate_const, MAX_LEVEL, replace_value
|
388 |
+
),
|
389 |
+
'Posterize': posterize_level_to_args(MAX_LEVEL),
|
390 |
+
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
|
391 |
+
}
|
392 |
+
|
393 |
+
|
394 |
+
class RandomAugment(object):
|
395 |
+
|
396 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
397 |
+
self.N = N
|
398 |
+
self.M = M
|
399 |
+
self.isPIL = isPIL
|
400 |
+
if augs:
|
401 |
+
self.augs = augs
|
402 |
+
else:
|
403 |
+
self.augs = list(arg_dict.keys())
|
404 |
+
|
405 |
+
def get_random_ops(self):
|
406 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
407 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
408 |
+
|
409 |
+
def __call__(self, img):
|
410 |
+
if self.isPIL:
|
411 |
+
img = np.array(img)
|
412 |
+
ops = self.get_random_ops()
|
413 |
+
for name, prob, level in ops:
|
414 |
+
if np.random.random() > prob:
|
415 |
+
continue
|
416 |
+
args = arg_dict[name](level)
|
417 |
+
img = func_dict[name](img, *args)
|
418 |
+
return img
|
419 |
+
|
420 |
+
|
421 |
+
def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic', std_mode='IMAGENET_INCEPTION'):
|
422 |
+
if std_mode == 'IMAGENET_INCEPTION':
|
423 |
+
mean = IMAGENET_INCEPTION_MEAN
|
424 |
+
std = IMAGENET_INCEPTION_STD
|
425 |
+
elif std_mode == 'OPENAI_CLIP':
|
426 |
+
mean = OPENAI_CLIP_MEAN
|
427 |
+
std = OPENAI_CLIP_STD
|
428 |
+
else:
|
429 |
+
raise NotImplementedError
|
430 |
+
|
431 |
+
if is_train:
|
432 |
+
crop_scale = float(os.environ.get('TRAIN_CROP_SCALE', 0.9999))
|
433 |
+
t = [
|
434 |
+
RandomResizedCropAndInterpolation(
|
435 |
+
input_size, scale=(crop_scale, 1.0), interpolation='bicubic'),
|
436 |
+
# transforms.RandomHorizontalFlip(),
|
437 |
+
]
|
438 |
+
if randaug and os.environ.get('TRAIN_DO_AUG', 'False') == 'True':
|
439 |
+
print(f'@@@@@ Do random aug during training', flush=True)
|
440 |
+
t.append(
|
441 |
+
RandomAugment(
|
442 |
+
2, 7, isPIL=True,
|
443 |
+
augs=[
|
444 |
+
'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
|
445 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
|
446 |
+
]))
|
447 |
+
else:
|
448 |
+
print(f'@@@@@ Skip random aug during training', flush=True)
|
449 |
+
t += [
|
450 |
+
transforms.ToTensor(),
|
451 |
+
transforms.Normalize(mean=mean, std=std),
|
452 |
+
]
|
453 |
+
t = transforms.Compose(t)
|
454 |
+
else:
|
455 |
+
t = transforms.Compose([
|
456 |
+
transforms.Resize((input_size, input_size),
|
457 |
+
interpolation=transforms.InterpolationMode.BICUBIC),
|
458 |
+
transforms.ToTensor(),
|
459 |
+
transforms.Normalize(mean=mean, std=std)
|
460 |
+
])
|
461 |
+
|
462 |
+
return t
|
463 |
+
|
464 |
+
|
465 |
+
def img2b64(img_path):
|
466 |
+
img = Image.open(img_path) # path to file
|
467 |
+
img_buffer = BytesIO()
|
468 |
+
img.save(img_buffer, format=img.format)
|
469 |
+
byte_data = img_buffer.getvalue()
|
470 |
+
base64_str = base64.b64encode(byte_data) # bytes
|
471 |
+
base64_str = base64_str.decode("utf-8") # str
|
472 |
+
return base64_str
|
473 |
+
|
474 |
+
|
475 |
+
def str2b64(str):
|
476 |
+
return base64.b64encode(str.encode('utf-8')).decode('utf-8')
|
477 |
+
|
478 |
+
|
479 |
+
def b642str(b64):
|
480 |
+
return base64.b64decode(b64).decode('utf-8')
|
481 |
+
|
482 |
+
|
483 |
+
def is_dist_avail_and_initialized():
|
484 |
+
if not dist.is_available():
|
485 |
+
return False
|
486 |
+
if not dist.is_initialized():
|
487 |
+
return False
|
488 |
+
return True
|
489 |
+
|
490 |
+
|
491 |
+
def get_world_size():
|
492 |
+
if not is_dist_avail_and_initialized():
|
493 |
+
return 1
|
494 |
+
return dist.get_world_size()
|
495 |
+
|
496 |
+
|
497 |
+
def get_rank():
|
498 |
+
if not is_dist_avail_and_initialized():
|
499 |
+
return 0
|
500 |
+
return dist.get_rank()
|
501 |
+
|
502 |
+
|
503 |
+
def all_gather(data):
|
504 |
+
"""
|
505 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
506 |
+
Args:
|
507 |
+
data: any picklable object
|
508 |
+
Returns:
|
509 |
+
list[data]: list of data gathered from each rank
|
510 |
+
"""
|
511 |
+
world_size = get_world_size()
|
512 |
+
if world_size == 1:
|
513 |
+
return [data]
|
514 |
+
|
515 |
+
# serialized to a Tensor
|
516 |
+
buffer = pickle.dumps(data)
|
517 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
518 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
519 |
+
|
520 |
+
# obtain Tensor size of each rank
|
521 |
+
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
522 |
+
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
523 |
+
dist.all_gather(size_list, local_size)
|
524 |
+
size_list = [int(size.item()) for size in size_list]
|
525 |
+
max_size = max(size_list)
|
526 |
+
|
527 |
+
# receiving Tensor from all ranks
|
528 |
+
# we pad the tensor because torch all_gather does not support
|
529 |
+
# gathering tensors of different shapes
|
530 |
+
tensor_list = []
|
531 |
+
for _ in size_list:
|
532 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
|
533 |
+
if local_size != max_size:
|
534 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
|
535 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
536 |
+
dist.all_gather(tensor_list, tensor)
|
537 |
+
|
538 |
+
data_list = []
|
539 |
+
for size, tensor in zip(size_list, tensor_list):
|
540 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
541 |
+
data_list.append(pickle.loads(buffer))
|
542 |
+
|
543 |
+
return data_list
|
544 |
+
|
545 |
+
|
546 |
+
def mean(lst):
|
547 |
+
return sum(lst) / len(lst)
|
548 |
+
|
549 |
+
|
550 |
+
def stop_gradient_by_name(name: str):
|
551 |
+
def apply_fn(module):
|
552 |
+
if hasattr(module, name):
|
553 |
+
getattr(module, name).requires_grad_(False)
|
554 |
+
|
555 |
+
return apply_fn
|
omnilmm/train/train_utils.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import copy
|
4 |
+
import time
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import warnings
|
8 |
+
import transformers
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from typing import Dict, Optional, Sequence
|
13 |
+
from omnilmm import conversation as conversation_lib
|
14 |
+
|
15 |
+
IGNORE_INDEX = -100
|
16 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
17 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
18 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
19 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
20 |
+
|
21 |
+
|
22 |
+
def _tokenize_fn(strings: Sequence[str],
|
23 |
+
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
24 |
+
"""Tokenize a list of strings."""
|
25 |
+
tokenized_list = [
|
26 |
+
tokenizer(
|
27 |
+
text,
|
28 |
+
return_tensors="pt",
|
29 |
+
padding="longest",
|
30 |
+
max_length=tokenizer.model_max_length,
|
31 |
+
truncation=True,
|
32 |
+
) for text in strings
|
33 |
+
]
|
34 |
+
input_ids = labels = [
|
35 |
+
tokenized.input_ids[0] for tokenized in tokenized_list
|
36 |
+
]
|
37 |
+
input_ids_lens = labels_lens = [
|
38 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
39 |
+
for tokenized in tokenized_list
|
40 |
+
]
|
41 |
+
return dict(
|
42 |
+
input_ids=input_ids,
|
43 |
+
labels=labels,
|
44 |
+
input_ids_lens=input_ids_lens,
|
45 |
+
labels_lens=labels_lens,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def omni_preprocess(sources,
|
51 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
52 |
+
generation=False):
|
53 |
+
system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.'
|
54 |
+
ignore_index = -100
|
55 |
+
|
56 |
+
response_template = '\n<|assistant|>\n'
|
57 |
+
instruction_template = '\n<|user|>\n'
|
58 |
+
response_token_ids = tokenizer.encode(
|
59 |
+
response_template, add_special_tokens=False)
|
60 |
+
instruction_token_ids = tokenizer.encode(
|
61 |
+
instruction_template, add_special_tokens=False)
|
62 |
+
|
63 |
+
batch_input_ids = []
|
64 |
+
batch_labels = []
|
65 |
+
for i in range(len(sources)):
|
66 |
+
new_source = []
|
67 |
+
prev_role = 'unexpect'
|
68 |
+
for conv_turn in sources[i]:
|
69 |
+
role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role']
|
70 |
+
content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content']
|
71 |
+
|
72 |
+
role = 'user' if role == 'human' else role
|
73 |
+
role = 'assistant' if role == 'gpt' else role
|
74 |
+
|
75 |
+
assert role in ['user', 'assistant']
|
76 |
+
assert role != prev_role, f'role={role}, prev_role={prev_role}'
|
77 |
+
prev_role = role
|
78 |
+
|
79 |
+
new_turn = {
|
80 |
+
'role': role,
|
81 |
+
'content': content
|
82 |
+
}
|
83 |
+
new_source.append(new_turn)
|
84 |
+
if new_source[0]['role'] != 'system':
|
85 |
+
new_source.insert(0, {'role': 'system', 'content': system_content})
|
86 |
+
|
87 |
+
# TODO: this automatically add '\n' to the end
|
88 |
+
res_text = tokenizer.apply_chat_template(
|
89 |
+
new_source, tokenize=False, add_generation_prompt=generation)
|
90 |
+
if not generation:
|
91 |
+
res_text = res_text.strip()
|
92 |
+
|
93 |
+
conversations_tokenized = _tokenize_fn([res_text], tokenizer)
|
94 |
+
res_input_ids = conversations_tokenized["input_ids"][0]
|
95 |
+
|
96 |
+
# since labels and input_ids are reference towards the same object
|
97 |
+
res_labels = copy.deepcopy(conversations_tokenized["labels"][0])
|
98 |
+
|
99 |
+
response_token_ids_idxs = []
|
100 |
+
human_token_ids_idxs = []
|
101 |
+
|
102 |
+
for assistant_idx in np.where(res_labels == response_token_ids[0])[0]:
|
103 |
+
# find the indexes of the start of a response.
|
104 |
+
if (response_token_ids == res_labels[assistant_idx: assistant_idx + len(
|
105 |
+
response_token_ids)].tolist()
|
106 |
+
):
|
107 |
+
response_token_ids_idxs.append(
|
108 |
+
assistant_idx + len(response_token_ids))
|
109 |
+
|
110 |
+
if len(response_token_ids_idxs) == 0:
|
111 |
+
warnings.warn(
|
112 |
+
f"Could not find response key `{response_template}` in the "
|
113 |
+
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
|
114 |
+
f'Raw text is @===>{res_text}<===@'
|
115 |
+
f'Raw source is @===>{new_source}<===@'
|
116 |
+
f"This instance will be ignored in loss calculation. "
|
117 |
+
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
118 |
+
)
|
119 |
+
res_labels[:] = ignore_index
|
120 |
+
|
121 |
+
human_token_ids = instruction_token_ids
|
122 |
+
for human_idx in np.where(res_labels == human_token_ids[0])[0]:
|
123 |
+
# find the indexes of the start of a human answer.
|
124 |
+
if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist():
|
125 |
+
human_token_ids_idxs.append(human_idx)
|
126 |
+
|
127 |
+
if len(human_token_ids_idxs) == 0:
|
128 |
+
warnings.warn(
|
129 |
+
f"Could not find instruction key `{instruction_template}` in the "
|
130 |
+
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
|
131 |
+
f'Raw text is @===>{res_text}<===@'
|
132 |
+
f'Raw source is @===>{new_source}<===@'
|
133 |
+
f"This instance will be ignored in loss calculation. "
|
134 |
+
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
135 |
+
)
|
136 |
+
res_labels[:] = ignore_index
|
137 |
+
|
138 |
+
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
|
139 |
+
# Make pytorch loss function ignore all non response tokens
|
140 |
+
if idx != 0:
|
141 |
+
res_labels[start:end] = ignore_index
|
142 |
+
else:
|
143 |
+
res_labels[:end] = ignore_index
|
144 |
+
|
145 |
+
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
|
146 |
+
res_labels[human_token_ids_idxs[-1]:] = ignore_index
|
147 |
+
|
148 |
+
batch_input_ids.append(res_input_ids)
|
149 |
+
batch_labels.append(res_labels)
|
150 |
+
|
151 |
+
return dict(input_ids=batch_input_ids, labels=batch_labels)
|
152 |
+
|
153 |
+
|
omnilmm/utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import logging.handlers
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from omnilmm.constants import LOGDIR
|
10 |
+
|
11 |
+
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
12 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
13 |
+
|
14 |
+
handler = None
|
15 |
+
|
16 |
+
|
17 |
+
def build_logger(logger_name, logger_filename):
|
18 |
+
global handler
|
19 |
+
|
20 |
+
formatter = logging.Formatter(
|
21 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
22 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
23 |
+
)
|
24 |
+
|
25 |
+
# Set the format of root handlers
|
26 |
+
if not logging.getLogger().handlers:
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
29 |
+
|
30 |
+
# Redirect stdout and stderr to loggers
|
31 |
+
stdout_logger = logging.getLogger("stdout")
|
32 |
+
stdout_logger.setLevel(logging.INFO)
|
33 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
34 |
+
sys.stdout = sl
|
35 |
+
|
36 |
+
stderr_logger = logging.getLogger("stderr")
|
37 |
+
stderr_logger.setLevel(logging.ERROR)
|
38 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
39 |
+
sys.stderr = sl
|
40 |
+
|
41 |
+
# Get logger
|
42 |
+
logger = logging.getLogger(logger_name)
|
43 |
+
logger.setLevel(logging.INFO)
|
44 |
+
|
45 |
+
# Add a file handler for all loggers
|
46 |
+
if handler is None:
|
47 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
48 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
49 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
50 |
+
filename, when='D', utc=True)
|
51 |
+
handler.setFormatter(formatter)
|
52 |
+
|
53 |
+
for name, item in logging.root.manager.loggerDict.items():
|
54 |
+
if isinstance(item, logging.Logger):
|
55 |
+
item.addHandler(handler)
|
56 |
+
|
57 |
+
return logger
|
58 |
+
|
59 |
+
|
60 |
+
class StreamToLogger(object):
|
61 |
+
"""
|
62 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self, logger, log_level=logging.INFO):
|
66 |
+
self.terminal = sys.stdout
|
67 |
+
self.logger = logger
|
68 |
+
self.log_level = log_level
|
69 |
+
self.linebuf = ''
|
70 |
+
|
71 |
+
def __getattr__(self, attr):
|
72 |
+
return getattr(self.terminal, attr)
|
73 |
+
|
74 |
+
def write(self, buf):
|
75 |
+
temp_linebuf = self.linebuf + buf
|
76 |
+
self.linebuf = ''
|
77 |
+
for line in temp_linebuf.splitlines(True):
|
78 |
+
# From the io.TextIOWrapper docs:
|
79 |
+
# On output, if newline is None, any '\n' characters written
|
80 |
+
# are translated to the system default line separator.
|
81 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
82 |
+
# translates them so this is still cross platform.
|
83 |
+
if line[-1] == '\n':
|
84 |
+
self.logger.log(self.log_level, line.rstrip())
|
85 |
+
else:
|
86 |
+
self.linebuf += line
|
87 |
+
|
88 |
+
def flush(self):
|
89 |
+
if self.linebuf != '':
|
90 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
91 |
+
self.linebuf = ''
|
92 |
+
|
93 |
+
|
94 |
+
def disable_torch_init():
|
95 |
+
"""
|
96 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
97 |
+
"""
|
98 |
+
import torch
|
99 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
100 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
101 |
+
|
102 |
+
|
103 |
+
def violates_moderation(text):
|
104 |
+
"""
|
105 |
+
Check whether the text violates OpenAI moderation API.
|
106 |
+
"""
|
107 |
+
url = "https://api.openai.com/v1/moderations"
|
108 |
+
headers = {"Content-Type": "application/json",
|
109 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
110 |
+
text = text.replace("\n", "")
|
111 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
112 |
+
data = data.encode("utf-8")
|
113 |
+
try:
|
114 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
115 |
+
flagged = ret.json()["results"][0]["flagged"]
|
116 |
+
except requests.exceptions.RequestException as e:
|
117 |
+
flagged = False
|
118 |
+
except KeyError as e:
|
119 |
+
flagged = False
|
120 |
+
|
121 |
+
return flagged
|
122 |
+
|
123 |
+
|
124 |
+
def pretty_print_semaphore(semaphore):
|
125 |
+
if semaphore is None:
|
126 |
+
return "None"
|
127 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Pillow==10.1.0
|
2 |
+
torch==2.1.2
|
3 |
+
torchvision==0.16.2
|
4 |
+
transformers==4.40.0
|
5 |
+
sentencepiece==0.1.99
|
6 |
+
opencv-python
|
7 |
+
gradio
|