File size: 7,293 Bytes
dd4cd4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from typing import List, Optional, Tuple, Union
import torch.utils.checkpoint
from torch import nn
from transformers import GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from .configuration_phantom import PhantomConfig
from .modeling_intern_vit import InternVisionModel
from .modeling_internlm2 import InternLM2ForCausalLM
from utils.utils import *
class PhantomForCausalLM(PreTrainedModel):
config_class = PhantomConfig
main_input_name = 'pixel_values'
_supports_flash_attn_2 = True
_no_split_modules = ['InternVisionModel', 'InternLM2DecoderLayer']
def __init__(self, config: PhantomConfig):
super().__init__(config)
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.template = config.template
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
self.downsample_ratio = config.downsample_ratio
self.vision_model = InternVisionModel(config.vision_config)
self.language_model = InternLM2ForCausalLM(config.llm_config)
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.llm_config.hidden_size
self.vision_proj = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size)
)
# prompt rule
self.prompt_rule = {
"system_start": "<|im_start|>system\n",
"system_end": "<|im_end|>",
"user_start": "<|im_start|>user\n",
"user_end": "<|im_end|>",
"assistant_start": "<|im_start|>assistant\n",
"assistant_end": "<|im_end|>",
"test_start": "assistant\n",
"test_end": "<|im_end|>",
"split": "",
}
def eval_process(
self,
inputs,
tokenizer,
data,
device,
):
batched_image=[]
batched_qa_prompt=[]
batched_phantom_position = []
for _input in inputs:
# making image prompt
if 'image' in _input.keys() and _input['image'] != None:
process_image = dynamic_preprocess(_input['image'].to(device))
dynamic_process_image = torch.stack([dynamic_transform(image) for image in process_image]).to(device)
img_token_number = dynamic_process_image.shape[0] * 256
batched_image.append(dynamic_process_image)
# make question and answer
question = _input['question']
# make instruction (qa pair) and label
qa_prompt = make_instruction(question, data, self.prompt_rule)
# adding image special tokens to question
if 'image' in _input.keys():
qa_prompt = qa_prompt.replace('<image>', '<img><IMG_CONTEXT></img>')
# add bundle image tokens if it has <image> token
qa_prompt = add_bundle_tokens(qa_prompt, '<IMG_CONTEXT>', img_token_number)
# phantom_position
label = tokenizer(qa_prompt, return_tensors='pt', add_special_tokens=False).input_ids[0].to(device)
phantom_position = torch.zeros_like(label)
phantom_position[0] = 1
# batched processing
batched_qa_prompt.append(qa_prompt)
batched_phantom_position.append(phantom_position.flip(dims=[0]))
'''For Final Outputs'''
qa_prompts = tokenizer(batched_qa_prompt, padding='longest', return_tensors="pt", add_special_tokens=False)
# [1] input_ids
input_ids = qa_prompts.input_ids.to(device)
# [2] attention_mask
attention_mask = qa_prompts.attention_mask.to(device)
# [3] Phantom Position
batched_phantom_position = torch.nn.utils.rnn.pad_sequence(batched_phantom_position, batch_first=True, padding_value=0).flip(dims=[1]) # padding left
if len(batched_image):
return {"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": torch.cat(batched_image, dim=0).to(device),
"phantom_position": batched_phantom_position.bool()
}
else:
return {"input_ids": input_ids,
"attention_mask": attention_mask,
"phantom_position": batched_phantom_position.bool()
}
def extract_feature(self, pixel_values):
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=False,
return_dict=True).last_hidden_state
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
vit_embeds = self.vision_proj(vit_embeds)
return vit_embeds
@torch.no_grad()
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
phantom_position: torch.BoolTensor = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**generate_kwargs,
) -> torch.LongTensor:
if pixel_values is not None:
vit_embeds = self.extract_feature(pixel_values.to(torch.bfloat16))
input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.config.image_token_index)
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
input_embeds = input_embeds.reshape(B, N, C)
else:
input_embeds = self.language_model.get_input_embeddings()(input_ids)
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
phantom_position=phantom_position,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=True,
pad_token_id=self.config.eos_token_id,
eos_token_id=self.config.eos_token_id,
**generate_kwargs,
)
return outputs
|