RLAIF-V-12B / chat.py
finalf0's picture
init
5491537
raw
history blame
No virus
7.76 kB
import os
import torch
import json
from PIL import Image
import base64
import io
#from accelerate import load_checkpoint_and_dispatch, init_empty_weights
from transformers import AutoTokenizer, AutoModel
from omnilmm.utils import disable_torch_init
from omnilmm.model.omnilmm import OmniLMMForCausalLM
from omnilmm.model.utils import build_transform
from omnilmm.train.train_utils import omni_preprocess
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
def init_omni_lmm(model_path):
torch.backends.cuda.matmul.allow_tf32 = True
disable_torch_init()
model_name = os.path.expanduser(model_path)
print(f'Load omni_lmm model and tokenizer from {model_name}')
tokenizer = AutoTokenizer.from_pretrained(
model_name, model_max_length=4096)
if False:
# model on multiple devices for small size gpu memory (Nvidia 3090 24G x2)
with init_empty_weights():
model = OmniLMMForCausalLM.from_pretrained(model_name, tune_clip=True, torch_dtype=torch.bfloat16)
model = load_checkpoint_and_dispatch(model, model_name, dtype=torch.bfloat16,
device_map="auto", no_split_module_classes=['Eva','MistralDecoderLayer', 'ModuleList', 'Resampler']
)
else:
model = OmniLMMForCausalLM.from_pretrained(
model_name, tune_clip=True, torch_dtype=torch.bfloat16
).to(device='cuda', dtype=torch.bfloat16)
image_processor = build_transform(
is_train=False, input_size=model.model.config.image_size, std_mode='OPENAI_CLIP')
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
assert mm_use_im_start_end
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN], special_tokens=True)
vision_config = model.model.vision_config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
vision_config.use_im_start_end = mm_use_im_start_end
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
image_token_len = model.model.config.num_query
return model, image_processor, image_token_len, tokenizer
def expand_question_into_multimodal(question_text, image_token_len, im_st_token, im_ed_token, im_patch_token):
if '<image>' in question_text[0]['content']:
question_text[0]['content'] = question_text[0]['content'].replace(
'<image>', im_st_token + im_patch_token * image_token_len + im_ed_token)
else:
question_text[0]['content'] = im_st_token + im_patch_token * \
image_token_len + im_ed_token + '\n' + question_text[0]['content']
return question_text
def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
question = expand_question_into_multimodal(
question, image_token_len, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN)
conversation = question
data_dict = omni_preprocess(sources=[conversation],
tokenizer=tokenizer,
generation=True)
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
return data_dict
class OmniLMM12B:
def __init__(self, model_path) -> None:
model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path)
self.model = model
self.image_token_len = image_token_len
self.image_transform = img_processor
self.tokenizer = tokenizer
self.model.eval()
def decode(self, image, input_ids):
with torch.inference_mode():
output = self.model.generate_vllm(
input_ids=input_ids.unsqueeze(0).cuda(),
images=image.unsqueeze(0).half().cuda(),
temperature=0.6,
max_new_tokens=1024,
# num_beams=num_beams,
do_sample=True,
output_scores=True,
return_dict_in_generate=True,
repetition_penalty=1.1,
top_k=30,
top_p=0.9,
)
response = self.tokenizer.decode(
output.sequences[0], skip_special_tokens=True)
response = response.strip()
return response
def chat(self, image, msgs):
#image = input['image']
#msgs = json.loads(input['question'])
input_ids = wrap_question_for_omni_lmm(
msgs, self.image_token_len, self.tokenizer)['input_ids']
input_ids = torch.as_tensor(input_ids)
#print('input_ids', input_ids)
image = self.image_transform(image)
out = self.decode(image, input_ids)
return out
def img2base64(file_name):
with open(file_name, 'rb') as f:
encoded_string = base64.b64encode(f.read())
return encoded_string
class MiniCPMV:
def __init__(self, model_path) -> None:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model.eval().cuda()
def chat(self, input):
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"
msgs = json.loads(input['question'])
answer, context, _ = self.model.chat(
image=image,
msgs=msgs,
context=None,
tokenizer=self.tokenizer,
sampling=True,
temperature=0.7
)
return answer
class MiniCPMV2_5:
def __init__(self, model_path) -> None:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.float16)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model.eval().cuda()
def chat(self, input):
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"
msgs = json.loads(input['question'])
answer = self.model.chat(
image=image,
msgs=msgs,
tokenizer=self.tokenizer,
sampling=True,
temperature=0.7
)
return answer
class MiniCPMVChat:
def __init__(self, model_path) -> None:
if '12B' in model_path:
self.model = OmniLMM12B(model_path)
elif 'MiniCPM-Llama3-V' in model_path:
self.model = MiniCPMV2_5(model_path)
else:
self.model = MiniCPMV(model_path)
def chat(self, input):
return self.model.chat(input)
if __name__ == '__main__':
model_path = 'openbmb/OmniLMM-12B'
chat_model = MiniCPMVChat(model_path)
im_64 = img2base64('./assets/worldmap_ck.jpg')
# first round chat
msgs = [{"role": "user", "content": "What is interesting about this image?"}]
input = {"image": im_64, "question": json.dumps(msgs, ensure_ascii=True)}
answer = chat_model.chat(input)
print(msgs[-1]["content"]+'\n', answer)
# second round chat
msgs.append({"role": "assistant", "content": answer})
msgs.append({"role": "user", "content": "Where is China in the image"})
input = {"image": im_64,"question": json.dumps(msgs, ensure_ascii=True)}
answer = chat_model.chat(input)
print(msgs[-1]["content"]+'\n', answer)