# Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoConfig, AutoModelForCausalLM, \ LlamaConfig, LlamaModel, LlamaForCausalLM, AutoTokenizer from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX import pdb class LlavaConfig(LlamaConfig): model_type = "llava_llama" class LlavaLlamaModel(LlavaMetaModel, LlamaModel): config_class = LlavaConfig def __init__(self, config: LlamaConfig): super(LlavaLlamaModel, self).__init__(config) class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): config_class = LlavaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LlavaLlamaModel(config) self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, ids: Optional[list] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, gen_image: Optional[torch.FloatTensor] = None, und_image: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # print(f"gen_image {gen_image}") # print(f"und_image {und_image}") if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, img_loss_indicator, img_indicator, target_image_embeds ) = self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, gen_image, und_image, image_sizes ) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, # img_indicator=img_indicator, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, # cache_position=cache_position, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() total_loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = torch.nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) # compute image loss # target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb img_loss_funct = torch.nn.MSELoss() img_hidden_states = self.get_model().down_projector(hidden_states[img_loss_indicator] if img_loss_indicator.sum()>0 else hidden_states[:,:1,:]) img_loss = 0.0 if img_loss_indicator.sum() <= 0: img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach())) else: # there are images in the output # all, conv2_3, conv2_9, seq_3, seq_9, seq_27 n_query = self.get_n_query() gen_pooling = self.get_gen_pooling() if gen_pooling == 'all': # img_loss = img_loss_funct(img_hidden_states, target_image_embeds) pass # if we use early pooling then we don't pool again # elif 'seq' in gen_pooling and not 'early' in gen_pooling: # step_size = int(gen_pooling.split('_')[1]) # num_step = img_hidden_states.shape[0] // step_size # select_idx = torch.range(1, num_step) * step_size - 1 # select_idx = select_idx.to(img_hidden_states.device, dtype = torch.long) # img_hidden_states = torch.index_select(img_hidden_states, 0, select_idx) # target_image_embeds = torch.index_select(target_image_embeds, 0, select_idx) # elif 'pool2d' in gen_pooling and not 'early' in gen_pooling: # stride = int(gen_pooling.split('_')[1]) # num_img = img_hidden_states.shape[0] // n_query # # print(f"img_hidden_states.shape {img_hidden_states.shape}, n_query {n_query}") # # print(f"img_loss_indicator, {img_loss_indicator}") # sqrt_n = int(n_query**0.5) # img_hidden_states = img_hidden_states.reshape(num_img, n_query, -1) # target_image_embeds = target_image_embeds.reshape(num_img, n_query, -1) # channel = img_hidden_states.shape[-1] # img_hidden_states = img_hidden_states.permute(0, 2, 1).view(num_img, -1, sqrt_n, sqrt_n) # target_image_embeds = target_image_embeds.permute(0, 2, 1).view(num_img, -1, sqrt_n, sqrt_n) # img_hidden_states = F.avg_pool2d(img_hidden_states, kernel_size=(stride, stride), stride=stride) # target_image_embeds = F.avg_pool2d(target_image_embeds, kernel_size=(stride, stride), stride=stride) # img_hidden_states = img_hidden_states.reshape(num_img, channel, -1).permute(0,2,1) # target_image_embeds = target_image_embeds.reshape(num_img, channel, -1).permute(0,2,1) img_loss = img_loss_funct(img_hidden_states, target_image_embeds) print(f"img loss {img_loss}, text loss {loss}") total_loss = loss + img_loss return CausalLMOutputWithPast( loss=total_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: ( inputs, position_ids, attention_mask, _, inputs_embeds, img_indicator, _ ) = self.prepare_inputs_labels_for_understanding( inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes ) else: inputs_embeds = self.get_model().embed_tokens(inputs) return super().generate( position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs ) @torch.no_grad() def generate_image( self, text: List[str], tokenizer: AutoTokenizer, image: Optional[torch.Tensor] = None, # placeholder: str = DEFAULT_IMG_PLACEHOLDER, ): vision_tower = self.get_vision_tower() mm_projector = self.get_mm_projector() gen_projector = self.get_gen_projector() N_QUERY = self.get_n_query() image_placeholder = DEFAULT_IM_START_TOKEN + N_QUERY*DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN if image is not None: # image: [Batch, 3, 448, 448] prompt_image_embeds = vision_tower(batch_images) num_img, _, c = prompt_image_embeds.shape # [batch, 576, 1024] all_image_embeds = torch.clone(prompt_image_embeds).detach() prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c) prompt_image_embeds = mm_projector(prompt_image_embeds) # prompt_image_embeds = prompt_image_embeds.view(-1, self.config.hidden_size) text = [t.replace(DEFAULT_IMAGE_TOKEN, image_placeholder) for t in text] # pdb.set_trace() target_image_embeds = None for num_img_token in range(N_QUERY): if num_img_token == 0: text = [f"{t}{DEFAULT_IM_START_TOKEN}" for t in text] else: text = [f"{t}{DEFAULT_IMAGE_TOKEN}" for t in text] inputs = tokenizer(text, padding="longest", return_tensors="pt") device = self.get_model().device attention_mask = inputs.attention_mask.to(device) input_ids = inputs.input_ids.to(device) # B x N text_embeds = self.get_model().embed_tokens(input_ids) image_idx = (input_ids == IMAGE_TOKEN_IDX) img_indicator = torch.clone(image_idx) img_indicator = torch.cat([img_indicator[:, 1:], img_indicator[:, :1]], dim=1) img_indicator[:,-1] = True cumsum_idx = torch.flip(torch.cumsum( torch.flip(image_idx, dims=[1]), dim=1), dims=[1]) if image is not None: prompt_idx = torch.logical_and( image_idx, cumsum_idx > num_img_token) text_embeds[prompt_idx] = prompt_image_embeds.to( text_embeds.device) if target_image_embeds is not None: target_idx = torch.logical_and(image_idx, torch.logical_and( cumsum_idx > 0, cumsum_idx <= num_img_token)) text_embeds[target_idx] = gen_projector( target_image_embeds).to(text_embeds.device) outputs = self.model( inputs_embeds=text_embeds, # img_indicator=img_indicator, # concept_indicator=concept_indicator if self.use_concept_token else None, attention_mask=attention_mask, output_hidden_states=True, return_dict=True, ) image_idx = (input_ids == IMAGE_TOKEN_IDX) + (input_ids == DEFAULT_IM_START_TOKEN_IDX) cumsum_idx = torch.flip(torch.cumsum( torch.flip(image_idx, dims=[1]), dim=1), dims=[1]) target_idx = torch.logical_and(image_idx, torch.logical_and( cumsum_idx > 0, cumsum_idx <= num_img_token+1)) hidden_states = outputs.hidden_states[-1] target_image_embeds = hidden_states[target_idx.to( hidden_states.device)] target_image_embeds = target_image_embeds.view( -1, target_image_embeds.shape[-1]) target_image_embeds = self.get_model().down_projector(target_image_embeds) _, C = target_image_embeds.shape B = hidden_states.shape[0] target_image_embeds = target_image_embeds.view(B, -1, C) # pdb.set_trace() return target_image_embeds def prepare_and_encode_inputs( self, inputs: List[str | Image.Image], tokenizer: AutoTokenizer, do_classifier_free_guidance: bool = False, ): # pdb.set_trace() device = self.get_model().device dtype = self.get_model().dtype has_image, has_text = False, False text_prompt, image_prompt = "", [] img_processor = self.get_vision_tower().image_processor negative_prompt = {} for x in inputs: if isinstance(x, str): has_text = True text_prompt += x else: has_image = True text_prompt += DEFAULT_IMAGE_TOKEN image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values']) # pdb.set_trace() if len(image_prompt) == 0: image_prompt = None else: image_prompt = torch.cat(image_prompt) image_prompt = image_prompt.type(dtype).to(device) if has_image and not has_text: prompt = self.encode_images(image_prompt) # pdb.set_trace() if do_classifier_free_guidance: key = "[NULL_IMAGE]" if key not in negative_prompt: negative_image = torch.zeros_like(image_prompt) negative_prompt[key] = self.encode_images(negative_image) prompt = torch.cat([prompt, negative_prompt[key]], dim=0) else: prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer) if do_classifier_free_guidance: key = "" if key not in negative_prompt: negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer) prompt = torch.cat([prompt, negative_prompt[key]], dim=0) gen_pooling = self.get_gen_pooling() n_query = self.get_n_query() num_img, _, c = prompt.shape if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling: stride = int(gen_pooling.split('_')[1]) sqrt_n = int(n_query**0.5) prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n) prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride) prompt = prompt.reshape(num_img, c, -1).permute(0,2,1) return prompt def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", None) image_sizes = kwargs.pop("image_sizes", None) inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs ) if images is not None: inputs['images'] = images if image_sizes is not None: inputs['image_sizes'] = image_sizes return inputs AutoConfig.register("llava_llama", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)