|
|
|
|
|
from typing import Dict, Any |
|
|
|
from peft import PeftModel |
|
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig |
|
import transformers |
|
|
|
|
|
import torch |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
|
|
from transformers.models.mistral.modeling_mistral import MistralAttention |
|
from typing import Optional, Tuple |
|
import warnings |
|
import math |
|
|
|
from transformers.cache_utils import Cache |
|
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv |
|
|
|
|
|
class MistralAttention(MistralAttention): |
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
if "padding_mask" in kwargs: |
|
warnings.warn( |
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
|
) |
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
kv_seq_len = key_states.shape[-2] |
|
if past_key_value is not None: |
|
if self.layer_idx is None: |
|
raise ValueError( |
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " |
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " |
|
"with a layer index." |
|
) |
|
print('kv ', type(kv_seq_len)) |
|
print('past ', type(past_key_value.get_usable_length(kv_seq_len, self.layer_idx))) |
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
|
|
|
if past_key_value is not None: |
|
cache_kwargs = {"sin": sin, "cos": cos} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
|
) |
|
print('attention_weights ', type(attn_weights)) |
|
print('attention_mask ', type(attention_mask)) |
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
d = value_states.size(-1) |
|
|
|
averaging_matrix = torch.full(value_states.shape, 1/d) |
|
custom_attn_output = torch.matmul(attn_weights.to('cpu'), averaging_matrix) |
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
|
|
|
custom_attn_output = custom_attn_output.transpose(1, 2).contiguous() |
|
custom_attn_output = custom_attn_output.reshape(bsz, q_len, self.hidden_size) |
|
self.attn_vec = custom_attn_output |
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, model_dir=''): |
|
self.instruction = 'Given a web search query, retrieve relevant passages that answer the query:\n' |
|
self.max_length = 4096 |
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
|
self.tokenizer.pad_token = '[PAD]' |
|
self.tokenizer.padding_side = 'left' |
|
|
|
bnb_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16) |
|
|
|
self.model = AutoModel.from_pretrained( |
|
model_dir, |
|
quantization_config=bnb_config, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
attn_implementation="eager", |
|
) |
|
self.model = PeftModel.from_pretrained(self.model, model_dir, subfolder='lora') |
|
self.model.eval() |
|
|
|
|
|
def last_token_pool(last_hidden_states: Tensor, |
|
attention_mask: Tensor) -> Tensor: |
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
|
if left_padding: |
|
return last_hidden_states[:, -1] |
|
else: |
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
batch_size = last_hidden_states.shape[0] |
|
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] |
|
|
|
|
|
def tokenize(self, text, type): |
|
if type == 'query': |
|
print('inst', type(self.instruction)) |
|
print('inst', type(text)) |
|
text = self.instruction + text |
|
return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device) |
|
|
|
|
|
def extract_attn_vec(model): |
|
return model._modules['layers'][-1].self_attn.attn_vec |
|
|
|
|
|
def embed(self, text, type): |
|
tokens = self.tokenize(text, type) |
|
with torch.no_grad(): |
|
output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach() |
|
embedding = self.last_token_pool(output, tokens['attention_mask']) |
|
embedding = F.normalize(embedding, p=2, dim=1) |
|
|
|
attn_vec = self.extract_attn_vec(self.model) |
|
attn_vec = self.last_token_pool(attn_vec, tokens['attention_mask']) |
|
attn_vec = F.normalize(attn_vec, p=2, dim=1) |
|
del output, tokens |
|
torch.cuda.empty_cache() |
|
return embedding, attn_vec |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
print(data) |
|
id = data.pop("id", data) |
|
text = data.pop("text", data) |
|
type = data.pop("type", data) |
|
print(id) |
|
print(text) |
|
print(type) |
|
|
|
embeddings, attn_vec = self.embed(text, type) |
|
embeddings = embeddings[0].tolist() |
|
attn_vec = attn_vec[0].tolist() |
|
|
|
if type == 'query': |
|
return {"id": id, "embedding": embeddings, "attention_vec": attn_vec} |
|
|
|
elif type == 'document': |
|
return {"id": id, "embedding": embeddings} |