OneChart / modeling_OneChart.py
kppkkp's picture
Upload modeling_OneChart.py
49c6f36 verified
from transformers import OPTConfig, OPTModel, OPTForCausalLM, StoppingCriteria, TextStreamer
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from typing import List, Optional, Tuple, Union
import requests
from PIL import Image
from io import BytesIO
import json
import re
import torch
import numpy as np
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from .sam_vision_b import build_SAM_vit_b
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import dataclasses
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
from enum import auto, Enum
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "<|im_end|>"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep + '\n'
for role, message in self.messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
if self.sep_style == SeparatorStyle.MPT:
if self.system:
ret = self.system + self.sep
else:
ret = ''
for role, message in self.messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2)
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
for keyword_id in self.keyword_ids:
if output_ids[0, -1] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
class OneChartImageEvalProcessor:
def __init__(self, image_size=1024):
mean = (0., 0., 0.)
std = (1., 1., 1.)
self.normalize = transforms.Normalize(mean, std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
class OneChartConfig(OPTConfig):
model_type = "OneChart"
class OneChartModel(OPTModel):
config_class = OneChartConfig
def __init__(self, config: OPTConfig):
super(OneChartModel, self).__init__(config)
self.vision_tower = build_SAM_vit_b()
self.mm_projector = nn.Linear(1024, 768)
def embed_tokens(self, x):
return self.get_input_embeddings()(x)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
vision_tower_high = getattr(self, 'vision_tower', None)
if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
use_im_start_end = getattr(self.config, "use_im_start_end", -1)
vision_select_layer = getattr(self.config, "vision_select_layer", -1)
im_patch_token = getattr(self.config, "im_patch_token", -1)
im_start_token = getattr(self.config, "im_start_token", -1)
im_end_token = getattr(self.config, "im_end_token", -1)
freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
image_features = []
for image in images:
P, C, H, W = image.shape
if P == 1:
with torch.set_grad_enabled(False):
cnn_feature = vision_tower_high(image)
cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
image_feature = self.mm_projector(cnn_feature)
image_features.append(image_feature)
else:
raise NotImplementedError("Batch inference needs to be implemented.")
use_im_start_end = True
new_input_embeds = []
for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
if use_im_start_end:
if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
raise ValueError("The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
num_patches = per_cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
raise ValueError("The image end token should follow the image start token.")
cur_input_embeds = torch.cat(
(
cur_input_embeds[:image_start_token_pos+1],
per_cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:]
),
dim=0
)
new_input_embeds.append(cur_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(OneChartModel, self).forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict
)
class OneChartOPTForCausalLM(OPTForCausalLM):
config_class = OneChartConfig
def __init__(self, config):
super(OneChartOPTForCausalLM, self).__init__(config)
self.model = OneChartModel(config)
self.vocab_size = config.vocab_size
self.num_decoder = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size // 2),
nn.ReLU(),
nn.Linear(config.hidden_size // 2, config.hidden_size // 2),
nn.ReLU(),
nn.Linear(config.hidden_size // 2, 256),
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.pred_locs = []
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
loc_labels=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
images=images,
return_dict=return_dict
)
hidden_states = outputs[0]
if (loc_labels is not None) and len(loc_labels) > 0:
det_patch_token = torch.where(input_ids == self.config.number_token)[1][0]
pred_locs = self.num_decoder(hidden_states[:, det_patch_token, :]) # shape: [batch_size, 256]
# inference时输出num_head预测的值
if not self.training:
try:
det_patch_token = torch.where(input_ids == self.config.number_token)[1][0]
pred_locs = self.num_decoder(hidden_states[:, det_patch_token, :]) # shape: [batch_size, 256]
self.pred_locs = pred_locs[0][:100].cpu().tolist()
except Exception as e:
pass
logits = self.lm_head(hidden_states)
logits = logits.float()
# logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"images": kwargs.get("images", None),
}
)
return model_inputs
def load_image(self, image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def disable_torch_init(self):
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def chat(self, tokenizer, image_file, reliable_check=True, print_prompt=False):
device = "cuda" if torch.cuda.is_available() else "cpu"
# dtype = torch.bfloat16 if device=="cuda" else next(self.get_model().parameters()).dtype
dtype=torch.float16 if device=="cuda" else torch.float32
# print(device, dtype)
def list_json_value(json_dict):
rst_str = []
sort_flag = True
try:
for key, value in json_dict.items():
if isinstance(value, dict):
decimal_out = list_json_value(value)
rst_str = rst_str + decimal_out
sort_flag = False
elif isinstance(value, list):
return []
else:
if isinstance(value, float) or isinstance(value, int):
rst_str.append(value)
else:
# num_value = value.replace("%", "").replace("$", "").replace(" ", "").replace(",", "")
value = re.sub(r'\(\d+\)|\[\d+\]', '', value)
num_value = re.sub(r'[^\d.-]', '', str(value))
if num_value not in ["-", "*", "none", "None", ""]:
rst_str.append(float(num_value))
except Exception as e:
print(f"Error: {e}")
# print(json_dict)
return []
# if len(rst_str) > 0:
# rst_str = rst_str + [float(-1)]
return rst_str
def norm_(rst_list):
if len(rst_list) < 2:
return rst_list
min_vals = min(rst_list)
max_vals = max(rst_list)
rst_list = np.array(rst_list)
normalized_tensor = (rst_list - min_vals) / (max_vals - min_vals + 1e-9)
return list(normalized_tensor)
self.disable_torch_init()
image_processor_high = OneChartImageEvalProcessor(image_size=1024)
use_im_start_end = True
image_token_len = 256
image = self.load_image(image_file)
image_tensor_1 = image_processor_high(image).to(dtype=dtype, device=device)
query = 'Convert the key information of the chart to a python dict:'
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + query + '\n'
conv = conv_vicuna_v1_1.copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
if print_prompt:
print(prompt)
inputs = tokenizer([prompt])
input_ids = torch.as_tensor(inputs.input_ids).to(device=device)
stop_str = '</s>'
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
if device=='cuda':
with torch.autocast(device, dtype=dtype):
output_ids = self.generate(
input_ids,
images=[image_tensor_1.unsqueeze(0)],
do_sample=False,
num_beams = 1,
# no_repeat_ngram_size = 20,
# streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
else:
output_ids = self.generate(
input_ids,
images=[image_tensor_1.unsqueeze(0)],
do_sample=False,
num_beams = 1,
# no_repeat_ngram_size = 20,
# streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
outputs = outputs.replace("<Number>", "")
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
response_str = outputs
if reliable_check:
pred_nums = self.pred_locs
try:
outputs_json = json.loads(outputs)
list_v = list_json_value(outputs_json['values'])
list_v = [round(x,4) for x in norm_(list_v)]
gt_nums = torch.tensor(list_v).reshape(1,-1)
response_str = response_str + "\n<Chart>: " + str(pred_nums[:len(list_v)])
pred_nums_ = torch.tensor(pred_nums[:len(list_v)]).reshape(1,-1)
reliable_distence = F.l1_loss(pred_nums_, gt_nums)
response_str = response_str + "\nreliable_distence: " + str(reliable_distence)
if reliable_distence < 0.1:
response_str = response_str + "\nAfter OneChart checking, this prediction is reliable."
else:
response_str = response_str + "\nThis prediction may be has error! "
except Exception as e:
response_str = response_str + "\nThis prediction may be has error! "
response_str = response_str + "\n" + str(e)
return response_str