import torch from torch import nn import torch.nn.functional as F from transformers import LlamaForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.cache_utils import StaticCache from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask_with_cache_position from .g_retriever_config import GRetrieverConfig from .gnn import GAT from functools import wraps from torch_geometric.nn.pool import global_mean_pool class GRetrieverModel(LlamaForCausalLM): config_class = GRetrieverConfig def __init__(self, config): super().__init__(config) self.graph_encoder = GAT( in_channels=config.gnn_in_dim, out_channels=config.gnn_hidden_dim, hidden_channels=config.gnn_hidden_dim, num_layers=config.gnn_num_layers, dropout=config.gnn_dropout, num_heads=config.gnn_num_heads, ).to(self.model.dtype) self.projector = nn.Sequential( nn.Linear(config.gnn_hidden_dim, 2048), nn.Sigmoid(), nn.Linear(2048, self.get_input_embeddings().embedding_dim), ).to(self.model.dtype) def encode_graphs(self, graph): n_embeds, _ = self.graph_encoder( graph.x.to(self.model.dtype), graph.edge_index.long(), graph.edge_attr.to(self.model.dtype) ) # mean pooling g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device)) return g_embeds @wraps(LlamaForCausalLM.forward) def forward( self, input_ids=None, graph=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, cache_position=None ): inputs = input_ids.clone() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (inputs==-1).any(): # embed bos prompt bos_embeds = self.get_input_embeddings()(torch.tensor( self.config.bos_id, device=self.model.device )) # encode graph graph_embeds = self.encode_graphs(graph) graph_embeds = self.projector(graph_embeds).to(self.model.device) # prepare for reserved ids (bos+graph) non_tokenized_ids = (inputs == -1).nonzero() non_tokenized_shape = non_tokenized_ids[:, 0], non_tokenized_ids[:, 1] # embed inputs inputs[non_tokenized_shape] = self.config.pad_token_id temp_inputs_embeds = self.get_input_embeddings()(inputs) non_tokenized_embeds = torch.cat([bos_embeds.repeat(len(inputs), 1, 1), graph_embeds.unsqueeze(1)], dim=1) # replace reserved ids with bos+graph inputs_embeds = temp_inputs_embeds.clone() inputs_embeds[non_tokenized_shape] = non_tokenized_embeds.view(len(non_tokenized_ids), -1) else: inputs_embeds = self.get_input_embeddings()(inputs) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, graph=None, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape device = model_inputs["inputs_embeds"].device else: batch_size, sequence_length = model_inputs["input_ids"].shape device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_length(), dtype=dtype, device=device, min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) model_inputs.update( { "graph": graph, "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } ) return model_inputs