krilecy's picture
Upload handler.py
4f3226b verified
raw
history blame
No virus
3.71 kB
# handler file for Huggingface Inference API
from typing import Dict, Any
from peft import PeftModel
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
import transformers
import torch
from torch import Tensor
import torch.nn.functional as F
from torch import nn
from transformers.models.mistral.modeling_mistral import MistralAttention
from ExtractableMistralAttention import forward
MistralAttention.forward = forward
class EndpointHandler():
def __init__(self, model_dir=''):
self.instruction = 'Given a web search query, retrieve relevant passages that answer the query:\n'
self.max_length = 4096
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.tokenizer.pad_token = '[PAD]'
self.tokenizer.padding_side = 'left'
bnb_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16)
self.model = AutoModel.from_pretrained(
model_dir,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
attn_implementation="eager",
)
#self.model = PeftModel.from_pretrained(self.model, model_dir, subfolder='lora')
self.model.eval()
def last_token_pool(self, last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def tokenize(self, text, request_type):
if request_type == 'query':
text = self.instruction + text
return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device)
def extract_attn_vec(self, model):
return self.model._modules['layers'][-1].self_attn.attn_vec
def embed(self, text, request_type):
tokens = self.tokenize(text, request_type)
with torch.no_grad():
output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach()
embedding = self.last_token_pool(output, tokens['attention_mask'])
embedding = F.normalize(embedding, p=2, dim=1)
attn_vec = self.extract_attn_vec(self.model)
attn_vec = self.last_token_pool(attn_vec, tokens['attention_mask'])
attn_vec = F.normalize(attn_vec, p=2, dim=1)
del output, tokens
torch.cuda.empty_cache()
return embedding, attn_vec
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
id = inputs.pop("id", inputs)
text = inputs.pop("text", inputs)
request_type = inputs.pop("type", inputs)
embeddings, attn_vec = self.embed(text, request_type)
embeddings = embeddings[0].tolist()
attn_vec = attn_vec[0].tolist()
if request_type == 'query':
return {"id": id, "embedding": embeddings, "attention_vec": attn_vec}
elif request_type == 'document':
return {"id": id, "embedding": embeddings}