# handler file for Huggingface Inference API from typing import Dict, Any from peft import PeftModel from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig import transformers import torch from torch import Tensor import torch.nn.functional as F from torch import nn # Override standard attention in Mistral to get extractable attention from transformers.models.mistral.modeling_mistral import MistralAttention from typing import Optional, Tuple import warnings import math from transformers.cache_utils import Cache from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv class MistralAttention(MistralAttention): def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) d = value_states.size(-1) # Custom averaging_matrix = torch.full(value_states.shape, 1/d) # Custom custom_attn_output = torch.matmul(attn_weights.to('cpu'), averaging_matrix) # Custom if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) custom_attn_output = custom_attn_output.transpose(1, 2).contiguous() # Custom custom_attn_output = custom_attn_output.reshape(bsz, q_len, self.hidden_size) # Custom self.attn_vec = custom_attn_output # Custom attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class EndpointHandler(): def __init__(self, model_dir=''): self.instruction = 'Given a web search query, retrieve relevant passages that answer the query:\n' self.max_length = 4096 self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) self.tokenizer.pad_token = '[PAD]' self.tokenizer.padding_side = 'left' bnb_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16) self.model = AutoModel.from_pretrained( model_dir, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, attn_implementation="eager", ) self.model = PeftModel.from_pretrained(self.model, '/lora') self.model.eval() def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] def tokenize(self, text, type): if type == 'query': text = self.instruction + text return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device) def extract_attn_vec(model): return model._modules['layers'][-1].self_attn.attn_vec def embed(self, text, type): tokens = self.tokenize(text, type) with torch.no_grad(): output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach() embedding = self.last_token_pool(output, tokens['attention_mask']) embedding = F.normalize(embedding, p=2, dim=1) attn_vec = self.extract_attn_vec(self.model) attn_vec = self.last_token_pool(attn_vec, tokens['attention_mask']) attn_vec = F.normalize(attn_vec, p=2, dim=1) del output, tokens torch.cuda.empty_cache() return embedding, attn_vec def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ id = data.pop("id", data) text = data.pop("text", data) type = data.pop("type", data) embeddings, attn_vec = self.embed(text, type) embeddings = embeddings[0].tolist() attn_vec = attn_vec[0].tolist() if type == 'query': return {"id": id, "embedding": embeddings, "attention_vec": attn_vec} elif type == 'document': return {"id": id, "embedding": embeddings}