from typing import List, Optional, Tuple, Union import warnings, os, torch import torch.nn as nn from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, AutoModelForCausalLM, AutoTokenizer from transformers.modeling_utils import ContextManagers, no_init_weights from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput from .configuration_apollo import ApolloConfig from .vision_tower import ApolloVisionTower from .mm_connector import MMConnector IGNORE_INDEX = -100 X_TOKEN_INDEX = -200 def get_model_config(config): default_keys = ["llm_cfg", "vision_tower_cfg", "mm_connector_cfg"] if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2: root_path = config._name_or_path else: root_path = config.resume_path return_pths = [] for key in default_keys: cfg = getattr(config, key, None) if isinstance(cfg, dict): try: return_pths.append(os.path.join(root_path, key[:-4])) except: raise ValueError(f"Cannot find resume path in config for {key}!") elif isinstance(cfg, PretrainedConfig): return_pths.append(os.path.join(root_path, key[:-4])) elif isinstance(cfg, str): return_pths.append(cfg) return_list = [] for pth in return_pths: return_list.append(AutoConfig.from_pretrained(pth, trust_remote_code=True)) return return_list def build_llm_and_tokenizer( llm_cfg: str, config: PretrainedConfig, attn_implementation=None, model_max_length=None, *args, **kwargs, ) -> PreTrainedModel: llm_arch = getattr(llm_cfg, "architectures")[0].lower() llm_path = llm_cfg._name_or_path llm = AutoModelForCausalLM.from_pretrained( llm_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs ) tokenizer = AutoTokenizer.from_pretrained( llm_path, model_max_length=llm_cfg.model_max_length, padding_side="right", use_fast=False, legacy=False, **kwargs ) #config.hidden_size = llm.config.hidden_size return llm, tokenizer class ApolloForCausalLM(PreTrainedModel): def __init__(self, config: ApolloConfig, *args, **kwargs): super().__init__(config) llm_cfg, vision_tower_cfg, mm_connector_cfg = get_model_config(config) model_dtype = getattr(config, "model_dtype", "torch.float16") if not hasattr(config, "model_dtype"): warnings.warn("model_dtype not found in config, defaulting to torch.float16.") config.model_dtype = model_dtype # Initialize weights and apply final processing self.lm_head = nn.Linear(llm_cfg.hidden_size, config.vocab_size, bias=False) self.vision_tower = ApolloVisionTower(config, vision_tower_cfg) self.mm_connector = MMConnector.from_pretrained(mm_connector_cfg._name_or_path) self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) self.post_init() self.is_loaded = True def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, vision_input: Optional[List[torch.FloatTensor]] = None, data_types: Optional[List[str]] = None, return_dict: Optional[bool] = None, cache_position=None, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels ) = self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, vision_input, data_types ) return self.get_llm().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, vision_input: Optional[List[torch.Tensor]] = None, data_types: Optional[List[str]] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") if vision_input is not None: (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal( inputs, position_ids, attention_mask, None, None, vision_input, data_types=data_types) else: inputs_embeds = self.embed_tokens(inputs) return self.get_llm().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): vision_input = kwargs.pop("vision_input", None) data_types = kwargs.pop("data_types", None) inputs = self.get_llm().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) if vision_input is not None: inputs["vision_input"] = vision_input if data_types is not None: inputs["data_types"] = data_types return inputs @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, ignore_mismatched_sizes: bool = False, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: bool = None, **kwargs, ): return cls.load_pretrained( pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, use_safetensors=use_safetensors, **kwargs, ) def get_llm(self): return self.llm def get_vision_tower(self): return self.vision_tower def get_mm_connector(self): return self.mm_connector @classmethod def load_pretrained(cls, model_path_or_config, *args, **kwargs): kwargs.pop("config", None) if isinstance(model_path_or_config, str): config = AutoConfig.from_pretrained(model_path_or_config, trust_remote_code=True, **kwargs) elif isinstance(model_path_or_config, ApolloConfig): config = model_path_or_config else: raise NotImplementedError(f"wrong type, {type(model_path_or_config)} \ {isinstance(model_path_or_config, ApolloConfig)}") model_dtype = getattr(config, "model_dtype", "torch.float16") if not hasattr(config, "model_dtype"): warnings.warn("model_dtype not found in config, defaulting to torch.float16.") config.model_dtype = model_dtype with ContextManagers([no_init_weights(_enable=True), ]): vlm = cls(config, *args, **kwargs) if hasattr(vlm, "llm") and hasattr(vlm, "vision_tower") and hasattr(vlm, "mm_connector"): if vlm.is_loaded: return vlm else: print('loading model failed!') else: print('loading model failed!') def _encode_mm(self, x): x = self.get_vision_tower()(x) x = self.mm_connector(x) return x def encode_mm_minibatch(self, x): split_sizes = [x_s[0].shape[0] for x_s in x] x = [torch.split(torch.cat([x_s[i] for x_s in x], dim=0), self.config.encode_batch_size) for i in range(self.get_vision_tower().num_vision_encoders)] swapped_x = [] for i in range(len(x[0])): swapped_x.append([x_s[i] for x_s in x]) features = [] for xx in swapped_x: xx = self._encode_mm(xx) features.append(xx) x = torch.cat(features, dim=0) x = torch.split(x, split_sizes, dim=0) return [xx.contiguous().view(-1, xx.shape[2]) for xx in x] def prepare_inputs_labels_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, vision_input, data_types ): vision_tower = self.get_vision_tower() if vision_tower is None or vision_input is None or input_ids.shape[1] == 1: if ( past_key_values is not None and vision_tower is not None and vision_input is not None and input_ids.shape[1] == 1 ): target_shape = past_key_values[-1][-1].shape[-2] + 1 attention_mask = torch.cat( ( attention_mask, torch.ones( ( attention_mask.shape[0], target_shape - attention_mask.shape[1], ), dtype=attention_mask.dtype, device=attention_mask.device, ), ), dim=1, ) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 return ( input_ids, position_ids, attention_mask, past_key_values, None, labels, ) ''' vision_input is a list of tuples, and data_type is a list of strings: data_type = ['image', 'video', 'video'..., 'text'] (for one video and two image encoders) vision_input = [ [image(1, T, C, H, W), image(1, T, C, H, W), image(1, T, C, H, W)], [video(Nc1, C, T, H, W), video(Nc1, T, C, H, W), video(Nc1, T, C, H, W)], [video(Nc2, C, T, H, W), video(Nc2, T, C, H, W), video(Nc2, T, C, H, W)], ] -> video encoders typlically expect (C,T,H,W), images expect (C,H,W). ''' # ==================================================================================================== merged_mm_features = self.encode_mm_minibatch(vision_input) if not getattr(self.config, "tune_language_model", True) and getattr(self.config, "use_mm_start_end", False): raise NotImplementedError # ==================================================================================================== # Let's just add dummy tensors if they do not exist, # it is a headache to deal with None all the time. # But it is not ideal, and if you have a better idea, # please open an issue / submit a PR, thanks. _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) # remove the padding using attention_mask input_ids_copy = input_ids.clone() # kentang-mit@: Otherwise tokenizer out of bounds. Embeddings of image tokens will not be used. input_ids_copy[input_ids_copy == X_TOKEN_INDEX] = 0 input_embeds = self.get_llm().model.embed_tokens(input_ids_copy) input_ids = [ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) ] input_embeds_1 = [ cur_input_embeds[cur_attention_mask] for cur_input_embeds, cur_attention_mask in zip(input_embeds, attention_mask) ] labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] # input_ids, new_input_embeds = self.inputs_merger(input_ids, input_embeds_1, merged_mm_features) new_labels = [] new_input_embeds = [] # print("BEFORE BATCH LOOP:", len(input_ids), input_ids[0].shape, input_ids[0].device, [(x == X_TOKEN_INDEX).sum() for x in input_ids]) # kentang-mit@: If some part of the model is executed in the loop, the the loop length needs to be a constant. for batch_idx, (cur_labels, cur_input_ids, mm_features) in enumerate( zip(labels, input_ids, merged_mm_features)): cur_input_ids = input_ids[batch_idx] num_mm = (cur_input_ids == X_TOKEN_INDEX).sum() if num_mm == 0: cur_input_embeds_1 = input_embeds_1[batch_idx] cur_input_embeds = torch.cat([cur_input_embeds_1, mm_features[0:0]], dim=0) new_input_embeds.append(cur_input_embeds) new_labels.append(cur_labels) # kenang-mit@: we do not have placeholdr image for text-only data now. continue if mm_features.shape[0] != num_mm: print(data_types[batch_idx]) assert num_mm == len( mm_features), f'Error in {data_types[batch_idx]}{num_mm}=/={len(mm_features)} not the same number of vision tokens in and vision embeddings!' cur_input_embeds = input_embeds_1[batch_idx] image_token_indices = ( [-1] + torch.where(cur_input_ids == X_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] ) cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] cur_input_embeds_no_im = [] for i in range(len(image_token_indices) - 1): cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1: image_token_indices[i + 1]]) cur_labels_noim.append(cur_labels[image_token_indices[i] + 1: image_token_indices[i + 1]]) cur_input_embeds_no_im.append(cur_input_embeds[image_token_indices[i] + 1: image_token_indices[i + 1]]) cur_new_input_embeds = [] cur_new_labels = [] for i in range(num_mm + 1): cur_new_input_embeds.append(cur_input_embeds_no_im[i]) # print("cur_new_input_embeds1", cur_new_input_embeds.shape[-1]) cur_new_labels.append(cur_labels_noim[i]) if i < num_mm: cur_image_features = mm_features[i:i + 1] cur_new_input_embeds.append(cur_image_features) # print("cur_new_input_embeds2", cur_new_input_embeds.shape[-1]) cur_new_labels.append( torch.full( (cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype, ) ) cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) new_input_embeds.append(cur_new_input_embeds) new_labels.append(cur_new_labels) # Truncate sequences to max length as image embeddings can make the sequence longer tokenizer_model_max_length = getattr(self.get_llm().config, "tokenizer_model_max_length", None) if tokenizer_model_max_length is not None: if any(len(x) > tokenizer_model_max_length for x in new_input_embeds): priny("Inputs truncated!") new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] new_labels = [x[:tokenizer_model_max_length] for x in new_labels] # Combine them max_len = max(x.shape[0] for x in new_input_embeds) batch_size = len(new_input_embeds) new_input_embeds_padded = [] new_labels_padded = torch.full( (batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device, ) attention_mask = torch.zeros( (batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device, ) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): cur_len = cur_new_embed.shape[0] if getattr(self.get_llm().config, "tokenizer_padding_side", "right") == "left": new_input_embeds_padded.append( torch.cat( ( torch.zeros( (max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), cur_new_embed, ), dim=0, ) ) if cur_len > 0: new_labels_padded[i, -cur_len:] = cur_new_labels attention_mask[i, -cur_len:] = True position_ids[i, -cur_len:] = torch.arange( 0, cur_len, dtype=position_ids.dtype, device=position_ids.device ) else: new_input_embeds_padded.append( torch.cat( ( cur_new_embed, torch.zeros( (max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), ), dim=0, ) ) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange( 0, cur_len, dtype=position_ids.dtype, device=position_ids.device ) new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None return ( None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, )