fffiloni's picture
Upload 164 files
2ada650 verified
raw
history blame
7.89 kB
import argparse
import time
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any
from minigpt4.common.registry import registry
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
# system_img: List[Image.Image] = []
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "<s>"
sep2: str = "</s>"
skip_next: bool = False
conv_id: Any = None
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
# ret = self.system + self.sep
ret = self.system +"<s>"
for role, message in self.messages:
if message:
# ret += role + ": " + message + self.sep
ret+= role + message
# ret+= role + message
else:
# ret += role + ":"
# ret += self.sep2 + role
ret += role
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
# ret = self.system + seps[0]
ret = self.system+"<s>"
for i, (role, message) in enumerate(self.messages):
if message:
# ret += role + ": " + message + seps[i % 2]
ret += role+message+seps[i%2]
else:
# ret += role + ":"
ret += role
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
# system_img=self.system_img,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
conv_id=self.conv_id)
def dict(self):
return {
"system": self.system,
# "system_img": self.system_img,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
"conv_id": self.conv_id,
}
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
CONV_VISION = Conversation(
# system="Give the following image: <Img>ImageContent</Img>. "
# "You will be able to see the image once I provide it to you. Please answer my questions.",
system = "",
roles = (r"[INST] ",r" [/INST]"),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="<s>",
)
class Chat:
def __init__(self, model, vis_processor, device='cuda:0'):
self.device = device
self.model = model
self.vis_processor = vis_processor
self.conv = CONV_VISION.copy()
self.img_list = []
self.raw_answers = []
stop_words_ids = [torch.tensor([2]).to(self.device)]
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def reset(self):
self.conv.messages = []
self.img_list = []
# self.img_list = [img for img in self.conv.system_img]
self.raw_answers = []
def ask(self, text, conv):
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
else:
conv.append_message(conv.roles[0], text)
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
current_max_len = embs.shape[1] + max_new_tokens
if current_max_len - max_length > 0:
print('Warning: The number of tokens in current conversation exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, current_max_len - max_length)
embs = embs[:, begin_idx:]
outputs = self.model.llama_model.generate(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria,
num_beams=num_beams,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
do_sample=False,
)
output_token = outputs[0]
if output_token[0] == 0:
output_token = output_token[1:]
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
self.raw_answers.append(output_text)
output_text = output_text.split('</s>')[0] # remove the stop sign '###'
output_text = output_text.replace("<s>", "")
output_text = output_text.split(r'[/INST]')[-1].strip()
self.conv.messages[-1][1] = output_text
return output_text, output_token.cpu().numpy()
def upload_img(self, image):
if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, Image.Image):
raw_image = image
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, torch.Tensor):
if len(image.shape) == 3:
image = image.unsqueeze(0)
image = image.to(self.device)
image_emb, _ = self.model.encode_img(image)
self.img_list.append(image_emb)
self.conv.append_message(self.conv.roles[0], "<Img><ImageHere></Img>")
msg = "Received."
# self.conv.append_message(self.conv.roles[1], msg)
return msg
def get_context_emb(self, conv, img_list):
prompt = conv.get_prompt()
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.model.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs