|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Tuple, Union |
|
from PIL import Image |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, \ |
|
LlamaConfig, LlamaModel, LlamaForCausalLM, AutoTokenizer |
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.generation.utils import GenerateOutput |
|
|
|
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM |
|
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX |
|
import pdb |
|
|
|
class LlavaConfig(LlamaConfig): |
|
model_type = "llava_llama" |
|
|
|
|
|
class LlavaLlamaModel(LlavaMetaModel, LlamaModel): |
|
config_class = LlavaConfig |
|
|
|
def __init__(self, config: LlamaConfig): |
|
super(LlavaLlamaModel, self).__init__(config) |
|
|
|
|
|
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): |
|
config_class = LlavaConfig |
|
|
|
def __init__(self, config): |
|
super(LlamaForCausalLM, self).__init__(config) |
|
self.model = LlavaLlamaModel(config) |
|
self.pretraining_tp = config.pretraining_tp |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_model(self): |
|
return self.model |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
ids: Optional[list] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
gen_image: Optional[torch.FloatTensor] = None, |
|
und_image: Optional[torch.FloatTensor] = None, |
|
image_sizes: Optional[List[List[int]]] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if inputs_embeds is None: |
|
( |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
inputs_embeds, |
|
labels, |
|
img_loss_indicator, |
|
img_indicator, |
|
target_image_embeds |
|
) = self.prepare_inputs_labels_for_multimodal( |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
labels, |
|
gen_image, |
|
und_image, |
|
image_sizes |
|
) |
|
|
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
|
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
|
|
) |
|
|
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
logits = logits.float() |
|
|
|
total_loss = None |
|
if labels is not None: |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = torch.nn.CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
|
|
|
|
img_loss_funct = torch.nn.MSELoss() |
|
img_hidden_states = self.get_model().down_projector(hidden_states[img_loss_indicator] if img_loss_indicator.sum()>0 else hidden_states[:,:1,:]) |
|
img_loss = 0.0 |
|
|
|
if img_loss_indicator.sum() <= 0: |
|
img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach())) |
|
else: |
|
|
|
n_query = self.get_n_query() |
|
gen_pooling = self.get_gen_pooling() |
|
if gen_pooling == 'all': |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_loss = img_loss_funct(img_hidden_states, target_image_embeds) |
|
|
|
print(f"img loss {img_loss}, text loss {loss}") |
|
total_loss = loss + img_loss |
|
|
|
return CausalLMOutputWithPast( |
|
loss=total_loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
images: Optional[torch.Tensor] = None, |
|
image_sizes: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> Union[GenerateOutput, torch.LongTensor]: |
|
position_ids = kwargs.pop("position_ids", None) |
|
attention_mask = kwargs.pop("attention_mask", None) |
|
if "inputs_embeds" in kwargs: |
|
raise NotImplementedError("`inputs_embeds` is not supported") |
|
|
|
if images is not None: |
|
( |
|
inputs, |
|
position_ids, |
|
attention_mask, |
|
_, |
|
inputs_embeds, |
|
img_indicator, |
|
_ |
|
) = self.prepare_inputs_labels_for_understanding( |
|
inputs, |
|
position_ids, |
|
attention_mask, |
|
None, |
|
None, |
|
images, |
|
image_sizes=image_sizes |
|
) |
|
else: |
|
inputs_embeds = self.get_model().embed_tokens(inputs) |
|
|
|
return super().generate( |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
**kwargs |
|
) |
|
|
|
@torch.no_grad() |
|
def generate_image( |
|
self, |
|
text: List[str], |
|
tokenizer: AutoTokenizer, |
|
image: Optional[torch.Tensor] = None, |
|
|
|
): |
|
vision_tower = self.get_vision_tower() |
|
mm_projector = self.get_mm_projector() |
|
gen_projector = self.get_gen_projector() |
|
|
|
N_QUERY = self.get_n_query() |
|
image_placeholder = DEFAULT_IM_START_TOKEN + N_QUERY*DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN |
|
|
|
if image is not None: |
|
|
|
prompt_image_embeds = vision_tower(batch_images) |
|
num_img, _, c = prompt_image_embeds.shape |
|
all_image_embeds = torch.clone(prompt_image_embeds).detach() |
|
prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c) |
|
prompt_image_embeds = mm_projector(prompt_image_embeds) |
|
|
|
|
|
text = [t.replace(DEFAULT_IMAGE_TOKEN, image_placeholder) for t in text] |
|
|
|
target_image_embeds = None |
|
for num_img_token in range(N_QUERY): |
|
if num_img_token == 0: |
|
text = [f"{t}{DEFAULT_IM_START_TOKEN}" for t in text] |
|
else: |
|
text = [f"{t}{DEFAULT_IMAGE_TOKEN}" for t in text] |
|
|
|
inputs = tokenizer(text, padding="longest", return_tensors="pt") |
|
device = self.get_model().device |
|
attention_mask = inputs.attention_mask.to(device) |
|
input_ids = inputs.input_ids.to(device) |
|
|
|
text_embeds = self.get_model().embed_tokens(input_ids) |
|
|
|
image_idx = (input_ids == IMAGE_TOKEN_IDX) |
|
img_indicator = torch.clone(image_idx) |
|
img_indicator = torch.cat([img_indicator[:, 1:], img_indicator[:, :1]], dim=1) |
|
img_indicator[:,-1] = True |
|
|
|
cumsum_idx = torch.flip(torch.cumsum( |
|
torch.flip(image_idx, dims=[1]), dim=1), dims=[1]) |
|
if image is not None: |
|
prompt_idx = torch.logical_and( |
|
image_idx, cumsum_idx > num_img_token) |
|
text_embeds[prompt_idx] = prompt_image_embeds.to( |
|
text_embeds.device) |
|
|
|
if target_image_embeds is not None: |
|
target_idx = torch.logical_and(image_idx, torch.logical_and( |
|
cumsum_idx > 0, cumsum_idx <= num_img_token)) |
|
text_embeds[target_idx] = gen_projector( |
|
target_image_embeds).to(text_embeds.device) |
|
|
|
outputs = self.model( |
|
inputs_embeds=text_embeds, |
|
|
|
|
|
attention_mask=attention_mask, |
|
output_hidden_states=True, |
|
return_dict=True, |
|
) |
|
|
|
image_idx = (input_ids == IMAGE_TOKEN_IDX) + (input_ids == DEFAULT_IM_START_TOKEN_IDX) |
|
cumsum_idx = torch.flip(torch.cumsum( |
|
torch.flip(image_idx, dims=[1]), dim=1), dims=[1]) |
|
target_idx = torch.logical_and(image_idx, torch.logical_and( |
|
cumsum_idx > 0, cumsum_idx <= num_img_token+1)) |
|
|
|
hidden_states = outputs.hidden_states[-1] |
|
target_image_embeds = hidden_states[target_idx.to( |
|
hidden_states.device)] |
|
target_image_embeds = target_image_embeds.view( |
|
-1, target_image_embeds.shape[-1]) |
|
target_image_embeds = self.get_model().down_projector(target_image_embeds) |
|
|
|
_, C = target_image_embeds.shape |
|
B = hidden_states.shape[0] |
|
target_image_embeds = target_image_embeds.view(B, -1, C) |
|
|
|
|
|
return target_image_embeds |
|
|
|
def prepare_and_encode_inputs( |
|
self, |
|
inputs: List[str | Image.Image], |
|
tokenizer: AutoTokenizer, |
|
do_classifier_free_guidance: bool = False, |
|
): |
|
|
|
device = self.get_model().device |
|
dtype = self.get_model().dtype |
|
|
|
has_image, has_text = False, False |
|
text_prompt, image_prompt = "", [] |
|
img_processor = self.get_vision_tower().image_processor |
|
negative_prompt = {} |
|
|
|
for x in inputs: |
|
if isinstance(x, str): |
|
has_text = True |
|
text_prompt += x |
|
else: |
|
has_image = True |
|
text_prompt += DEFAULT_IMAGE_TOKEN |
|
image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values']) |
|
|
|
if len(image_prompt) == 0: |
|
image_prompt = None |
|
else: |
|
image_prompt = torch.cat(image_prompt) |
|
image_prompt = image_prompt.type(dtype).to(device) |
|
|
|
if has_image and not has_text: |
|
prompt = self.encode_images(image_prompt) |
|
|
|
if do_classifier_free_guidance: |
|
key = "[NULL_IMAGE]" |
|
if key not in negative_prompt: |
|
negative_image = torch.zeros_like(image_prompt) |
|
negative_prompt[key] = self.encode_images(negative_image) |
|
prompt = torch.cat([prompt, negative_prompt[key]], dim=0) |
|
else: |
|
prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer) |
|
if do_classifier_free_guidance: |
|
key = "" |
|
if key not in negative_prompt: |
|
negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer) |
|
prompt = torch.cat([prompt, negative_prompt[key]], dim=0) |
|
|
|
gen_pooling = self.get_gen_pooling() |
|
n_query = self.get_n_query() |
|
num_img, _, c = prompt.shape |
|
if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling: |
|
stride = int(gen_pooling.split('_')[1]) |
|
sqrt_n = int(n_query**0.5) |
|
prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n) |
|
prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride) |
|
prompt = prompt.reshape(num_img, c, -1).permute(0,2,1) |
|
return prompt |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, |
|
inputs_embeds=None, **kwargs): |
|
images = kwargs.pop("images", None) |
|
image_sizes = kwargs.pop("image_sizes", None) |
|
inputs = super().prepare_inputs_for_generation( |
|
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs |
|
) |
|
if images is not None: |
|
inputs['images'] = images |
|
if image_sizes is not None: |
|
inputs['image_sizes'] = image_sizes |
|
return inputs |
|
|
|
AutoConfig.register("llava_llama", LlavaConfig) |
|
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) |
|
|