krilecy's picture
Upload handler.py
6891925 verified
raw
history blame
8.33 kB
# handler file for Huggingface Inference API
from typing import Dict, Any
from peft import PeftModel
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
import transformers
import torch
from torch import Tensor
import torch.nn.functional as F
from torch import nn
# Override standard attention in Mistral to get extractable attention
from transformers.models.mistral.modeling_mistral import MistralAttention
from typing import Optional, Tuple
import warnings
import math
from transformers.cache_utils import Cache
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
class MistralAttention(MistralAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
d = value_states.size(-1) # Custom
averaging_matrix = torch.full(value_states.shape, 1/d) # Custom
custom_attn_output = torch.matmul(attn_weights.to('cpu'), averaging_matrix) # Custom
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
custom_attn_output = custom_attn_output.transpose(1, 2).contiguous() # Custom
custom_attn_output = custom_attn_output.reshape(bsz, q_len, self.hidden_size) # Custom
self.attn_vec = custom_attn_output # Custom
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class EndpointHandler():
def __init__(self, model_dir=''):
self.instruction = 'Given a web search query, retrieve relevant passages that answer the query:\n'
self.max_length = 4096
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.tokenizer.pad_token = '[PAD]'
self.tokenizer.padding_side = 'left'
bnb_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16)
self.model = AutoModel.from_pretrained(
model_dir,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
attn_implementation="eager",
)
self.model = PeftModel.from_pretrained(self.model, '/lora')
self.model.eval()
def last_token_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def tokenize(self, text, type):
if type == 'query':
text = self.instruction + text
return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device)
def extract_attn_vec(model):
return model._modules['layers'][-1].self_attn.attn_vec
def embed(self, text, type):
tokens = self.tokenize(text, type)
with torch.no_grad():
output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach()
embedding = self.last_token_pool(output, tokens['attention_mask'])
embedding = F.normalize(embedding, p=2, dim=1)
attn_vec = self.extract_attn_vec(self.model)
attn_vec = self.last_token_pool(attn_vec, tokens['attention_mask'])
attn_vec = F.normalize(attn_vec, p=2, dim=1)
del output, tokens
torch.cuda.empty_cache()
return embedding, attn_vec
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
id = data.pop("id", data)
text = data.pop("text", data)
type = data.pop("type", data)
embeddings, attn_vec = self.embed(text, type)
embeddings = embeddings[0].tolist()
attn_vec = attn_vec[0].tolist()
if type == 'query':
return {"id": id, "embedding": embeddings, "attention_vec": attn_vec}
elif type == 'document':
return {"id": id, "embedding": embeddings}