# handler file for Huggingface Inference API from typing import Dict, Any from peft import PeftModel from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig import transformers import torch from torch import Tensor import torch.nn.functional as F from torch import nn from transformers.models.mistral.modeling_mistral import MistralAttention from ExtractableMistralAttention import forward MistralAttention.forward = forward class EndpointHandler(): def __init__(self, model_dir=''): self.instruction = 'Given a web search query, retrieve relevant passages that answer the query:\n' self.max_length = 4096 self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) self.tokenizer.pad_token = '[PAD]' self.tokenizer.padding_side = 'left' bnb_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16) self.model = AutoModel.from_pretrained( model_dir, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, attn_implementation="eager", ) #self.model = PeftModel.from_pretrained(self.model, model_dir, subfolder='lora') self.model.eval() def last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] def tokenize(self, text, request_type): if request_type == 'query': text = self.instruction + text return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device) def extract_attn_vec(self, model): return self.model._modules['layers'][-1].self_attn.attn_vec def embed(self, text, request_type): tokens = self.tokenize(text, request_type) with torch.no_grad(): output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach() embedding = self.last_token_pool(output, tokens['attention_mask']) embedding = F.normalize(embedding, p=2, dim=1) attn_vec = self.extract_attn_vec(self.model) attn_vec = self.last_token_pool(attn_vec, tokens['attention_mask']) attn_vec = F.normalize(attn_vec, p=2, dim=1) del output, tokens torch.cuda.empty_cache() return embedding, attn_vec def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data.pop("inputs", data) id = inputs.pop("id", inputs) text = inputs.pop("text", inputs) request_type = inputs.pop("type", inputs) embeddings, attn_vec = self.embed(text, request_type) embeddings = embeddings[0].tolist() attn_vec = attn_vec[0].tolist() if request_type == 'query': return {"id": id, "embedding": embeddings, "attention_vec": attn_vec} elif request_type == 'document': return {"id": id, "embedding": embeddings}