# Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod import torch import torch.nn as nn import torch.nn.functional as F from .multimodal_encoder.builder import build_vision_tower, build_gen_vision_tower from .multimodal_projector.builder import build_vision_projector, build_down_projector, build_gen_vision_projector from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX class LlavaMetaModel: def __init__(self, config): super(LlavaMetaModel, self).__init__(config) if hasattr(config, "mm_vision_tower"): self.vision_tower = build_vision_tower(config, delay_load=True) self.mm_projector = build_vision_projector(config) self.down_projector = build_down_projector(config) if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): self.image_newline = nn.Parameter( torch.empty(config.hidden_size, dtype=self.dtype) ) if hasattr(config, "gen_vision_tower"): self.gen_vision_tower = build_gen_vision_tower(config, delay_load=True) self.gen_projector = build_gen_vision_projector(config) if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): self.image_newline = nn.Parameter( torch.empty(config.hidden_size, dtype=self.dtype) ) def get_vision_tower(self): vision_tower = getattr(self, 'vision_tower', None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def get_gen_vision_tower(self): gen_vision_tower = getattr(self, 'gen_vision_tower', None) if type(gen_vision_tower) is list: gen_vision_tower = gen_vision_tower[0] return gen_vision_tower def initialize_vision_modules(self, model_args, fsdp=None): vision_tower = model_args.vision_tower gen_vision_tower = model_args.gen_vision_tower mm_vision_select_layer = model_args.mm_vision_select_layer mm_vision_select_feature = model_args.mm_vision_select_feature pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter pretrain_gen_mlp_adapter = model_args.pretrain_gen_mlp_adapter mm_patch_merge_type = model_args.mm_patch_merge_type self.config.mm_vision_tower = vision_tower self.config.gen_vision_tower = gen_vision_tower self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "") if self.get_vision_tower() is None: vision_tower = build_vision_tower(model_args) if fsdp is not None and len(fsdp) > 0: self.vision_tower = [vision_tower] else: self.vision_tower = vision_tower else: if fsdp is not None and len(fsdp) > 0: vision_tower = self.vision_tower[0] else: vision_tower = self.vision_tower vision_tower.load_model() if self.get_gen_vision_tower() is None: gen_vision_tower = build_gen_vision_tower(model_args) if fsdp is not None and len(fsdp) > 0: self.gen_vision_tower = [gen_vision_tower] else: self.gen_vision_tower = gen_vision_tower else: if fsdp is not None and len(fsdp) > 0: gen_vision_tower = self.gen_vision_tower[0] else: gen_vision_tower = self.gen_vision_tower gen_vision_tower.load_model() self.config.use_mm_proj = True self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') self.config.gen_projector_type = getattr(model_args, 'gen_projector_type', 'linear') self.config.mm_hidden_size = vision_tower.hidden_size self.config.gen_hidden_size = gen_vision_tower.hidden_size self.config.mm_vision_select_layer = mm_vision_select_layer self.config.mm_vision_select_feature = mm_vision_select_feature self.config.mm_patch_merge_type = mm_patch_merge_type self.config.n_query = model_args.n_query self.config.gen_pooling = model_args.gen_pooling if getattr(self, 'mm_projector', None) is None: print("random initiation the mm_project !!!") self.mm_projector = build_vision_projector(self.config) if 'unpad' in mm_patch_merge_type: embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) self.image_newline = nn.Parameter( torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std ) else: # In case it is frozen by LoRA for p in self.mm_projector.parameters(): p.requires_grad = True if getattr(self, 'gen_projector', None) is None: print("random initiation the gen_projector !!!") self.gen_projector = build_gen_vision_projector(self.config) if 'unpad' in mm_patch_merge_type: embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) self.image_newline = nn.Parameter( torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std ) else: # In case it is frozen by LoRA for p in self.gen_projector.parameters(): p.requires_grad = True if getattr(self, 'down_projector', None) is None: print("random initiation the down_projector !!!") self.down_projector = build_down_projector(self.config) else: # In case it is frozen by LoRA for p in self.down_projector.parameters(): p.requires_grad = True if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) if pretrain_gen_mlp_adapter is not None: gen_projector_weights = torch.load(pretrain_gen_mlp_adapter, map_location='cpu') def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} self.gen_projector.load_state_dict(get_w(gen_projector_weights, 'mm_projector')) def unpad_image(tensor, original_size): """ Unpads a PyTorch tensor of a padded and resized image. Args: tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. original_size (tuple): The original size of PIL image (width, height). Returns: torch.Tensor: The unpadded image tensor. """ original_width, original_height = original_size current_height, current_width = tensor.shape[1:] original_aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height if original_aspect_ratio > current_aspect_ratio: scale_factor = current_width / original_width new_height = int(original_height * scale_factor) padding = (current_height - new_height) // 2 unpadded_tensor = tensor[:, padding:current_height - padding, :] else: scale_factor = current_height / original_height new_width = int(original_width * scale_factor) padding = (current_width - new_width) // 2 unpadded_tensor = tensor[:, :, padding:current_width - padding] return unpadded_tensor class LlavaMetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def get_gen_vision_tower(self): return self.get_model().get_gen_vision_tower() def encode_images(self, images): device = self.get_vision_tower().device images = images.to(device) image_features = self.get_model().get_vision_tower()(images) num_img, _, c = image_features.shape gen_pooling = self.get_gen_pooling() n_query = self.get_n_query() if not 'early' in gen_pooling else 729 if 'pool2d' in gen_pooling: stride = int(gen_pooling.split('_')[-1]) sqrt_n = int(n_query**0.5) image_features = image_features.permute(0, 2, 1).view(num_img, -1, sqrt_n, sqrt_n) image_features = F.avg_pool2d(image_features, kernel_size=(stride, stride), stride=stride) image_features = image_features.reshape(num_img, c, -1).permute(0,2,1) # image_features = image_features.contiguous().view(-1, c) # image_features = self.get_model().mm_projector(image_features) return image_features def get_mm_projector(self): return self.get_model().mm_projector def get_gen_projector(self): return self.get_model().gen_projector def get_n_query(self): return self.get_model().config.n_query def get_gen_pooling(self): return self.get_model().config.gen_pooling def pool_img(self, image_features): num_img, n, c = image_features.shape gen_pooling = self.get_gen_pooling() # n_query = self.get_n_query() stride = int(gen_pooling.split('_')[-1]) sqrt_n = int(n**0.5) image_features = image_features.permute(0, 2, 1).view(num_img, c, sqrt_n, sqrt_n) image_features = F.avg_pool2d(image_features, kernel_size=(stride, stride), stride=stride) image_features = image_features.view(num_img, c, -1).permute(0,2,1).contiguous() return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, gen_images, und_images, image_sizes=None ): vision_tower = self.get_vision_tower() mm_projector = self.get_mm_projector() gen_vision_tower = self.get_gen_vision_tower() gen_projector = self.get_gen_projector() if (gen_images is None and und_images is None) or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, None, None if not gen_images is None: # print(f"gen_images {gen_images.shape}") prompt_image_embeds = gen_vision_tower(gen_images) # TODO: check dimension # print(f"prompt_image_embeds {prompt_image_embeds.shape}") if 'early' in self.get_gen_pooling(): prompt_image_embeds = self.pool_img(prompt_image_embeds) num_img, _, c = prompt_image_embeds.shape # [batch, 729, 1152] # all_image_embeds = torch.clone(prompt_image_embeds).detach() prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c) target_image_embeds = torch.clone(prompt_image_embeds).detach() prompt_image_embeds = gen_projector(prompt_image_embeds) else: # print(f"warning !!!!!!!!!!!!!") target_image_embeds = None # quick fix # change und_images dim so gen_vision_tower process # und_images torch.Size([2, 3, 336, 336]) # gen_images torch.Size([2, 3, 384, 384]) num_img = und_images.shape[0] dummy = torch.zeros(num_img, 3, 448, 448 , dtype=und_images.dtype, device=und_images.device) # TODO temp = gen_vision_tower(dummy)[:,:729,:] num_img, _, c = temp.shape temp = temp.contiguous().view(-1, c) temp = gen_projector(temp) * 1e-9 # print(f"gen temp {temp.sum()}") if not und_images is None: # print(f"und_images {und_images.shape}") und_image_embeds = vision_tower(und_images) num_img, _, c = und_image_embeds.shape und_image_embeds = und_image_embeds.contiguous().view(-1, c) und_image_embeds = mm_projector(und_image_embeds) if gen_images is None: und_image_embeds += temp else: # print(f"warning !!!!!!!!!!!!!") num_img = gen_images.shape[0] dummy = torch.zeros(num_img, 3, 384, 384 , dtype=gen_images.dtype, device=gen_images.device) # clip (3, 336, 336) temp = vision_tower(dummy) if 'early' in self.get_gen_pooling(): temp = temp[:,:64,:] num_img, _, c = temp.shape temp = temp.contiguous().view(-1, c) temp = mm_projector(temp) * 1e-9 # print(f"und temp {temp.sum()}") prompt_image_embeds += temp image_idx = (input_ids == IMAGE_TOKEN_IDX) img_indicator = torch.clone(image_idx) output_indicator = labels != -100 # print(f"### output_indicator {output_indicator.tolist()}") input_indicator = labels == -100 # print(f"### input_indicator {input_indicator.tolist()}") # print(f"output_indicator {output_indicator[0]}") img_loss_indicator = torch.logical_and(output_indicator, img_indicator) img_loss_indicator = torch.cat( [img_loss_indicator[:, 1:], img_loss_indicator[:, :1]], dim=1) img_indicator = torch.cat( [img_indicator[:, 1:], img_indicator[:, :1]], dim=1) # num_output_img = img_loss_indicator.sum().item()//self.model.n_query # print(f"img_loss_indicator {img_loss_indicator[0]}") # print(f"img_loss_indicator.sum() {img_loss_indicator.sum()}") if not target_image_embeds is None: target_image_embeds = target_image_embeds[-img_loss_indicator.sum():,:] # print(f"target_image_embeds {target_image_embeds}") # print(f"before embed input ids") # print(f"image_idx.sum() {image_idx.sum()}") # print(f"input_ids {input_ids[0,:]}") # print(f"self.model.decoder.lm.model.emb {self.model.decoder.lm.get_input_embeddings().weight.data.shape}") text_embeds = self.get_model().embed_tokens(input_ids) # print(f"text_embeds {text_embeds}") # print(f"break 1") N_QUERY = self.get_n_query() # if not image_idx.sum()/N_QUERY == image_idx.sum()//N_QUERY: # print('warning half image: ', image_idx.sum()/N_QUERY, image_idx.sum()//N_QUERY) # breakpoint() # print(f"image_idx {image_idx}") # print(f"text_embeds {text_embeds}, prompt_image_embeds {prompt_image_embeds}") # print(f"prompt_image_embeds {prompt_image_embeds}") gen_img_idx = torch.logical_and(output_indicator, image_idx) if not target_image_embeds is None: text_embeds[gen_img_idx] = prompt_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:] target_image_embeds = target_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:] und_img_idx = torch.logical_and(input_indicator, image_idx) if not und_images is None: # text_embeds[und_img_idx] = und_image_embeds.to(text_embeds.device)[:und_img_idx.sum(),:] # try: text_embeds[und_img_idx] = und_image_embeds.to(text_embeds.device)[:und_img_idx.sum(), :] # except RuntimeError as e: # print(f"RuntimeError: {e}") # print(f"text_embeds shape: {text_embeds.shape}") # print(f"und_images: {und_images.shape}") # print(f"und_image_embeds shape: {und_image_embeds.shape}") # print(f"und_img_idx sum: {und_img_idx.sum()} (should match number of rows in und_image_embeds)") # print("Continuing without modifying text_embeds.") # # Get the shapes involved # expected_shape = und_img_idx.sum() # Number of True values or indices # actual_shape = und_image_embeds.shape[0] # Number of rows in und_image_embeds # if expected_shape > actual_shape: # # If more indices than embeddings, truncate und_img_idx to match und_image_embeds # print(f"Shape mismatch: expected {expected_shape} rows, but only {actual_shape} embeddings available.") # adjusted_idx = und_img_idx.nonzero(as_tuple=True)[0][:actual_shape] # Get the first `actual_shape` indices # text_embeds[adjusted_idx] = und_image_embeds.to(text_embeds.device) # print(f"Truncated indices from {expected_shape} to {actual_shape}.") # else: # # If more embeddings than indices, trim und_image_embeds to match und_img_idx # print(f"Shape mismatch: expected {expected_shape} rows, but got {actual_shape}. Using first {expected_shape} embeddings.") # text_embeds[und_img_idx] = und_image_embeds[:expected_shape, :].to(text_embeds.device) # print(f"target_image_embeds {target_image_embeds}") # print(f"break 4") labels[image_idx] = -100 # print(f"labels[0] {labels[0]}") # print(f"break 5") # print({'all_image_embeds':all_image_embeds.shape, 'num_output_img':num_output_img, 'num_img': num_img}) return None, position_ids, attention_mask, past_key_values, text_embeds, labels, img_loss_indicator, img_indicator, target_image_embeds def prepare_inputs_labels_for_understanding( self, input_ids, position_ids, attention_mask, past_key_values, labels, batch_images, image_sizes=None ): vision_tower = self.get_vision_tower() mm_projector = self.get_mm_projector() # pdb.set_trace() prompt_image_embeds = vision_tower(batch_images) # TODO: check dimension # print(f"prompt_image_embeds.shape: {prompt_image_embeds.shape}") num_img, _, c = prompt_image_embeds.shape # [batch, 576, 1024] all_image_embeds = torch.clone(prompt_image_embeds).detach() prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c) prompt_image_embeds = mm_projector(prompt_image_embeds) # print(f"prompt_image_embeds {prompt_image_embeds.shape}") # print(f"input_ids {input_ids}") # IMAGE = 128259 image_idx = (input_ids == IMAGE_TOKEN_IDX) # print(f"image_idx {image_idx[0]}") img_indicator = torch.clone(image_idx) img_indicator = torch.cat( [img_indicator[:, 1:], img_indicator[:, :1]], dim=1) # print(f"before embed input ids") # print(f"image_idx.sum() {image_idx.sum()}") # print(f"input_ids {input_ids[0,:]}") # print(f"self.model.decoder.lm.model.emb {self.model.decoder.lm.get_input_embeddings().weight.data.shape}") text_embeds = self.get_model().embed_tokens(input_ids) # print(f"text_embeds {text_embeds}") # print(f"break 1") N_QUERY = self.get_n_query() # if not image_idx.sum()/N_QUERY == image_idx.sum()//N_QUERY: # print('warning half image: ', image_idx.sum()/N_QUERY, image_idx.sum()//N_QUERY) # print(f"break 1.5") # print(f"image_idx {image_idx}") # print(f"text_embeds {text_embeds}, prompt_image_embeds {prompt_image_embeds}") text_embeds[image_idx] = prompt_image_embeds.to(text_embeds.device)[:image_idx.sum(),:] # print({'all_image_embeds':all_image_embeds.shape, 'num_output_img':num_output_img, 'num_img': num_img}) return None, position_ids, attention_mask, past_key_values, text_embeds, img_indicator, labels def initialize_vision_tokenizer(self, model_args, tokenizer): if model_args.mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = False if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") elif model_args.mm_use_im_patch_token: if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): p.requires_grad = False