jiuhai's picture
first
99aee7a
raw
history blame
16.9 kB
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, \
LlamaConfig, LlamaModel, LlamaForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX
import pdb
class LlavaConfig(LlamaConfig):
model_type = "llava_llama"
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
config_class = LlavaConfig
def __init__(self, config: LlamaConfig):
super(LlavaLlamaModel, self).__init__(config)
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = LlavaLlamaModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
ids: Optional[list] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
gen_image: Optional[torch.FloatTensor] = None,
und_image: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# print(f"gen_image {gen_image}")
# print(f"und_image {und_image}")
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
img_loss_indicator,
img_indicator,
target_image_embeds
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
gen_image,
und_image,
image_sizes
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
# img_indicator=img_indicator,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
# cache_position=cache_position,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
total_loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
# compute image loss
# target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb
img_loss_funct = torch.nn.MSELoss()
img_hidden_states = self.get_model().down_projector(hidden_states[img_loss_indicator] if img_loss_indicator.sum()>0 else hidden_states[:,:1,:])
img_loss = 0.0
if img_loss_indicator.sum() <= 0:
img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach()))
else: # there are images in the output
# all, conv2_3, conv2_9, seq_3, seq_9, seq_27
n_query = self.get_n_query()
gen_pooling = self.get_gen_pooling()
if gen_pooling == 'all':
# img_loss = img_loss_funct(img_hidden_states, target_image_embeds)
pass
# if we use early pooling then we don't pool again
# elif 'seq' in gen_pooling and not 'early' in gen_pooling:
# step_size = int(gen_pooling.split('_')[1])
# num_step = img_hidden_states.shape[0] // step_size
# select_idx = torch.range(1, num_step) * step_size - 1
# select_idx = select_idx.to(img_hidden_states.device, dtype = torch.long)
# img_hidden_states = torch.index_select(img_hidden_states, 0, select_idx)
# target_image_embeds = torch.index_select(target_image_embeds, 0, select_idx)
# elif 'pool2d' in gen_pooling and not 'early' in gen_pooling:
# stride = int(gen_pooling.split('_')[1])
# num_img = img_hidden_states.shape[0] // n_query
# # print(f"img_hidden_states.shape {img_hidden_states.shape}, n_query {n_query}")
# # print(f"img_loss_indicator, {img_loss_indicator}")
# sqrt_n = int(n_query**0.5)
# img_hidden_states = img_hidden_states.reshape(num_img, n_query, -1)
# target_image_embeds = target_image_embeds.reshape(num_img, n_query, -1)
# channel = img_hidden_states.shape[-1]
# img_hidden_states = img_hidden_states.permute(0, 2, 1).view(num_img, -1, sqrt_n, sqrt_n)
# target_image_embeds = target_image_embeds.permute(0, 2, 1).view(num_img, -1, sqrt_n, sqrt_n)
# img_hidden_states = F.avg_pool2d(img_hidden_states, kernel_size=(stride, stride), stride=stride)
# target_image_embeds = F.avg_pool2d(target_image_embeds, kernel_size=(stride, stride), stride=stride)
# img_hidden_states = img_hidden_states.reshape(num_img, channel, -1).permute(0,2,1)
# target_image_embeds = target_image_embeds.reshape(num_img, channel, -1).permute(0,2,1)
img_loss = img_loss_funct(img_hidden_states, target_image_embeds)
print(f"img loss {img_loss}, text loss {loss}")
total_loss = loss + img_loss
return CausalLMOutputWithPast(
loss=total_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
img_indicator,
_
) = self.prepare_inputs_labels_for_understanding(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
@torch.no_grad()
def generate_image(
self,
text: List[str],
tokenizer: AutoTokenizer,
image: Optional[torch.Tensor] = None,
# placeholder: str = DEFAULT_IMG_PLACEHOLDER,
):
vision_tower = self.get_vision_tower()
mm_projector = self.get_mm_projector()
gen_projector = self.get_gen_projector()
N_QUERY = self.get_n_query()
image_placeholder = DEFAULT_IM_START_TOKEN + N_QUERY*DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if image is not None:
# image: [Batch, 3, 448, 448]
prompt_image_embeds = vision_tower(batch_images)
num_img, _, c = prompt_image_embeds.shape # [batch, 576, 1024]
all_image_embeds = torch.clone(prompt_image_embeds).detach()
prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
prompt_image_embeds = mm_projector(prompt_image_embeds)
# prompt_image_embeds = prompt_image_embeds.view(-1, self.config.hidden_size)
text = [t.replace(DEFAULT_IMAGE_TOKEN, image_placeholder) for t in text]
# pdb.set_trace()
target_image_embeds = None
for num_img_token in range(N_QUERY):
if num_img_token == 0:
text = [f"{t}{DEFAULT_IM_START_TOKEN}" for t in text]
else:
text = [f"{t}{DEFAULT_IMAGE_TOKEN}" for t in text]
inputs = tokenizer(text, padding="longest", return_tensors="pt")
device = self.get_model().device
attention_mask = inputs.attention_mask.to(device)
input_ids = inputs.input_ids.to(device) # B x N
text_embeds = self.get_model().embed_tokens(input_ids)
image_idx = (input_ids == IMAGE_TOKEN_IDX)
img_indicator = torch.clone(image_idx)
img_indicator = torch.cat([img_indicator[:, 1:], img_indicator[:, :1]], dim=1)
img_indicator[:,-1] = True
cumsum_idx = torch.flip(torch.cumsum(
torch.flip(image_idx, dims=[1]), dim=1), dims=[1])
if image is not None:
prompt_idx = torch.logical_and(
image_idx, cumsum_idx > num_img_token)
text_embeds[prompt_idx] = prompt_image_embeds.to(
text_embeds.device)
if target_image_embeds is not None:
target_idx = torch.logical_and(image_idx, torch.logical_and(
cumsum_idx > 0, cumsum_idx <= num_img_token))
text_embeds[target_idx] = gen_projector(
target_image_embeds).to(text_embeds.device)
outputs = self.model(
inputs_embeds=text_embeds,
# img_indicator=img_indicator,
# concept_indicator=concept_indicator if self.use_concept_token else None,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
)
image_idx = (input_ids == IMAGE_TOKEN_IDX) + (input_ids == DEFAULT_IM_START_TOKEN_IDX)
cumsum_idx = torch.flip(torch.cumsum(
torch.flip(image_idx, dims=[1]), dim=1), dims=[1])
target_idx = torch.logical_and(image_idx, torch.logical_and(
cumsum_idx > 0, cumsum_idx <= num_img_token+1))
hidden_states = outputs.hidden_states[-1]
target_image_embeds = hidden_states[target_idx.to(
hidden_states.device)]
target_image_embeds = target_image_embeds.view(
-1, target_image_embeds.shape[-1])
target_image_embeds = self.get_model().down_projector(target_image_embeds)
_, C = target_image_embeds.shape
B = hidden_states.shape[0]
target_image_embeds = target_image_embeds.view(B, -1, C)
# pdb.set_trace()
return target_image_embeds
def prepare_and_encode_inputs(
self,
inputs: List[str | Image.Image],
tokenizer: AutoTokenizer,
do_classifier_free_guidance: bool = False,
):
# pdb.set_trace()
device = self.get_model().device
dtype = self.get_model().dtype
has_image, has_text = False, False
text_prompt, image_prompt = "", []
img_processor = self.get_vision_tower().image_processor
negative_prompt = {}
for x in inputs:
if isinstance(x, str):
has_text = True
text_prompt += x
else:
has_image = True
text_prompt += DEFAULT_IMAGE_TOKEN
image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
# pdb.set_trace()
if len(image_prompt) == 0:
image_prompt = None
else:
image_prompt = torch.cat(image_prompt)
image_prompt = image_prompt.type(dtype).to(device)
if has_image and not has_text:
prompt = self.encode_images(image_prompt)
# pdb.set_trace()
if do_classifier_free_guidance:
key = "[NULL_IMAGE]"
if key not in negative_prompt:
negative_image = torch.zeros_like(image_prompt)
negative_prompt[key] = self.encode_images(negative_image)
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
else:
prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
if do_classifier_free_guidance:
key = ""
if key not in negative_prompt:
negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
gen_pooling = self.get_gen_pooling()
n_query = self.get_n_query()
num_img, _, c = prompt.shape
if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
stride = int(gen_pooling.split('_')[1])
sqrt_n = int(n_query**0.5)
prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
return prompt
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)